Bug Summary

File:build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Warning:line 4209, column 24
2nd function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SPIRVOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16 -I tools/mlir/lib/Dialect/SPIRV/IR -I /build/source/mlir/lib/Dialect/SPIRV/IR -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1671487667 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-12-20-010714-16201-1 -x c++ /build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/FunctionImplementation.h"
25#include "mlir/IR/OpDefinition.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/Operation.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Interfaces/CallInterfaces.h"
30#include "llvm/ADT/APFloat.h"
31#include "llvm/ADT/APInt.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/Support/FormatVariadic.h"
36#include <cassert>
37#include <numeric>
38
39using namespace mlir;
40
41// TODO: generate these strings using ODS.
42constexpr char kAlignmentAttrName[] = "alignment";
43constexpr char kBranchWeightAttrName[] = "branch_weights";
44constexpr char kCallee[] = "callee";
45constexpr char kClusterSize[] = "cluster_size";
46constexpr char kControl[] = "control";
47constexpr char kDefaultValueAttrName[] = "default_value";
48constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
49constexpr char kExecutionScopeAttrName[] = "execution_scope";
50constexpr char kFnNameAttrName[] = "fn";
51constexpr char kGroupOperationAttrName[] = "group_operation";
52constexpr char kIndicesAttrName[] = "indices";
53constexpr char kInitializerAttrName[] = "initializer";
54constexpr char kInterfaceAttrName[] = "interface";
55constexpr char kMemoryAccessAttrName[] = "memory_access";
56constexpr char kMemoryScopeAttrName[] = "memory_scope";
57constexpr char kPackedVectorFormatAttrName[] = "format";
58constexpr char kSemanticsAttrName[] = "semantics";
59constexpr char kSourceAlignmentAttrName[] = "source_alignment";
60constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
61constexpr char kSpecIdAttrName[] = "spec_id";
62constexpr char kTypeAttrName[] = "type";
63constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
64constexpr char kValueAttrName[] = "value";
65constexpr char kValuesAttrName[] = "values";
66constexpr char kCompositeSpecConstituentsName[] = "constituents";
67
68//===----------------------------------------------------------------------===//
69// Common utility functions
70//===----------------------------------------------------------------------===//
71
72static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
73 OperationState &result) {
74 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
75 Type type;
76 // If the operand list is in-between parentheses, then we have a generic form.
77 // (see the fallback in `printOneResultOp`).
78 SMLoc loc = parser.getCurrentLocation();
79 if (!parser.parseOptionalLParen()) {
80 if (parser.parseOperandList(ops) || parser.parseRParen() ||
81 parser.parseOptionalAttrDict(result.attributes) ||
82 parser.parseColon() || parser.parseType(type))
83 return failure();
84 auto fnType = type.dyn_cast<FunctionType>();
85 if (!fnType) {
86 parser.emitError(loc, "expected function type");
87 return failure();
88 }
89 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
90 return failure();
91 result.addTypes(fnType.getResults());
92 return success();
93 }
94 return failure(parser.parseOperandList(ops) ||
95 parser.parseOptionalAttrDict(result.attributes) ||
96 parser.parseColonType(type) ||
97 parser.resolveOperands(ops, type, result.operands) ||
98 parser.addTypeToList(type, result.types));
99}
100
101static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
102 assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 &&
"op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 102, __extension__
__PRETTY_FUNCTION__))
;
103
104 // If not all the operand and result types are the same, just use the
105 // generic assembly form to avoid omitting information in printing.
106 auto resultType = op->getResult(0).getType();
107 if (llvm::any_of(op->getOperandTypes(),
108 [&](Type type) { return type != resultType; })) {
109 p.printGenericOp(op, /*printOpName=*/false);
110 return;
111 }
112
113 p << ' ';
114 p.printOperands(op->getOperands());
115 p.printOptionalAttrDict(op->getAttrs());
116 // Now we can output only one type for all operands and the result.
117 p << " : " << resultType;
118}
119
120/// Returns true if the given op is a function-like op or nested in a
121/// function-like op without a module-like op in the middle.
122static bool isNestedInFunctionOpInterface(Operation *op) {
123 if (!op)
124 return false;
125 if (op->hasTrait<OpTrait::SymbolTable>())
126 return false;
127 if (isa<FunctionOpInterface>(op))
128 return true;
129 return isNestedInFunctionOpInterface(op->getParentOp());
130}
131
132/// Returns true if the given op is an module-like op that maintains a symbol
133/// table.
134static bool isDirectInModuleLikeOp(Operation *op) {
135 return op && op->hasTrait<OpTrait::SymbolTable>();
136}
137
138static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
139 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
140 if (!constOp) {
141 return failure();
142 }
143 auto valueAttr = constOp.getValue();
144 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
145 if (!integerValueAttr) {
146 return failure();
147 }
148
149 if (integerValueAttr.getType().isSignlessInteger())
150 value = integerValueAttr.getInt();
151 else
152 value = integerValueAttr.getSInt();
153
154 return success();
155}
156
157template <typename Ty>
158static ArrayAttr
159getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
160 function_ref<StringRef(Ty)> stringifyFn) {
161 if (enumValues.empty()) {
162 return nullptr;
163 }
164 SmallVector<StringRef, 1> enumValStrs;
165 enumValStrs.reserve(enumValues.size());
166 for (auto val : enumValues) {
167 enumValStrs.emplace_back(stringifyFn(val));
168 }
169 return builder.getStrArrayAttr(enumValStrs);
170}
171
172/// Parses the next string attribute in `parser` as an enumerant of the given
173/// `EnumClass`.
174template <typename EnumClass>
175static ParseResult
176parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
177 StringRef attrName = spirv::attributeName<EnumClass>()) {
178 Attribute attrVal;
179 NamedAttrList attr;
180 auto loc = parser.getCurrentLocation();
181 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
3
Taking false branch
182 attrName, attr))
183 return failure();
184 if (!attrVal.isa<StringAttr>())
4
Taking true branch
185 return parser.emitError(loc, "expected ")
5
Returning without writing to 'value'
186 << attrName << " attribute specified as string";
187 auto attrOptional =
188 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
189 if (!attrOptional)
190 return parser.emitError(loc, "invalid ")
191 << attrName << " attribute specification: " << attrVal;
192 value = *attrOptional;
193 return success();
194}
195
196/// Parses the next string attribute in `parser` as an enumerant of the given
197/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
198/// attribute with the enum class's name as attribute name.
199template <typename EnumAttrClass,
200 typename EnumClass = typename EnumAttrClass::ValueType>
201static ParseResult
202parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
203 StringRef attrName = spirv::attributeName<EnumClass>()) {
204 if (parseEnumStrAttr(value, parser))
205 return failure();
206 state.addAttribute(attrName,
207 parser.getBuilder().getAttr<EnumAttrClass>(value));
208 return success();
209}
210
211/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
212/// and inserts the enumerant into `state` as an 32-bit integer attribute with
213/// the enum class's name as attribute name.
214template <typename EnumAttrClass,
215 typename EnumClass = typename EnumAttrClass::ValueType>
216static ParseResult
217parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
218 OperationState &state,
219 StringRef attrName = spirv::attributeName<EnumClass>()) {
220 if (parseEnumKeywordAttr(value, parser))
221 return failure();
222 state.addAttribute(attrName,
223 parser.getBuilder().getAttr<EnumAttrClass>(value));
224 return success();
225}
226
227/// Parses Function, Selection and Loop control attributes. If no control is
228/// specified, "None" is used as a default.
229template <typename EnumAttrClass, typename EnumClass>
230static ParseResult
231parseControlAttribute(OpAsmParser &parser, OperationState &state,
232 StringRef attrName = spirv::attributeName<EnumClass>()) {
233 if (succeeded(parser.parseOptionalKeyword(kControl))) {
234 EnumClass control;
235 if (parser.parseLParen() ||
236 parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
237 parser.parseRParen())
238 return failure();
239 return success();
240 }
241 // Set control to "None" otherwise.
242 Builder builder = parser.getBuilder();
243 state.addAttribute(attrName,
244 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
245 return success();
246}
247
248/// Parses optional memory access attributes attached to a memory access
249/// operand/pointer. Specifically, parses the following syntax:
250/// (`[` memory-access `]`)?
251/// where:
252/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
253/// integer-literal | `"NonTemporal"`
254static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
255 OperationState &state) {
256 // Parse an optional list of attributes staring with '['
257 if (parser.parseOptionalLSquare()) {
258 // Nothing to do
259 return success();
260 }
261
262 spirv::MemoryAccess memoryAccessAttr;
263 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
264 kMemoryAccessAttrName))
265 return failure();
266
267 if (spirv::bitEnumContainsAll(memoryAccessAttr,
268 spirv::MemoryAccess::Aligned)) {
269 // Parse integer attribute for alignment.
270 Attribute alignmentAttr;
271 Type i32Type = parser.getBuilder().getIntegerType(32);
272 if (parser.parseComma() ||
273 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
274 state.attributes)) {
275 return failure();
276 }
277 }
278 return parser.parseRSquare();
279}
280
281// TODO Make sure to merge this and the previous function into one template
282// parameterized by memory access attribute name and alignment. Doing so now
283// results in VS2017 in producing an internal error (at the call site) that's
284// not detailed enough to understand what is happening.
285static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
286 OperationState &state) {
287 // Parse an optional list of attributes staring with '['
288 if (parser.parseOptionalLSquare()) {
289 // Nothing to do
290 return success();
291 }
292
293 spirv::MemoryAccess memoryAccessAttr;
294 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
295 kSourceMemoryAccessAttrName))
296 return failure();
297
298 if (spirv::bitEnumContainsAll(memoryAccessAttr,
299 spirv::MemoryAccess::Aligned)) {
300 // Parse integer attribute for alignment.
301 Attribute alignmentAttr;
302 Type i32Type = parser.getBuilder().getIntegerType(32);
303 if (parser.parseComma() ||
304 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
305 state.attributes)) {
306 return failure();
307 }
308 }
309 return parser.parseRSquare();
310}
311
312template <typename MemoryOpTy>
313static void printMemoryAccessAttribute(
314 MemoryOpTy memoryOp, OpAsmPrinter &printer,
315 SmallVectorImpl<StringRef> &elidedAttrs,
316 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
317 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
318 // Print optional memory access attribute.
319 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
320 : memoryOp.getMemoryAccess())) {
321 elidedAttrs.push_back(kMemoryAccessAttrName);
322
323 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
324
325 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
326 // Print integer alignment attribute.
327 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
328 : memoryOp.getAlignment())) {
329 elidedAttrs.push_back(kAlignmentAttrName);
330 printer << ", " << *alignment;
331 }
332 }
333 printer << "]";
334 }
335 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
336}
337
338// TODO Make sure to merge this and the previous function into one template
339// parameterized by memory access attribute name and alignment. Doing so now
340// results in VS2017 in producing an internal error (at the call site) that's
341// not detailed enough to understand what is happening.
342template <typename MemoryOpTy>
343static void printSourceMemoryAccessAttribute(
344 MemoryOpTy memoryOp, OpAsmPrinter &printer,
345 SmallVectorImpl<StringRef> &elidedAttrs,
346 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
347 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
348
349 printer << ", ";
350
351 // Print optional memory access attribute.
352 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
353 : memoryOp.getMemoryAccess())) {
354 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
355
356 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
357
358 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
359 // Print integer alignment attribute.
360 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
361 : memoryOp.getAlignment())) {
362 elidedAttrs.push_back(kSourceAlignmentAttrName);
363 printer << ", " << *alignment;
364 }
365 }
366 printer << "]";
367 }
368 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
369}
370
371static ParseResult parseImageOperands(OpAsmParser &parser,
372 spirv::ImageOperandsAttr &attr) {
373 // Expect image operands
374 if (parser.parseOptionalLSquare())
375 return success();
376
377 spirv::ImageOperands imageOperands;
378 if (parseEnumStrAttr(imageOperands, parser))
379 return failure();
380
381 attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
382
383 return parser.parseRSquare();
384}
385
386static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
387 spirv::ImageOperandsAttr attr) {
388 if (attr) {
389 auto strImageOperands = stringifyImageOperands(attr.getValue());
390 printer << "[\"" << strImageOperands << "\"]";
391 }
392}
393
394template <typename Op>
395static LogicalResult verifyImageOperands(Op imageOp,
396 spirv::ImageOperandsAttr attr,
397 Operation::operand_range operands) {
398 if (!attr) {
399 if (operands.empty())
400 return success();
401
402 return imageOp.emitError("the Image Operands should encode what operands "
403 "follow, as per Image Operands");
404 }
405
406 // TODO: Add the validation rules for the following Image Operands.
407 spirv::ImageOperands noSupportOperands =
408 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
409 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
410 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
411 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
412 spirv::ImageOperands::MakeTexelAvailable |
413 spirv::ImageOperands::MakeTexelVisible |
414 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
415
416 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
417 llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 417)
;
418
419 return success();
420}
421
422static LogicalResult verifyCastOp(Operation *op,
423 bool requireSameBitWidth = true,
424 bool skipBitWidthCheck = false) {
425 // Some CastOps have no limit on bit widths for result and operand type.
426 if (skipBitWidthCheck)
427 return success();
428
429 Type operandType = op->getOperand(0).getType();
430 Type resultType = op->getResult(0).getType();
431
432 // ODS checks that result type and operand type have the same shape.
433 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
434 operandType = vectorType.getElementType();
435 resultType = resultType.cast<VectorType>().getElementType();
436 }
437
438 if (auto coopMatrixType =
439 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
440 operandType = coopMatrixType.getElementType();
441 resultType =
442 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
443 }
444
445 if (auto jointMatrixType =
446 operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
447 operandType = jointMatrixType.getElementType();
448 resultType =
449 resultType.cast<spirv::JointMatrixINTELType>().getElementType();
450 }
451
452 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
453 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
454 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
455
456 if (requireSameBitWidth) {
457 if (!isSameBitWidth) {
458 return op->emitOpError(
459 "expected the same bit widths for operand type and result "
460 "type, but provided ")
461 << operandType << " and " << resultType;
462 }
463 return success();
464 }
465
466 if (isSameBitWidth) {
467 return op->emitOpError(
468 "expected the different bit widths for operand type and result "
469 "type, but provided ")
470 << operandType << " and " << resultType;
471 }
472 return success();
473}
474
475template <typename MemoryOpTy>
476static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
477 // ODS checks for attributes values. Just need to verify that if the
478 // memory-access attribute is Aligned, then the alignment attribute must be
479 // present.
480 auto *op = memoryOp.getOperation();
481 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
482 if (!memAccessAttr) {
483 // Alignment attribute shouldn't be present if memory access attribute is
484 // not present.
485 if (op->getAttr(kAlignmentAttrName)) {
486 return memoryOp.emitOpError(
487 "invalid alignment specification without aligned memory access "
488 "specification");
489 }
490 return success();
491 }
492
493 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
494
495 if (!memAccess) {
496 return memoryOp.emitOpError("invalid memory access specifier: ")
497 << memAccessAttr;
498 }
499
500 if (spirv::bitEnumContainsAll(memAccess.getValue(),
501 spirv::MemoryAccess::Aligned)) {
502 if (!op->getAttr(kAlignmentAttrName)) {
503 return memoryOp.emitOpError("missing alignment value");
504 }
505 } else {
506 if (op->getAttr(kAlignmentAttrName)) {
507 return memoryOp.emitOpError(
508 "invalid alignment specification with non-aligned memory access "
509 "specification");
510 }
511 }
512 return success();
513}
514
515// TODO Make sure to merge this and the previous function into one template
516// parameterized by memory access attribute name and alignment. Doing so now
517// results in VS2017 in producing an internal error (at the call site) that's
518// not detailed enough to understand what is happening.
519template <typename MemoryOpTy>
520static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
521 // ODS checks for attributes values. Just need to verify that if the
522 // memory-access attribute is Aligned, then the alignment attribute must be
523 // present.
524 auto *op = memoryOp.getOperation();
525 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
526 if (!memAccessAttr) {
527 // Alignment attribute shouldn't be present if memory access attribute is
528 // not present.
529 if (op->getAttr(kSourceAlignmentAttrName)) {
530 return memoryOp.emitOpError(
531 "invalid alignment specification without aligned memory access "
532 "specification");
533 }
534 return success();
535 }
536
537 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
538
539 if (!memAccess) {
540 return memoryOp.emitOpError("invalid memory access specifier: ")
541 << memAccess;
542 }
543
544 if (spirv::bitEnumContainsAll(memAccess.getValue(),
545 spirv::MemoryAccess::Aligned)) {
546 if (!op->getAttr(kSourceAlignmentAttrName)) {
547 return memoryOp.emitOpError("missing alignment value");
548 }
549 } else {
550 if (op->getAttr(kSourceAlignmentAttrName)) {
551 return memoryOp.emitOpError(
552 "invalid alignment specification with non-aligned memory access "
553 "specification");
554 }
555 }
556 return success();
557}
558
559static LogicalResult
560verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
561 // According to the SPIR-V specification:
562 // "Despite being a mask and allowing multiple bits to be combined, it is
563 // invalid for more than one of these four bits to be set: Acquire, Release,
564 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
565 // Release semantics is done by setting the AcquireRelease bit, not by setting
566 // two bits."
567 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
568 spirv::MemorySemantics::Release |
569 spirv::MemorySemantics::AcquireRelease |
570 spirv::MemorySemantics::SequentiallyConsistent;
571
572 auto bitCount = llvm::countPopulation(
573 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
574 if (bitCount > 1) {
575 return op->emitError(
576 "expected at most one of these four memory constraints "
577 "to be set: `Acquire`, `Release`,"
578 "`AcquireRelease` or `SequentiallyConsistent`");
579 }
580 return success();
581}
582
583template <typename LoadStoreOpTy>
584static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
585 Value val) {
586 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
587 // type of the pointer and the type of the value are the same
588 //
589 // TODO: Check that the value type satisfies restrictions of
590 // SPIR-V OpLoad/OpStore operations
591 if (val.getType() !=
592 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
593 return op.emitOpError("mismatch in result type and pointer type");
594 }
595 return success();
596}
597
598template <typename BlockReadWriteOpTy>
599static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
600 Value ptr, Value val) {
601 auto valType = val.getType();
602 if (auto valVecTy = valType.dyn_cast<VectorType>())
603 valType = valVecTy.getElementType();
604
605 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
606 return op.emitOpError("mismatch in result type and pointer type");
607 }
608 return success();
609}
610
611static ParseResult parseVariableDecorations(OpAsmParser &parser,
612 OperationState &state) {
613 auto builtInName = llvm::convertToSnakeFromCamelCase(
614 stringifyDecoration(spirv::Decoration::BuiltIn));
615 if (succeeded(parser.parseOptionalKeyword("bind"))) {
616 Attribute set, binding;
617 // Parse optional descriptor binding
618 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
619 stringifyDecoration(spirv::Decoration::DescriptorSet));
620 auto bindingName = llvm::convertToSnakeFromCamelCase(
621 stringifyDecoration(spirv::Decoration::Binding));
622 Type i32Type = parser.getBuilder().getIntegerType(32);
623 if (parser.parseLParen() ||
624 parser.parseAttribute(set, i32Type, descriptorSetName,
625 state.attributes) ||
626 parser.parseComma() ||
627 parser.parseAttribute(binding, i32Type, bindingName,
628 state.attributes) ||
629 parser.parseRParen()) {
630 return failure();
631 }
632 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
633 StringAttr builtIn;
634 if (parser.parseLParen() ||
635 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
636 parser.parseRParen()) {
637 return failure();
638 }
639 }
640
641 // Parse other attributes
642 if (parser.parseOptionalAttrDict(state.attributes))
643 return failure();
644
645 return success();
646}
647
648static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
649 SmallVectorImpl<StringRef> &elidedAttrs) {
650 // Print optional descriptor binding
651 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
652 stringifyDecoration(spirv::Decoration::DescriptorSet));
653 auto bindingName = llvm::convertToSnakeFromCamelCase(
654 stringifyDecoration(spirv::Decoration::Binding));
655 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
656 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
657 if (descriptorSet && binding) {
658 elidedAttrs.push_back(descriptorSetName);
659 elidedAttrs.push_back(bindingName);
660 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
661 << ")";
662 }
663
664 // Print BuiltIn attribute if present
665 auto builtInName = llvm::convertToSnakeFromCamelCase(
666 stringifyDecoration(spirv::Decoration::BuiltIn));
667 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
668 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
669 elidedAttrs.push_back(builtInName);
670 }
671
672 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
673}
674
675// Get bit width of types.
676static unsigned getBitWidth(Type type) {
677 if (type.isa<spirv::PointerType>()) {
678 // Just return 64 bits for pointer types for now.
679 // TODO: Make sure not caller relies on the actual pointer width value.
680 return 64;
681 }
682
683 if (type.isIntOrFloat())
684 return type.getIntOrFloatBitWidth();
685
686 if (auto vectorType = type.dyn_cast<VectorType>()) {
687 assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 687, __extension__
__PRETTY_FUNCTION__))
;
688 return vectorType.getNumElements() *
689 vectorType.getElementType().getIntOrFloatBitWidth();
690 }
691 llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 691)
;
692}
693
694/// Walks the given type hierarchy with the given indices, potentially down
695/// to component granularity, to select an element type. Returns null type and
696/// emits errors with the given loc on failure.
697static Type
698getElementType(Type type, ArrayRef<int32_t> indices,
699 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
700 if (indices.empty()) {
701 emitErrorFn("expected at least one index for spirv.CompositeExtract");
702 return nullptr;
703 }
704
705 for (auto index : indices) {
706 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
707 if (cType.hasCompileTimeKnownNumElements() &&
708 (index < 0 ||
709 static_cast<uint64_t>(index) >= cType.getNumElements())) {
710 emitErrorFn("index ") << index << " out of bounds for " << type;
711 return nullptr;
712 }
713 type = cType.getElementType(index);
714 } else {
715 emitErrorFn("cannot extract from non-composite type ")
716 << type << " with index " << index;
717 return nullptr;
718 }
719 }
720 return type;
721}
722
723static Type
724getElementType(Type type, Attribute indices,
725 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
726 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
727 if (!indicesArrayAttr) {
728 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
729 return nullptr;
730 }
731 if (indicesArrayAttr.empty()) {
732 emitErrorFn("expected at least one index for spirv.CompositeExtract");
733 return nullptr;
734 }
735
736 SmallVector<int32_t, 2> indexVals;
737 for (auto indexAttr : indicesArrayAttr) {
738 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
739 if (!indexIntAttr) {
740 emitErrorFn("expected an 32-bit integer for index, but found '")
741 << indexAttr << "'";
742 return nullptr;
743 }
744 indexVals.push_back(indexIntAttr.getInt());
745 }
746 return getElementType(type, indexVals, emitErrorFn);
747}
748
749static Type getElementType(Type type, Attribute indices, Location loc) {
750 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
751 return ::mlir::emitError(loc, err);
752 };
753 return getElementType(type, indices, errorFn);
754}
755
756static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
757 SMLoc loc) {
758 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
759 return parser.emitError(loc, err);
760 };
761 return getElementType(type, indices, errorFn);
762}
763
764/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
765static inline bool isMergeBlock(Block &block) {
766 return !block.empty() && std::next(block.begin()) == block.end() &&
767 isa<spirv::MergeOp>(block.front());
768}
769
770template <typename ExtendedBinaryOp>
771static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
772 auto resultType = op.getType().template cast<spirv::StructType>();
773 if (resultType.getNumElements() != 2)
774 return op.emitOpError("expected result struct type containing two members");
775
776 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
777 resultType.getElementType(0),
778 resultType.getElementType(1)}))
779 return op.emitOpError(
780 "expected all operand types and struct member types are the same");
781
782 return success();
783}
784
785static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
786 OperationState &result) {
787 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
788 if (parser.parseOptionalAttrDict(result.attributes) ||
789 parser.parseOperandList(operands) || parser.parseColon())
790 return failure();
791
792 Type resultType;
793 SMLoc loc = parser.getCurrentLocation();
794 if (parser.parseType(resultType))
795 return failure();
796
797 auto structType = resultType.dyn_cast<spirv::StructType>();
798 if (!structType || structType.getNumElements() != 2)
799 return parser.emitError(loc, "expected spirv.struct type with two members");
800
801 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
802 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
803 return failure();
804
805 result.addTypes(resultType);
806 return success();
807}
808
809static void printArithmeticExtendedBinaryOp(Operation *op,
810 OpAsmPrinter &printer) {
811 printer << ' ';
812 printer.printOptionalAttrDict(op->getAttrs());
813 printer.printOperands(op->getOperands());
814 printer << " : " << op->getResultTypes().front();
815}
816
817//===----------------------------------------------------------------------===//
818// Common parsers and printers
819//===----------------------------------------------------------------------===//
820
821// Parses an atomic update op. If the update op does not take a value (like
822// AtomicIIncrement) `hasValue` must be false.
823static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
824 OperationState &state, bool hasValue) {
825 spirv::Scope scope;
826 spirv::MemorySemantics memoryScope;
827 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
828 OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
829 Type type;
830 SMLoc loc;
831 if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
832 kMemoryScopeAttrName) ||
833 parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
834 kSemanticsAttrName) ||
835 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
836 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
837 return failure();
838
839 auto ptrType = type.dyn_cast<spirv::PointerType>();
840 if (!ptrType)
841 return parser.emitError(loc, "expected pointer type");
842
843 SmallVector<Type, 2> operandTypes;
844 operandTypes.push_back(ptrType);
845 if (hasValue)
846 operandTypes.push_back(ptrType.getPointeeType());
847 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
848 state.operands))
849 return failure();
850 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
851}
852
853// Prints an atomic update op.
854static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
855 printer << " \"";
856 auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
857 printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
858 auto memorySemanticsAttr =
859 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
860 printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
861 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
862}
863
864template <typename T>
865static StringRef stringifyTypeName();
866
867template <>
868StringRef stringifyTypeName<IntegerType>() {
869 return "integer";
870}
871
872template <>
873StringRef stringifyTypeName<FloatType>() {
874 return "float";
875}
876
877// Verifies an atomic update op.
878template <typename ExpectedElementType>
879static LogicalResult verifyAtomicUpdateOp(Operation *op) {
880 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
881 auto elementType = ptrType.getPointeeType();
882 if (!elementType.isa<ExpectedElementType>())
883 return op->emitOpError() << "pointer operand must point to an "
884 << stringifyTypeName<ExpectedElementType>()
885 << " value, found " << elementType;
886
887 if (op->getNumOperands() > 1) {
888 auto valueType = op->getOperand(1).getType();
889 if (valueType != elementType)
890 return op->emitOpError("expected value to have the same type as the "
891 "pointer operand's pointee type ")
892 << elementType << ", but found " << valueType;
893 }
894 auto memorySemantics =
895 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
896 .getValue();
897 if (failed(verifyMemorySemantics(op, memorySemantics))) {
898 return failure();
899 }
900 return success();
901}
902
903static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
904 OperationState &state) {
905 spirv::Scope executionScope;
906 spirv::GroupOperation groupOperation;
907 OpAsmParser::UnresolvedOperand valueInfo;
908 if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
909 kExecutionScopeAttrName) ||
910 parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
911 kGroupOperationAttrName) ||
912 parser.parseOperand(valueInfo))
913 return failure();
914
915 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
916 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
917 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
918 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
919 parser.parseRParen())
920 return failure();
921 }
922
923 Type resultType;
924 if (parser.parseColonType(resultType))
925 return failure();
926
927 if (parser.resolveOperand(valueInfo, resultType, state.operands))
928 return failure();
929
930 if (clusterSizeInfo) {
931 Type i32Type = parser.getBuilder().getIntegerType(32);
932 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
933 return failure();
934 }
935
936 return parser.addTypeToList(resultType, state.types);
937}
938
939static void printGroupNonUniformArithmeticOp(Operation *groupOp,
940 OpAsmPrinter &printer) {
941 printer
942 << " \""
943 << stringifyScope(
944 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
945 .getValue())
946 << "\" \""
947 << stringifyGroupOperation(groupOp
948 ->getAttrOfType<spirv::GroupOperationAttr>(
949 kGroupOperationAttrName)
950 .getValue())
951 << "\" " << groupOp->getOperand(0);
952
953 if (groupOp->getNumOperands() > 1)
954 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
955 printer << " : " << groupOp->getResult(0).getType();
956}
957
958static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
959 spirv::Scope scope =
960 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
961 .getValue();
962 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
963 return groupOp->emitOpError(
964 "execution scope must be 'Workgroup' or 'Subgroup'");
965
966 spirv::GroupOperation operation =
967 groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
968 .getValue();
969 if (operation == spirv::GroupOperation::ClusteredReduce &&
970 groupOp->getNumOperands() == 1)
971 return groupOp->emitOpError("cluster size operand must be provided for "
972 "'ClusteredReduce' group operation");
973 if (groupOp->getNumOperands() > 1) {
974 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
975 int32_t clusterSize = 0;
976
977 // TODO: support specialization constant here.
978 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
979 return groupOp->emitOpError(
980 "cluster size operand must come from a constant op");
981
982 if (!llvm::isPowerOf2_32(clusterSize))
983 return groupOp->emitOpError(
984 "cluster size operand must be a power of two");
985 }
986 return success();
987}
988
989/// Result of a logical op must be a scalar or vector of boolean type.
990static Type getUnaryOpResultType(Type operandType) {
991 Builder builder(operandType.getContext());
992 Type resultType = builder.getIntegerType(1);
993 if (auto vecType = operandType.dyn_cast<VectorType>())
994 return VectorType::get(vecType.getNumElements(), resultType);
995 return resultType;
996}
997
998static LogicalResult verifyShiftOp(Operation *op) {
999 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
1000 return op->emitError("expected the same type for the first operand and "
1001 "result, but provided ")
1002 << op->getOperand(0).getType() << " and "
1003 << op->getResult(0).getType();
1004 }
1005 return success();
1006}
1007
1008static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
1009 Value lhs, Value rhs) {
1010 assert(lhs.getType() == rhs.getType())(static_cast <bool> (lhs.getType() == rhs.getType()) ? void
(0) : __assert_fail ("lhs.getType() == rhs.getType()", "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp"
, 1010, __extension__ __PRETTY_FUNCTION__))
;
1011
1012 Type boolType = builder.getI1Type();
1013 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
1014 boolType = VectorType::get(vecType.getShape(), boolType);
1015 state.addTypes(boolType);
1016
1017 state.addOperands({lhs, rhs});
1018}
1019
1020static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
1021 Value value) {
1022 Type boolType = builder.getI1Type();
1023 if (auto vecType = value.getType().dyn_cast<VectorType>())
1024 boolType = VectorType::get(vecType.getShape(), boolType);
1025 state.addTypes(boolType);
1026
1027 state.addOperands(value);
1028}
1029
1030//===----------------------------------------------------------------------===//
1031// spirv.AccessChainOp
1032//===----------------------------------------------------------------------===//
1033
1034static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
1035 auto ptrType = type.dyn_cast<spirv::PointerType>();
1036 if (!ptrType) {
1037 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
1038 "to composite type, but provided ")
1039 << type;
1040 return nullptr;
1041 }
1042
1043 auto resultType = ptrType.getPointeeType();
1044 auto resultStorageClass = ptrType.getStorageClass();
1045 int32_t index = 0;
1046
1047 for (auto indexSSA : indices) {
1048 auto cType = resultType.dyn_cast<spirv::CompositeType>();
1049 if (!cType) {
1050 emitError(
1051 baseLoc,
1052 "'spirv.AccessChain' op cannot extract from non-composite type ")
1053 << resultType << " with index " << index;
1054 return nullptr;
1055 }
1056 index = 0;
1057 if (resultType.isa<spirv::StructType>()) {
1058 Operation *op = indexSSA.getDefiningOp();
1059 if (!op) {
1060 emitError(baseLoc, "'spirv.AccessChain' op index must be an "
1061 "integer spirv.Constant to access "
1062 "element of spirv.struct");
1063 return nullptr;
1064 }
1065
1066 // TODO: this should be relaxed to allow
1067 // integer literals of other bitwidths.
1068 if (failed(extractValueFromConstOp(op, index))) {
1069 emitError(
1070 baseLoc,
1071 "'spirv.AccessChain' index must be an integer spirv.Constant to "
1072 "access element of spirv.struct, but provided ")
1073 << op->getName();
1074 return nullptr;
1075 }
1076 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1077 emitError(baseLoc, "'spirv.AccessChain' op index ")
1078 << index << " out of bounds for " << resultType;
1079 return nullptr;
1080 }
1081 }
1082 resultType = cType.getElementType(index);
1083 }
1084 return spirv::PointerType::get(resultType, resultStorageClass);
1085}
1086
1087void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1088 Value basePtr, ValueRange indices) {
1089 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1090 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1090, __extension__
__PRETTY_FUNCTION__))
;
1091 build(builder, state, type, basePtr, indices);
1092}
1093
1094ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1095 OperationState &result) {
1096 OpAsmParser::UnresolvedOperand ptrInfo;
1097 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1098 Type type;
1099 auto loc = parser.getCurrentLocation();
1100 SmallVector<Type, 4> indicesTypes;
1101
1102 if (parser.parseOperand(ptrInfo) ||
1103 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1104 parser.parseColonType(type) ||
1105 parser.resolveOperand(ptrInfo, type, result.operands)) {
1106 return failure();
1107 }
1108
1109 // Check that the provided indices list is not empty before parsing their
1110 // type list.
1111 if (indicesInfo.empty()) {
1112 return mlir::emitError(result.location,
1113 "'spirv.AccessChain' op expected at "
1114 "least one index ");
1115 }
1116
1117 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1118 return failure();
1119
1120 // Check that the indices types list is not empty and that it has a one-to-one
1121 // mapping to the provided indices.
1122 if (indicesTypes.size() != indicesInfo.size()) {
1123 return mlir::emitError(
1124 result.location, "'spirv.AccessChain' op indices types' count must be "
1125 "equal to indices info count");
1126 }
1127
1128 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1129 return failure();
1130
1131 auto resultType = getElementPtrType(
1132 type, llvm::makeArrayRef(result.operands).drop_front(), result.location);
1133 if (!resultType) {
1134 return failure();
1135 }
1136
1137 result.addTypes(resultType);
1138 return success();
1139}
1140
1141template <typename Op>
1142static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1143 printer << ' ' << op.getBasePtr() << '[' << indices
1144 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
1145}
1146
1147void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1148 printAccessChain(*this, getIndices(), printer);
1149}
1150
1151template <typename Op>
1152static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1153 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
1154 indices, accessChainOp.getLoc());
1155 if (!resultType)
1156 return failure();
1157
1158 auto providedResultType =
1159 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1160 if (!providedResultType)
1161 return accessChainOp.emitOpError(
1162 "result type must be a pointer, but provided")
1163 << providedResultType;
1164
1165 if (resultType != providedResultType)
1166 return accessChainOp.emitOpError("invalid result type: expected ")
1167 << resultType << ", but provided " << providedResultType;
1168
1169 return success();
1170}
1171
1172LogicalResult spirv::AccessChainOp::verify() {
1173 return verifyAccessChain(*this, getIndices());
1174}
1175
1176//===----------------------------------------------------------------------===//
1177// spirv.mlir.addressof
1178//===----------------------------------------------------------------------===//
1179
1180void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1181 spirv::GlobalVariableOp var) {
1182 build(builder, state, var.getType(), SymbolRefAttr::get(var));
1183}
1184
1185LogicalResult spirv::AddressOfOp::verify() {
1186 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1187 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1188 getVariableAttr()));
1189 if (!varOp) {
1190 return emitOpError("expected spirv.GlobalVariable symbol");
1191 }
1192 if (getPointer().getType() != varOp.getType()) {
1193 return emitOpError(
1194 "result type mismatch with the referenced global variable's type");
1195 }
1196 return success();
1197}
1198
1199template <typename T>
1200static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1201 printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
1202 << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
1203 << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
1204 << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
1205}
1206
1207static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1208 OperationState &state) {
1209 spirv::Scope memoryScope;
1210 spirv::MemorySemantics equalSemantics, unequalSemantics;
1211 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1212 Type type;
1213 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1214 kMemoryScopeAttrName) ||
1215 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1216 equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1217 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1218 unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1219 parser.parseOperandList(operandInfo, 3))
1220 return failure();
1221
1222 auto loc = parser.getCurrentLocation();
1223 if (parser.parseColonType(type))
1224 return failure();
1225
1226 auto ptrType = type.dyn_cast<spirv::PointerType>();
1227 if (!ptrType)
1228 return parser.emitError(loc, "expected pointer type");
1229
1230 if (parser.resolveOperands(
1231 operandInfo,
1232 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1233 parser.getNameLoc(), state.operands))
1234 return failure();
1235
1236 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1237}
1238
1239template <typename T>
1240static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1241 // According to the spec:
1242 // "The type of Value must be the same as Result Type. The type of the value
1243 // pointed to by Pointer must be the same as Result Type. This type must also
1244 // match the type of Comparator."
1245 if (atomOp.getType() != atomOp.getValue().getType())
1246 return atomOp.emitOpError("value operand must have the same type as the op "
1247 "result, but found ")
1248 << atomOp.getValue().getType() << " vs " << atomOp.getType();
1249
1250 if (atomOp.getType() != atomOp.getComparator().getType())
1251 return atomOp.emitOpError(
1252 "comparator operand must have the same type as the op "
1253 "result, but found ")
1254 << atomOp.getComparator().getType() << " vs " << atomOp.getType();
1255
1256 Type pointeeType = atomOp.getPointer()
1257 .getType()
1258 .template cast<spirv::PointerType>()
1259 .getPointeeType();
1260 if (atomOp.getType() != pointeeType)
1261 return atomOp.emitOpError(
1262 "pointer operand's pointee type must have the same "
1263 "as the op result type, but found ")
1264 << pointeeType << " vs " << atomOp.getType();
1265
1266 // TODO: Unequal cannot be set to Release or Acquire and Release.
1267 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1268
1269 return success();
1270}
1271
1272//===----------------------------------------------------------------------===//
1273// spirv.AtomicAndOp
1274//===----------------------------------------------------------------------===//
1275
1276LogicalResult spirv::AtomicAndOp::verify() {
1277 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1278}
1279
1280ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1281 OperationState &result) {
1282 return ::parseAtomicUpdateOp(parser, result, true);
1283}
1284void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1285 ::printAtomicUpdateOp(*this, p);
1286}
1287
1288//===----------------------------------------------------------------------===//
1289// spirv.AtomicCompareExchangeOp
1290//===----------------------------------------------------------------------===//
1291
1292LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1293 return ::verifyAtomicCompareExchangeImpl(*this);
1294}
1295
1296ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1297 OperationState &result) {
1298 return ::parseAtomicCompareExchangeImpl(parser, result);
1299}
1300void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1301 ::printAtomicCompareExchangeImpl(*this, p);
1302}
1303
1304//===----------------------------------------------------------------------===//
1305// spirv.AtomicCompareExchangeWeakOp
1306//===----------------------------------------------------------------------===//
1307
1308LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1309 return ::verifyAtomicCompareExchangeImpl(*this);
1310}
1311
1312ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1313 OperationState &result) {
1314 return ::parseAtomicCompareExchangeImpl(parser, result);
1315}
1316void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1317 ::printAtomicCompareExchangeImpl(*this, p);
1318}
1319
1320//===----------------------------------------------------------------------===//
1321// spirv.AtomicExchange
1322//===----------------------------------------------------------------------===//
1323
1324void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1325 printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
1326 << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
1327 << " : " << getPointer().getType();
1328}
1329
1330ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1331 OperationState &result) {
1332 spirv::Scope memoryScope;
1333 spirv::MemorySemantics semantics;
1334 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1335 Type type;
1336 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1337 kMemoryScopeAttrName) ||
1338 parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1339 kSemanticsAttrName) ||
1340 parser.parseOperandList(operandInfo, 2))
1341 return failure();
1342
1343 auto loc = parser.getCurrentLocation();
1344 if (parser.parseColonType(type))
1345 return failure();
1346
1347 auto ptrType = type.dyn_cast<spirv::PointerType>();
1348 if (!ptrType)
1349 return parser.emitError(loc, "expected pointer type");
1350
1351 if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1352 parser.getNameLoc(), result.operands))
1353 return failure();
1354
1355 return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1356}
1357
1358LogicalResult spirv::AtomicExchangeOp::verify() {
1359 if (getType() != getValue().getType())
1360 return emitOpError("value operand must have the same type as the op "
1361 "result, but found ")
1362 << getValue().getType() << " vs " << getType();
1363
1364 Type pointeeType =
1365 getPointer().getType().cast<spirv::PointerType>().getPointeeType();
1366 if (getType() != pointeeType)
1367 return emitOpError("pointer operand's pointee type must have the same "
1368 "as the op result type, but found ")
1369 << pointeeType << " vs " << getType();
1370
1371 return success();
1372}
1373
1374//===----------------------------------------------------------------------===//
1375// spirv.AtomicIAddOp
1376//===----------------------------------------------------------------------===//
1377
1378LogicalResult spirv::AtomicIAddOp::verify() {
1379 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1380}
1381
1382ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1383 OperationState &result) {
1384 return ::parseAtomicUpdateOp(parser, result, true);
1385}
1386void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1387 ::printAtomicUpdateOp(*this, p);
1388}
1389
1390//===----------------------------------------------------------------------===//
1391// spirv.EXT.AtomicFAddOp
1392//===----------------------------------------------------------------------===//
1393
1394LogicalResult spirv::EXTAtomicFAddOp::verify() {
1395 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1396}
1397
1398ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
1399 OperationState &result) {
1400 return ::parseAtomicUpdateOp(parser, result, true);
1401}
1402void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
1403 ::printAtomicUpdateOp(*this, p);
1404}
1405
1406//===----------------------------------------------------------------------===//
1407// spirv.AtomicIDecrementOp
1408//===----------------------------------------------------------------------===//
1409
1410LogicalResult spirv::AtomicIDecrementOp::verify() {
1411 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1412}
1413
1414ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1415 OperationState &result) {
1416 return ::parseAtomicUpdateOp(parser, result, false);
1417}
1418void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1419 ::printAtomicUpdateOp(*this, p);
1420}
1421
1422//===----------------------------------------------------------------------===//
1423// spirv.AtomicIIncrementOp
1424//===----------------------------------------------------------------------===//
1425
1426LogicalResult spirv::AtomicIIncrementOp::verify() {
1427 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1428}
1429
1430ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1431 OperationState &result) {
1432 return ::parseAtomicUpdateOp(parser, result, false);
1433}
1434void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1435 ::printAtomicUpdateOp(*this, p);
1436}
1437
1438//===----------------------------------------------------------------------===//
1439// spirv.AtomicISubOp
1440//===----------------------------------------------------------------------===//
1441
1442LogicalResult spirv::AtomicISubOp::verify() {
1443 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1444}
1445
1446ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1447 OperationState &result) {
1448 return ::parseAtomicUpdateOp(parser, result, true);
1449}
1450void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1451 ::printAtomicUpdateOp(*this, p);
1452}
1453
1454//===----------------------------------------------------------------------===//
1455// spirv.AtomicOrOp
1456//===----------------------------------------------------------------------===//
1457
1458LogicalResult spirv::AtomicOrOp::verify() {
1459 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1460}
1461
1462ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1463 OperationState &result) {
1464 return ::parseAtomicUpdateOp(parser, result, true);
1465}
1466void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1467 ::printAtomicUpdateOp(*this, p);
1468}
1469
1470//===----------------------------------------------------------------------===//
1471// spirv.AtomicSMaxOp
1472//===----------------------------------------------------------------------===//
1473
1474LogicalResult spirv::AtomicSMaxOp::verify() {
1475 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1476}
1477
1478ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1479 OperationState &result) {
1480 return ::parseAtomicUpdateOp(parser, result, true);
1481}
1482void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1483 ::printAtomicUpdateOp(*this, p);
1484}
1485
1486//===----------------------------------------------------------------------===//
1487// spirv.AtomicSMinOp
1488//===----------------------------------------------------------------------===//
1489
1490LogicalResult spirv::AtomicSMinOp::verify() {
1491 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1492}
1493
1494ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1495 OperationState &result) {
1496 return ::parseAtomicUpdateOp(parser, result, true);
1497}
1498void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1499 ::printAtomicUpdateOp(*this, p);
1500}
1501
1502//===----------------------------------------------------------------------===//
1503// spirv.AtomicUMaxOp
1504//===----------------------------------------------------------------------===//
1505
1506LogicalResult spirv::AtomicUMaxOp::verify() {
1507 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1508}
1509
1510ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1511 OperationState &result) {
1512 return ::parseAtomicUpdateOp(parser, result, true);
1513}
1514void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1515 ::printAtomicUpdateOp(*this, p);
1516}
1517
1518//===----------------------------------------------------------------------===//
1519// spirv.AtomicUMinOp
1520//===----------------------------------------------------------------------===//
1521
1522LogicalResult spirv::AtomicUMinOp::verify() {
1523 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1524}
1525
1526ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1527 OperationState &result) {
1528 return ::parseAtomicUpdateOp(parser, result, true);
1529}
1530void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1531 ::printAtomicUpdateOp(*this, p);
1532}
1533
1534//===----------------------------------------------------------------------===//
1535// spirv.AtomicXorOp
1536//===----------------------------------------------------------------------===//
1537
1538LogicalResult spirv::AtomicXorOp::verify() {
1539 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1540}
1541
1542ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1543 OperationState &result) {
1544 return ::parseAtomicUpdateOp(parser, result, true);
1545}
1546void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1547 ::printAtomicUpdateOp(*this, p);
1548}
1549
1550//===----------------------------------------------------------------------===//
1551// spirv.BitcastOp
1552//===----------------------------------------------------------------------===//
1553
1554LogicalResult spirv::BitcastOp::verify() {
1555 // TODO: The SPIR-V spec validation rules are different for different
1556 // versions.
1557 auto operandType = getOperand().getType();
1558 auto resultType = getResult().getType();
1559 if (operandType == resultType) {
1560 return emitError("result type must be different from operand type");
1561 }
1562 if (operandType.isa<spirv::PointerType>() &&
1563 !resultType.isa<spirv::PointerType>()) {
1564 return emitError(
1565 "unhandled bit cast conversion from pointer type to non-pointer type");
1566 }
1567 if (!operandType.isa<spirv::PointerType>() &&
1568 resultType.isa<spirv::PointerType>()) {
1569 return emitError(
1570 "unhandled bit cast conversion from non-pointer type to pointer type");
1571 }
1572 auto operandBitWidth = getBitWidth(operandType);
1573 auto resultBitWidth = getBitWidth(resultType);
1574 if (operandBitWidth != resultBitWidth) {
1575 return emitOpError("mismatch in result type bitwidth ")
1576 << resultBitWidth << " and operand type bitwidth "
1577 << operandBitWidth;
1578 }
1579 return success();
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// spirv.PtrCastToGenericOp
1584//===----------------------------------------------------------------------===//
1585
1586LogicalResult spirv::PtrCastToGenericOp::verify() {
1587 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1588 auto resultType = getResult().getType().cast<spirv::PointerType>();
1589
1590 spirv::StorageClass operandStorage = operandType.getStorageClass();
1591 if (operandStorage != spirv::StorageClass::Workgroup &&
1592 operandStorage != spirv::StorageClass::CrossWorkgroup &&
1593 operandStorage != spirv::StorageClass::Function)
1594 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
1595 ", or Function Storage Class");
1596
1597 spirv::StorageClass resultStorage = resultType.getStorageClass();
1598 if (resultStorage != spirv::StorageClass::Generic)
1599 return emitError("result type must be of storage class Generic");
1600
1601 Type operandPointeeType = operandType.getPointeeType();
1602 Type resultPointeeType = resultType.getPointeeType();
1603 if (operandPointeeType != resultPointeeType)
1604 return emitOpError("pointer operand's pointee type must have the same "
1605 "as the op result type, but found ")
1606 << operandPointeeType << " vs " << resultPointeeType;
1607 return success();
1608}
1609
1610//===----------------------------------------------------------------------===//
1611// spirv.GenericCastToPtrOp
1612//===----------------------------------------------------------------------===//
1613
1614LogicalResult spirv::GenericCastToPtrOp::verify() {
1615 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1616 auto resultType = getResult().getType().cast<spirv::PointerType>();
1617
1618 spirv::StorageClass operandStorage = operandType.getStorageClass();
1619 if (operandStorage != spirv::StorageClass::Generic)
1620 return emitError("pointer type must be of storage class Generic");
1621
1622 spirv::StorageClass resultStorage = resultType.getStorageClass();
1623 if (resultStorage != spirv::StorageClass::Workgroup &&
1624 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1625 resultStorage != spirv::StorageClass::Function)
1626 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1627 "or Function Storage Class");
1628
1629 Type operandPointeeType = operandType.getPointeeType();
1630 Type resultPointeeType = resultType.getPointeeType();
1631 if (operandPointeeType != resultPointeeType)
1632 return emitOpError("pointer operand's pointee type must have the same "
1633 "as the op result type, but found ")
1634 << operandPointeeType << " vs " << resultPointeeType;
1635 return success();
1636}
1637
1638//===----------------------------------------------------------------------===//
1639// spirv.GenericCastToPtrExplicitOp
1640//===----------------------------------------------------------------------===//
1641
1642LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
1643 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1644 auto resultType = getResult().getType().cast<spirv::PointerType>();
1645
1646 spirv::StorageClass operandStorage = operandType.getStorageClass();
1647 if (operandStorage != spirv::StorageClass::Generic)
1648 return emitError("pointer type must be of storage class Generic");
1649
1650 spirv::StorageClass resultStorage = resultType.getStorageClass();
1651 if (resultStorage != spirv::StorageClass::Workgroup &&
1652 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1653 resultStorage != spirv::StorageClass::Function)
1654 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1655 "or Function Storage Class");
1656
1657 Type operandPointeeType = operandType.getPointeeType();
1658 Type resultPointeeType = resultType.getPointeeType();
1659 if (operandPointeeType != resultPointeeType)
1660 return emitOpError("pointer operand's pointee type must have the same "
1661 "as the op result type, but found ")
1662 << operandPointeeType << " vs " << resultPointeeType;
1663 return success();
1664}
1665
1666//===----------------------------------------------------------------------===//
1667// spirv.BranchOp
1668//===----------------------------------------------------------------------===//
1669
1670SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1671 assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index"
) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1671, __extension__
__PRETTY_FUNCTION__))
;
1672 return SuccessorOperands(0, getTargetOperandsMutable());
1673}
1674
1675//===----------------------------------------------------------------------===//
1676// spirv.BranchConditionalOp
1677//===----------------------------------------------------------------------===//
1678
1679SuccessorOperands
1680spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1681 assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index"
) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1681, __extension__
__PRETTY_FUNCTION__))
;
1682 return SuccessorOperands(index == kTrueIndex
1683 ? getTrueTargetOperandsMutable()
1684 : getFalseTargetOperandsMutable());
1685}
1686
1687ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1688 OperationState &result) {
1689 auto &builder = parser.getBuilder();
1690 OpAsmParser::UnresolvedOperand condInfo;
1691 Block *dest;
1692
1693 // Parse the condition.
1694 Type boolTy = builder.getI1Type();
1695 if (parser.parseOperand(condInfo) ||
1696 parser.resolveOperand(condInfo, boolTy, result.operands))
1697 return failure();
1698
1699 // Parse the optional branch weights.
1700 if (succeeded(parser.parseOptionalLSquare())) {
1701 IntegerAttr trueWeight, falseWeight;
1702 NamedAttrList weights;
1703
1704 auto i32Type = builder.getIntegerType(32);
1705 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1706 parser.parseComma() ||
1707 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1708 parser.parseRSquare())
1709 return failure();
1710
1711 result.addAttribute(kBranchWeightAttrName,
1712 builder.getArrayAttr({trueWeight, falseWeight}));
1713 }
1714
1715 // Parse the true branch.
1716 SmallVector<Value, 4> trueOperands;
1717 if (parser.parseComma() ||
1718 parser.parseSuccessorAndUseList(dest, trueOperands))
1719 return failure();
1720 result.addSuccessors(dest);
1721 result.addOperands(trueOperands);
1722
1723 // Parse the false branch.
1724 SmallVector<Value, 4> falseOperands;
1725 if (parser.parseComma() ||
1726 parser.parseSuccessorAndUseList(dest, falseOperands))
1727 return failure();
1728 result.addSuccessors(dest);
1729 result.addOperands(falseOperands);
1730 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1731 builder.getDenseI32ArrayAttr(
1732 {1, static_cast<int32_t>(trueOperands.size()),
1733 static_cast<int32_t>(falseOperands.size())}));
1734
1735 return success();
1736}
1737
1738void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1739 printer << ' ' << getCondition();
1740
1741 if (auto weights = getBranchWeights()) {
1742 printer << " [";
1743 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1744 printer << a.cast<IntegerAttr>().getInt();
1745 });
1746 printer << "]";
1747 }
1748
1749 printer << ", ";
1750 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1751 printer << ", ";
1752 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1753}
1754
1755LogicalResult spirv::BranchConditionalOp::verify() {
1756 if (auto weights = getBranchWeights()) {
1757 if (weights->getValue().size() != 2) {
1758 return emitOpError("must have exactly two branch weights");
1759 }
1760 if (llvm::all_of(*weights, [](Attribute attr) {
1761 return attr.cast<IntegerAttr>().getValue().isNullValue();
1762 }))
1763 return emitOpError("branch weights cannot both be zero");
1764 }
1765
1766 return success();
1767}
1768
1769//===----------------------------------------------------------------------===//
1770// spirv.CompositeConstruct
1771//===----------------------------------------------------------------------===//
1772
1773LogicalResult spirv::CompositeConstructOp::verify() {
1774 auto cType = getType().cast<spirv::CompositeType>();
1775 operand_range constituents = this->getConstituents();
1776
1777 if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1778 if (constituents.size() != 1)
1779 return emitOpError("has incorrect number of operands: expected ")
1780 << "1, but provided " << constituents.size();
1781 if (coopType.getElementType() != constituents.front().getType())
1782 return emitOpError("operand type mismatch: expected operand type ")
1783 << coopType.getElementType() << ", but provided "
1784 << constituents.front().getType();
1785 return success();
1786 }
1787
1788 if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1789 if (constituents.size() != 1)
1790 return emitOpError("has incorrect number of operands: expected ")
1791 << "1, but provided " << constituents.size();
1792 if (jointType.getElementType() != constituents.front().getType())
1793 return emitOpError("operand type mismatch: expected operand type ")
1794 << jointType.getElementType() << ", but provided "
1795 << constituents.front().getType();
1796 return success();
1797 }
1798
1799 if (constituents.size() == cType.getNumElements()) {
1800 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1801 if (constituents[index].getType() != cType.getElementType(index)) {
1802 return emitOpError("operand type mismatch: expected operand type ")
1803 << cType.getElementType(index) << ", but provided "
1804 << constituents[index].getType();
1805 }
1806 }
1807 return success();
1808 }
1809
1810 // If not constructing a cooperative matrix type, then we must be constructing
1811 // a vector type.
1812 auto resultType = cType.dyn_cast<VectorType>();
1813 if (!resultType)
1814 return emitOpError(
1815 "expected to return a vector or cooperative matrix when the number of "
1816 "constituents is less than what the result needs");
1817
1818 SmallVector<unsigned> sizes;
1819 for (Value component : constituents) {
1820 if (!component.getType().isa<VectorType>() &&
1821 !component.getType().isIntOrFloat())
1822 return emitOpError("operand type mismatch: expected operand to have "
1823 "a scalar or vector type, but provided ")
1824 << component.getType();
1825
1826 Type elementType = component.getType();
1827 if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1828 sizes.push_back(vectorType.getNumElements());
1829 elementType = vectorType.getElementType();
1830 } else {
1831 sizes.push_back(1);
1832 }
1833
1834 if (elementType != resultType.getElementType())
1835 return emitOpError("operand element type mismatch: expected to be ")
1836 << resultType.getElementType() << ", but provided " << elementType;
1837 }
1838 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1839 if (totalCount != cType.getNumElements())
1840 return emitOpError("has incorrect number of operands: expected ")
1841 << cType.getNumElements() << ", but provided " << totalCount;
1842 return success();
1843}
1844
1845//===----------------------------------------------------------------------===//
1846// spirv.CompositeExtractOp
1847//===----------------------------------------------------------------------===//
1848
1849void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1850 Value composite,
1851 ArrayRef<int32_t> indices) {
1852 auto indexAttr = builder.getI32ArrayAttr(indices);
1853 auto elementType =
1854 getElementType(composite.getType(), indexAttr, state.location);
1855 if (!elementType) {
1856 return;
1857 }
1858 build(builder, state, elementType, composite, indexAttr);
1859}
1860
1861ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1862 OperationState &result) {
1863 OpAsmParser::UnresolvedOperand compositeInfo;
1864 Attribute indicesAttr;
1865 Type compositeType;
1866 SMLoc attrLocation;
1867
1868 if (parser.parseOperand(compositeInfo) ||
1869 parser.getCurrentLocation(&attrLocation) ||
1870 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1871 parser.parseColonType(compositeType) ||
1872 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1873 return failure();
1874 }
1875
1876 Type resultType =
1877 getElementType(compositeType, indicesAttr, parser, attrLocation);
1878 if (!resultType) {
1879 return failure();
1880 }
1881 result.addTypes(resultType);
1882 return success();
1883}
1884
1885void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1886 printer << ' ' << getComposite() << getIndices() << " : "
1887 << getComposite().getType();
1888}
1889
1890LogicalResult spirv::CompositeExtractOp::verify() {
1891 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1892 auto resultType =
1893 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1894 if (!resultType)
1895 return failure();
1896
1897 if (resultType != getType()) {
1898 return emitOpError("invalid result type: expected ")
1899 << resultType << " but provided " << getType();
1900 }
1901
1902 return success();
1903}
1904
1905//===----------------------------------------------------------------------===//
1906// spirv.CompositeInsert
1907//===----------------------------------------------------------------------===//
1908
1909void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1910 Value object, Value composite,
1911 ArrayRef<int32_t> indices) {
1912 auto indexAttr = builder.getI32ArrayAttr(indices);
1913 build(builder, state, composite.getType(), object, composite, indexAttr);
1914}
1915
1916ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1917 OperationState &result) {
1918 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1919 Type objectType, compositeType;
1920 Attribute indicesAttr;
1921 auto loc = parser.getCurrentLocation();
1922
1923 return failure(
1924 parser.parseOperandList(operands, 2) ||
1925 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1926 parser.parseColonType(objectType) ||
1927 parser.parseKeywordType("into", compositeType) ||
1928 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1929 result.operands) ||
1930 parser.addTypesToList(compositeType, result.types));
1931}
1932
1933LogicalResult spirv::CompositeInsertOp::verify() {
1934 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1935 auto objectType =
1936 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1937 if (!objectType)
1938 return failure();
1939
1940 if (objectType != getObject().getType()) {
1941 return emitOpError("object operand type should be ")
1942 << objectType << ", but found " << getObject().getType();
1943 }
1944
1945 if (getComposite().getType() != getType()) {
1946 return emitOpError("result type should be the same as "
1947 "the composite type, but found ")
1948 << getComposite().getType() << " vs " << getType();
1949 }
1950
1951 return success();
1952}
1953
1954void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1955 printer << " " << getObject() << ", " << getComposite() << getIndices()
1956 << " : " << getObject().getType() << " into "
1957 << getComposite().getType();
1958}
1959
1960//===----------------------------------------------------------------------===//
1961// spirv.Constant
1962//===----------------------------------------------------------------------===//
1963
1964ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1965 OperationState &result) {
1966 Attribute value;
1967 if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1968 return failure();
1969
1970 Type type = NoneType::get(parser.getContext());
1971 if (auto typedAttr = value.dyn_cast<TypedAttr>())
1972 type = typedAttr.getType();
1973 if (type.isa<NoneType, TensorType>()) {
1974 if (parser.parseColonType(type))
1975 return failure();
1976 }
1977
1978 return parser.addTypeToList(type, result.types);
1979}
1980
1981void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1982 printer << ' ' << getValue();
1983 if (getType().isa<spirv::ArrayType>())
1984 printer << " : " << getType();
1985}
1986
1987static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1988 Type opType) {
1989 if (value.isa<IntegerAttr, FloatAttr>()) {
1990 auto valueType = value.cast<TypedAttr>().getType();
1991 if (valueType != opType)
1992 return op.emitOpError("result type (")
1993 << opType << ") does not match value type (" << valueType << ")";
1994 return success();
1995 }
1996 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1997 auto valueType = value.cast<TypedAttr>().getType();
1998 if (valueType == opType)
1999 return success();
2000 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
2001 auto shapedType = valueType.dyn_cast<ShapedType>();
2002 if (!arrayType)
2003 return op.emitOpError("result or element type (")
2004 << opType << ") does not match value type (" << valueType
2005 << "), must be the same or spirv.array";
2006
2007 int numElements = arrayType.getNumElements();
2008 auto opElemType = arrayType.getElementType();
2009 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
2010 numElements *= t.getNumElements();
2011 opElemType = t.getElementType();
2012 }
2013 if (!opElemType.isIntOrFloat())
2014 return op.emitOpError("only support nested array result type");
2015
2016 auto valueElemType = shapedType.getElementType();
2017 if (valueElemType != opElemType) {
2018 return op.emitOpError("result element type (")
2019 << opElemType << ") does not match value element type ("
2020 << valueElemType << ")";
2021 }
2022
2023 if (numElements != shapedType.getNumElements()) {
2024 return op.emitOpError("result number of elements (")
2025 << numElements << ") does not match value number of elements ("
2026 << shapedType.getNumElements() << ")";
2027 }
2028 return success();
2029 }
2030 if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
2031 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
2032 if (!arrayType)
2033 return op.emitOpError(
2034 "must have spirv.array result type for array value");
2035 Type elemType = arrayType.getElementType();
2036 for (Attribute element : arrayAttr.getValue()) {
2037 // Verify array elements recursively.
2038 if (failed(verifyConstantType(op, element, elemType)))
2039 return failure();
2040 }
2041 return success();
2042 }
2043 return op.emitOpError("cannot have attribute: ") << value;
2044}
2045
2046LogicalResult spirv::ConstantOp::verify() {
2047 // ODS already generates checks to make sure the result type is valid. We just
2048 // need to additionally check that the value's attribute type is consistent
2049 // with the result type.
2050 return verifyConstantType(*this, getValueAttr(), getType());
2051}
2052
2053bool spirv::ConstantOp::isBuildableWith(Type type) {
2054 // Must be valid SPIR-V type first.
2055 if (!type.isa<spirv::SPIRVType>())
2056 return false;
2057
2058 if (isa<SPIRVDialect>(type.getDialect())) {
2059 // TODO: support constant struct
2060 return type.isa<spirv::ArrayType>();
2061 }
2062
2063 return true;
2064}
2065
2066spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
2067 OpBuilder &builder) {
2068 if (auto intType = type.dyn_cast<IntegerType>()) {
2069 unsigned width = intType.getWidth();
2070 if (width == 1)
2071 return builder.create<spirv::ConstantOp>(loc, type,
2072 builder.getBoolAttr(false));
2073 return builder.create<spirv::ConstantOp>(
2074 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
2075 }
2076 if (auto floatType = type.dyn_cast<FloatType>()) {
2077 return builder.create<spirv::ConstantOp>(
2078 loc, type, builder.getFloatAttr(floatType, 0.0));
2079 }
2080 if (auto vectorType = type.dyn_cast<VectorType>()) {
2081 Type elemType = vectorType.getElementType();
2082 if (elemType.isa<IntegerType>()) {
2083 return builder.create<spirv::ConstantOp>(
2084 loc, type,
2085 DenseElementsAttr::get(vectorType,
2086 IntegerAttr::get(elemType, 0).getValue()));
2087 }
2088 if (elemType.isa<FloatType>()) {
2089 return builder.create<spirv::ConstantOp>(
2090 loc, type,
2091 DenseFPElementsAttr::get(vectorType,
2092 FloatAttr::get(elemType, 0.0).getValue()));
2093 }
2094 }
2095
2096 llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2096)
;
2097}
2098
2099spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
2100 OpBuilder &builder) {
2101 if (auto intType = type.dyn_cast<IntegerType>()) {
2102 unsigned width = intType.getWidth();
2103 if (width == 1)
2104 return builder.create<spirv::ConstantOp>(loc, type,
2105 builder.getBoolAttr(true));
2106 return builder.create<spirv::ConstantOp>(
2107 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
2108 }
2109 if (auto floatType = type.dyn_cast<FloatType>()) {
2110 return builder.create<spirv::ConstantOp>(
2111 loc, type, builder.getFloatAttr(floatType, 1.0));
2112 }
2113 if (auto vectorType = type.dyn_cast<VectorType>()) {
2114 Type elemType = vectorType.getElementType();
2115 if (elemType.isa<IntegerType>()) {
2116 return builder.create<spirv::ConstantOp>(
2117 loc, type,
2118 DenseElementsAttr::get(vectorType,
2119 IntegerAttr::get(elemType, 1).getValue()));
2120 }
2121 if (elemType.isa<FloatType>()) {
2122 return builder.create<spirv::ConstantOp>(
2123 loc, type,
2124 DenseFPElementsAttr::get(vectorType,
2125 FloatAttr::get(elemType, 1.0).getValue()));
2126 }
2127 }
2128
2129 llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2129)
;
2130}
2131
2132void mlir::spirv::ConstantOp::getAsmResultNames(
2133 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2134 Type type = getType();
2135
2136 SmallString<32> specialNameBuffer;
2137 llvm::raw_svector_ostream specialName(specialNameBuffer);
2138 specialName << "cst";
2139
2140 IntegerType intTy = type.dyn_cast<IntegerType>();
2141
2142 if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2143 if (intTy && intTy.getWidth() == 1) {
2144 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2145 }
2146
2147 if (intTy.isSignless()) {
2148 specialName << intCst.getInt();
2149 } else {
2150 specialName << intCst.getSInt();
2151 }
2152 }
2153
2154 if (intTy || type.isa<FloatType>()) {
2155 specialName << '_' << type;
2156 }
2157
2158 if (auto vecType = type.dyn_cast<VectorType>()) {
2159 specialName << "_vec_";
2160 specialName << vecType.getDimSize(0);
2161
2162 Type elementType = vecType.getElementType();
2163
2164 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2165 specialName << "x" << elementType;
2166 }
2167 }
2168
2169 setNameFn(getResult(), specialName.str());
2170}
2171
2172void mlir::spirv::AddressOfOp::getAsmResultNames(
2173 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2174 SmallString<32> specialNameBuffer;
2175 llvm::raw_svector_ostream specialName(specialNameBuffer);
2176 specialName << getVariable() << "_addr";
2177 setNameFn(getResult(), specialName.str());
2178}
2179
2180//===----------------------------------------------------------------------===//
2181// spirv.ControlBarrierOp
2182//===----------------------------------------------------------------------===//
2183
2184LogicalResult spirv::ControlBarrierOp::verify() {
2185 return verifyMemorySemantics(getOperation(), getMemorySemantics());
2186}
2187
2188//===----------------------------------------------------------------------===//
2189// spirv.ConvertFToSOp
2190//===----------------------------------------------------------------------===//
2191
2192LogicalResult spirv::ConvertFToSOp::verify() {
2193 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2194 /*skipBitWidthCheck=*/true);
2195}
2196
2197//===----------------------------------------------------------------------===//
2198// spirv.ConvertFToUOp
2199//===----------------------------------------------------------------------===//
2200
2201LogicalResult spirv::ConvertFToUOp::verify() {
2202 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2203 /*skipBitWidthCheck=*/true);
2204}
2205
2206//===----------------------------------------------------------------------===//
2207// spirv.ConvertSToFOp
2208//===----------------------------------------------------------------------===//
2209
2210LogicalResult spirv::ConvertSToFOp::verify() {
2211 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2212 /*skipBitWidthCheck=*/true);
2213}
2214
2215//===----------------------------------------------------------------------===//
2216// spirv.ConvertUToFOp
2217//===----------------------------------------------------------------------===//
2218
2219LogicalResult spirv::ConvertUToFOp::verify() {
2220 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2221 /*skipBitWidthCheck=*/true);
2222}
2223
2224//===----------------------------------------------------------------------===//
2225// spirv.EntryPoint
2226//===----------------------------------------------------------------------===//
2227
2228void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2229 spirv::ExecutionModel executionModel,
2230 spirv::FuncOp function,
2231 ArrayRef<Attribute> interfaceVars) {
2232 build(builder, state,
2233 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2234 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2235}
2236
2237ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2238 OperationState &result) {
2239 spirv::ExecutionModel execModel;
2240 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2241 SmallVector<Type, 0> idTypes;
2242 SmallVector<Attribute, 4> interfaceVars;
2243
2244 FlatSymbolRefAttr fn;
2245 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2246 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2247 return failure();
2248 }
2249
2250 if (!parser.parseOptionalComma()) {
2251 // Parse the interface variables
2252 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2253 // The name of the interface variable attribute isnt important
2254 FlatSymbolRefAttr var;
2255 NamedAttrList attrs;
2256 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2257 return failure();
2258 interfaceVars.push_back(var);
2259 return success();
2260 }))
2261 return failure();
2262 }
2263 result.addAttribute(kInterfaceAttrName,
2264 parser.getBuilder().getArrayAttr(interfaceVars));
2265 return success();
2266}
2267
2268void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2269 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
2270 printer.printSymbolName(getFn());
2271 auto interfaceVars = getInterface().getValue();
2272 if (!interfaceVars.empty()) {
2273 printer << ", ";
2274 llvm::interleaveComma(interfaceVars, printer);
2275 }
2276}
2277
2278LogicalResult spirv::EntryPointOp::verify() {
2279 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2280 // verification.
2281 return success();
2282}
2283
2284//===----------------------------------------------------------------------===//
2285// spirv.ExecutionMode
2286//===----------------------------------------------------------------------===//
2287
2288void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2289 spirv::FuncOp function,
2290 spirv::ExecutionMode executionMode,
2291 ArrayRef<int32_t> params) {
2292 build(builder, state, SymbolRefAttr::get(function),
2293 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2294 builder.getI32ArrayAttr(params));
2295}
2296
2297ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2298 OperationState &result) {
2299 spirv::ExecutionMode execMode;
2300 Attribute fn;
2301 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2302 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2303 return failure();
2304 }
2305
2306 SmallVector<int32_t, 4> values;
2307 Type i32Type = parser.getBuilder().getIntegerType(32);
2308 while (!parser.parseOptionalComma()) {
2309 NamedAttrList attr;
2310 Attribute value;
2311 if (parser.parseAttribute(value, i32Type, "value", attr)) {
2312 return failure();
2313 }
2314 values.push_back(value.cast<IntegerAttr>().getInt());
2315 }
2316 result.addAttribute(kValuesAttrName,
2317 parser.getBuilder().getI32ArrayAttr(values));
2318 return success();
2319}
2320
2321void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2322 printer << " ";
2323 printer.printSymbolName(getFn());
2324 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
2325 auto values = this->getValues();
2326 if (values.empty())
2327 return;
2328 printer << ", ";
2329 llvm::interleaveComma(values, printer, [&](Attribute a) {
2330 printer << a.cast<IntegerAttr>().getInt();
2331 });
2332}
2333
2334//===----------------------------------------------------------------------===//
2335// spirv.FConvertOp
2336//===----------------------------------------------------------------------===//
2337
2338LogicalResult spirv::FConvertOp::verify() {
2339 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2340}
2341
2342//===----------------------------------------------------------------------===//
2343// spirv.SConvertOp
2344//===----------------------------------------------------------------------===//
2345
2346LogicalResult spirv::SConvertOp::verify() {
2347 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2348}
2349
2350//===----------------------------------------------------------------------===//
2351// spirv.UConvertOp
2352//===----------------------------------------------------------------------===//
2353
2354LogicalResult spirv::UConvertOp::verify() {
2355 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2356}
2357
2358//===----------------------------------------------------------------------===//
2359// spirv.func
2360//===----------------------------------------------------------------------===//
2361
2362ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2363 SmallVector<OpAsmParser::Argument> entryArgs;
2364 SmallVector<DictionaryAttr> resultAttrs;
2365 SmallVector<Type> resultTypes;
2366 auto &builder = parser.getBuilder();
2367
2368 // Parse the name as a symbol.
2369 StringAttr nameAttr;
2370 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2371 result.attributes))
2372 return failure();
2373
2374 // Parse the function signature.
2375 bool isVariadic = false;
2376 if (function_interface_impl::parseFunctionSignature(
2377 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2378 resultAttrs))
2379 return failure();
2380
2381 SmallVector<Type> argTypes;
2382 for (auto &arg : entryArgs)
2383 argTypes.push_back(arg.type);
2384 auto fnType = builder.getFunctionType(argTypes, resultTypes);
2385 result.addAttribute(getFunctionTypeAttrName(result.name),
2386 TypeAttr::get(fnType));
2387
2388 // Parse the optional function control keyword.
2389 spirv::FunctionControl fnControl;
2390 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2391 return failure();
2392
2393 // If additional attributes are present, parse them.
2394 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2395 return failure();
2396
2397 // Add the attributes to the function arguments.
2398 assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes.
size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2398, __extension__
__PRETTY_FUNCTION__))
;
2399 function_interface_impl::addArgAndResultAttrs(
2400 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
2401 getResAttrsAttrName(result.name));
2402
2403 // Parse the optional function body.
2404 auto *body = result.addRegion();
2405 OptionalParseResult parseResult =
2406 parser.parseOptionalRegion(*body, entryArgs);
2407 return failure(parseResult.has_value() && failed(*parseResult));
2408}
2409
2410void spirv::FuncOp::print(OpAsmPrinter &printer) {
2411 // Print function name, signature, and control.
2412 printer << " ";
2413 printer.printSymbolName(getSymName());
2414 auto fnType = getFunctionType();
2415 function_interface_impl::printFunctionSignature(
2416 printer, *this, fnType.getInputs(),
2417 /*isVariadic=*/false, fnType.getResults());
2418 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
2419 << "\"";
2420 function_interface_impl::printFunctionAttributes(
2421 printer, *this,
2422 {spirv::attributeName<spirv::FunctionControl>(),
2423 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2424 getFunctionControlAttrName()});
2425
2426 // Print the body if this is not an external function.
2427 Region &body = this->getBody();
2428 if (!body.empty()) {
2429 printer << ' ';
2430 printer.printRegion(body, /*printEntryBlockArgs=*/false,
2431 /*printBlockTerminators=*/true);
2432 }
2433}
2434
2435LogicalResult spirv::FuncOp::verifyType() {
2436 if (getFunctionType().getNumResults() > 1)
2437 return emitOpError("cannot have more than one result");
2438 return success();
2439}
2440
2441LogicalResult spirv::FuncOp::verifyBody() {
2442 FunctionType fnType = getFunctionType();
2443
2444 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2445 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2446 if (fnType.getNumResults() != 0)
2447 return retOp.emitOpError("cannot be used in functions returning value");
2448 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2449 if (fnType.getNumResults() != 1)
2450 return retOp.emitOpError(
2451 "returns 1 value but enclosing function requires ")
2452 << fnType.getNumResults() << " results";
2453
2454 auto retOperandType = retOp.getValue().getType();
2455 auto fnResultType = fnType.getResult(0);
2456 if (retOperandType != fnResultType)
2457 return retOp.emitOpError(" return value's type (")
2458 << retOperandType << ") mismatch with function's result type ("
2459 << fnResultType << ")";
2460 }
2461 return WalkResult::advance();
2462 });
2463
2464 // TODO: verify other bits like linkage type.
2465
2466 return failure(walkResult.wasInterrupted());
2467}
2468
2469void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2470 StringRef name, FunctionType type,
2471 spirv::FunctionControl control,
2472 ArrayRef<NamedAttribute> attrs) {
2473 state.addAttribute(SymbolTable::getSymbolAttrName(),
2474 builder.getStringAttr(name));
2475 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
2476 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2477 builder.getAttr<spirv::FunctionControlAttr>(control));
2478 state.attributes.append(attrs.begin(), attrs.end());
2479 state.addRegion();
2480}
2481
2482// CallableOpInterface
2483Region *spirv::FuncOp::getCallableRegion() {
2484 return isExternal() ? nullptr : &getBody();
2485}
2486
2487// CallableOpInterface
2488ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2489 return getFunctionType().getResults();
2490}
2491
2492//===----------------------------------------------------------------------===//
2493// spirv.FunctionCall
2494//===----------------------------------------------------------------------===//
2495
2496LogicalResult spirv::FunctionCallOp::verify() {
2497 auto fnName = getCalleeAttr();
2498
2499 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2500 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2501 if (!funcOp) {
2502 return emitOpError("callee function '")
2503 << fnName.getValue() << "' not found in nearest symbol table";
2504 }
2505
2506 auto functionType = funcOp.getFunctionType();
2507
2508 if (getNumResults() > 1) {
2509 return emitOpError(
2510 "expected callee function to have 0 or 1 result, but provided ")
2511 << getNumResults();
2512 }
2513
2514 if (functionType.getNumInputs() != getNumOperands()) {
2515 return emitOpError("has incorrect number of operands for callee: expected ")
2516 << functionType.getNumInputs() << ", but provided "
2517 << getNumOperands();
2518 }
2519
2520 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2521 if (getOperand(i).getType() != functionType.getInput(i)) {
2522 return emitOpError("operand type mismatch: expected operand type ")
2523 << functionType.getInput(i) << ", but provided "
2524 << getOperand(i).getType() << " for operand number " << i;
2525 }
2526 }
2527
2528 if (functionType.getNumResults() != getNumResults()) {
2529 return emitOpError(
2530 "has incorrect number of results has for callee: expected ")
2531 << functionType.getNumResults() << ", but provided "
2532 << getNumResults();
2533 }
2534
2535 if (getNumResults() &&
2536 (getResult(0).getType() != functionType.getResult(0))) {
2537 return emitOpError("result type mismatch: expected ")
2538 << functionType.getResult(0) << ", but provided "
2539 << getResult(0).getType();
2540 }
2541
2542 return success();
2543}
2544
2545CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2546 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2547}
2548
2549Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2550 return getArguments();
2551}
2552
2553//===----------------------------------------------------------------------===//
2554// spirv.GLFClampOp
2555//===----------------------------------------------------------------------===//
2556
2557ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2558 OperationState &result) {
2559 return parseOneResultSameOperandTypeOp(parser, result);
2560}
2561void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2562
2563//===----------------------------------------------------------------------===//
2564// spirv.GLUClampOp
2565//===----------------------------------------------------------------------===//
2566
2567ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2568 OperationState &result) {
2569 return parseOneResultSameOperandTypeOp(parser, result);
2570}
2571void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2572
2573//===----------------------------------------------------------------------===//
2574// spirv.GLSClampOp
2575//===----------------------------------------------------------------------===//
2576
2577ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2578 OperationState &result) {
2579 return parseOneResultSameOperandTypeOp(parser, result);
2580}
2581void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2582
2583//===----------------------------------------------------------------------===//
2584// spirv.GLFmaOp
2585//===----------------------------------------------------------------------===//
2586
2587ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2588 return parseOneResultSameOperandTypeOp(parser, result);
2589}
2590void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2591
2592//===----------------------------------------------------------------------===//
2593// spirv.GlobalVariable
2594//===----------------------------------------------------------------------===//
2595
2596void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2597 Type type, StringRef name,
2598 unsigned descriptorSet, unsigned binding) {
2599 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2600 state.addAttribute(
2601 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2602 builder.getI32IntegerAttr(descriptorSet));
2603 state.addAttribute(
2604 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2605 builder.getI32IntegerAttr(binding));
2606}
2607
2608void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2609 Type type, StringRef name,
2610 spirv::BuiltIn builtin) {
2611 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2612 state.addAttribute(
2613 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2614 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2615}
2616
2617ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2618 OperationState &result) {
2619 // Parse variable name.
2620 StringAttr nameAttr;
2621 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2622 result.attributes)) {
2623 return failure();
2624 }
2625
2626 // Parse optional initializer
2627 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2628 FlatSymbolRefAttr initSymbol;
2629 if (parser.parseLParen() ||
2630 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2631 result.attributes) ||
2632 parser.parseRParen())
2633 return failure();
2634 }
2635
2636 if (parseVariableDecorations(parser, result)) {
2637 return failure();
2638 }
2639
2640 Type type;
2641 auto loc = parser.getCurrentLocation();
2642 if (parser.parseColonType(type)) {
2643 return failure();
2644 }
2645 if (!type.isa<spirv::PointerType>()) {
2646 return parser.emitError(loc, "expected spirv.ptr type");
2647 }
2648 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2649
2650 return success();
2651}
2652
2653void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2654 SmallVector<StringRef, 4> elidedAttrs{
2655 spirv::attributeName<spirv::StorageClass>()};
2656
2657 // Print variable name.
2658 printer << ' ';
2659 printer.printSymbolName(getSymName());
2660 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2661
2662 // Print optional initializer
2663 if (auto initializer = this->getInitializer()) {
2664 printer << " " << kInitializerAttrName << '(';
2665 printer.printSymbolName(*initializer);
2666 printer << ')';
2667 elidedAttrs.push_back(kInitializerAttrName);
2668 }
2669
2670 elidedAttrs.push_back(kTypeAttrName);
2671 printVariableDecorations(*this, printer, elidedAttrs);
2672 printer << " : " << getType();
2673}
2674
2675LogicalResult spirv::GlobalVariableOp::verify() {
2676 if (!getType().isa<spirv::PointerType>())
2677 return emitOpError("result must be of a !spv.ptr type");
2678
2679 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2680 // object. It cannot be Generic. It must be the same as the Storage Class
2681 // operand of the Result Type."
2682 // Also, Function storage class is reserved by spirv.Variable.
2683 auto storageClass = this->storageClass();
2684 if (storageClass == spirv::StorageClass::Generic ||
2685 storageClass == spirv::StorageClass::Function) {
2686 return emitOpError("storage class cannot be '")
2687 << stringifyStorageClass(storageClass) << "'";
2688 }
2689
2690 if (auto init =
2691 (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2692 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2693 (*this)->getParentOp(), init.getAttr());
2694 // TODO: Currently only variable initialization with specialization
2695 // constants and other variables is supported. They could be normal
2696 // constants in the module scope as well.
2697 if (!initOp ||
2698 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2699 return emitOpError("initializer must be result of a "
2700 "spirv.SpecConstant or spirv.GlobalVariable op");
2701 }
2702 }
2703
2704 return success();
2705}
2706
2707//===----------------------------------------------------------------------===//
2708// spirv.GroupBroadcast
2709//===----------------------------------------------------------------------===//
2710
2711LogicalResult spirv::GroupBroadcastOp::verify() {
2712 spirv::Scope scope = getExecutionScope();
2713 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2714 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2715
2716 if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2717 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2718 return emitOpError("localid is a vector and can be with only "
2719 " 2 or 3 components, actual number is ")
2720 << localIdTy.getNumElements();
2721
2722 return success();
2723}
2724
2725//===----------------------------------------------------------------------===//
2726// spirv.GroupNonUniformBallotOp
2727//===----------------------------------------------------------------------===//
2728
2729LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2730 spirv::Scope scope = getExecutionScope();
2731 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2732 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2733
2734 return success();
2735}
2736
2737//===----------------------------------------------------------------------===//
2738// spirv.GroupNonUniformBroadcast
2739//===----------------------------------------------------------------------===//
2740
2741LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2742 spirv::Scope scope = getExecutionScope();
2743 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2744 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2745
2746 // SPIR-V spec: "Before version 1.5, Id must come from a
2747 // constant instruction.
2748 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2749 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2750 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2751
2752 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2753 auto *idOp = getId().getDefiningOp();
2754 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2755 spirv::ReferenceOfOp>(idOp)) // for spec constant
2756 return emitOpError("id must be the result of a constant op");
2757 }
2758
2759 return success();
2760}
2761
2762//===----------------------------------------------------------------------===//
2763// spirv.GroupNonUniformShuffle*
2764//===----------------------------------------------------------------------===//
2765
2766template <typename OpTy>
2767static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
2768 spirv::Scope scope = op.getExecutionScope();
2769 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2770 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2771
2772 if (op.getOperands().back().getType().isSignedInteger())
2773 return op.emitOpError("second operand must be a singless/unsigned integer");
2774
2775 return success();
2776}
2777
2778LogicalResult spirv::GroupNonUniformShuffleOp::verify() {
2779 return verifyGroupNonUniformShuffleOp(*this);
2780}
2781LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() {
2782 return verifyGroupNonUniformShuffleOp(*this);
2783}
2784LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() {
2785 return verifyGroupNonUniformShuffleOp(*this);
2786}
2787LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
2788 return verifyGroupNonUniformShuffleOp(*this);
2789}
2790
2791//===----------------------------------------------------------------------===//
2792// spirv.INTEL.SubgroupBlockRead
2793//===----------------------------------------------------------------------===//
2794
2795ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
2796 OperationState &result) {
2797 // Parse the storage class specification
2798 spirv::StorageClass storageClass;
2799 OpAsmParser::UnresolvedOperand ptrInfo;
2800 Type elementType;
2801 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2802 parser.parseColon() || parser.parseType(elementType)) {
2803 return failure();
2804 }
2805
2806 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2807 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2808 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2809
2810 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2811 return failure();
2812 }
2813
2814 result.addTypes(elementType);
2815 return success();
2816}
2817
2818void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
2819 printer << " " << getPtr() << " : " << getType();
2820}
2821
2822LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
2823 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2824 return failure();
2825
2826 return success();
2827}
2828
2829//===----------------------------------------------------------------------===//
2830// spirv.INTEL.SubgroupBlockWrite
2831//===----------------------------------------------------------------------===//
2832
2833ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
2834 OperationState &result) {
2835 // Parse the storage class specification
2836 spirv::StorageClass storageClass;
2837 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2838 auto loc = parser.getCurrentLocation();
2839 Type elementType;
2840 if (parseEnumStrAttr(storageClass, parser) ||
2841 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2842 parser.parseType(elementType)) {
2843 return failure();
2844 }
2845
2846 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2847 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2848 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2849
2850 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2851 result.operands)) {
2852 return failure();
2853 }
2854 return success();
2855}
2856
2857void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
2858 printer << " " << getPtr() << ", " << getValue() << " : "
2859 << getValue().getType();
2860}
2861
2862LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
2863 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2864 return failure();
2865
2866 return success();
2867}
2868
2869//===----------------------------------------------------------------------===//
2870// spirv.GroupNonUniformElectOp
2871//===----------------------------------------------------------------------===//
2872
2873LogicalResult spirv::GroupNonUniformElectOp::verify() {
2874 spirv::Scope scope = getExecutionScope();
2875 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2876 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2877
2878 return success();
2879}
2880
2881//===----------------------------------------------------------------------===//
2882// spirv.GroupNonUniformFAddOp
2883//===----------------------------------------------------------------------===//
2884
2885LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2886 return verifyGroupNonUniformArithmeticOp(*this);
2887}
2888
2889ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2890 OperationState &result) {
2891 return parseGroupNonUniformArithmeticOp(parser, result);
2892}
2893void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2894 printGroupNonUniformArithmeticOp(*this, p);
2895}
2896
2897//===----------------------------------------------------------------------===//
2898// spirv.GroupNonUniformFMaxOp
2899//===----------------------------------------------------------------------===//
2900
2901LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2902 return verifyGroupNonUniformArithmeticOp(*this);
2903}
2904
2905ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2906 OperationState &result) {
2907 return parseGroupNonUniformArithmeticOp(parser, result);
2908}
2909void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2910 printGroupNonUniformArithmeticOp(*this, p);
2911}
2912
2913//===----------------------------------------------------------------------===//
2914// spirv.GroupNonUniformFMinOp
2915//===----------------------------------------------------------------------===//
2916
2917LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2918 return verifyGroupNonUniformArithmeticOp(*this);
2919}
2920
2921ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2922 OperationState &result) {
2923 return parseGroupNonUniformArithmeticOp(parser, result);
2924}
2925void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2926 printGroupNonUniformArithmeticOp(*this, p);
2927}
2928
2929//===----------------------------------------------------------------------===//
2930// spirv.GroupNonUniformFMulOp
2931//===----------------------------------------------------------------------===//
2932
2933LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2934 return verifyGroupNonUniformArithmeticOp(*this);
2935}
2936
2937ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2938 OperationState &result) {
2939 return parseGroupNonUniformArithmeticOp(parser, result);
2940}
2941void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2942 printGroupNonUniformArithmeticOp(*this, p);
2943}
2944
2945//===----------------------------------------------------------------------===//
2946// spirv.GroupNonUniformIAddOp
2947//===----------------------------------------------------------------------===//
2948
2949LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2950 return verifyGroupNonUniformArithmeticOp(*this);
2951}
2952
2953ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2954 OperationState &result) {
2955 return parseGroupNonUniformArithmeticOp(parser, result);
2956}
2957void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2958 printGroupNonUniformArithmeticOp(*this, p);
2959}
2960
2961//===----------------------------------------------------------------------===//
2962// spirv.GroupNonUniformIMulOp
2963//===----------------------------------------------------------------------===//
2964
2965LogicalResult spirv::GroupNonUniformIMulOp::verify() {
2966 return verifyGroupNonUniformArithmeticOp(*this);
2967}
2968
2969ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2970 OperationState &result) {
2971 return parseGroupNonUniformArithmeticOp(parser, result);
2972}
2973void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
2974 printGroupNonUniformArithmeticOp(*this, p);
2975}
2976
2977//===----------------------------------------------------------------------===//
2978// spirv.GroupNonUniformSMaxOp
2979//===----------------------------------------------------------------------===//
2980
2981LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
2982 return verifyGroupNonUniformArithmeticOp(*this);
2983}
2984
2985ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2986 OperationState &result) {
2987 return parseGroupNonUniformArithmeticOp(parser, result);
2988}
2989void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
2990 printGroupNonUniformArithmeticOp(*this, p);
2991}
2992
2993//===----------------------------------------------------------------------===//
2994// spirv.GroupNonUniformSMinOp
2995//===----------------------------------------------------------------------===//
2996
2997LogicalResult spirv::GroupNonUniformSMinOp::verify() {
2998 return verifyGroupNonUniformArithmeticOp(*this);
2999}
3000
3001ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
3002 OperationState &result) {
3003 return parseGroupNonUniformArithmeticOp(parser, result);
3004}
3005void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
3006 printGroupNonUniformArithmeticOp(*this, p);
3007}
3008
3009//===----------------------------------------------------------------------===//
3010// spirv.GroupNonUniformUMaxOp
3011//===----------------------------------------------------------------------===//
3012
3013LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
3014 return verifyGroupNonUniformArithmeticOp(*this);
3015}
3016
3017ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
3018 OperationState &result) {
3019 return parseGroupNonUniformArithmeticOp(parser, result);
3020}
3021void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
3022 printGroupNonUniformArithmeticOp(*this, p);
3023}
3024
3025//===----------------------------------------------------------------------===//
3026// spirv.GroupNonUniformUMinOp
3027//===----------------------------------------------------------------------===//
3028
3029LogicalResult spirv::GroupNonUniformUMinOp::verify() {
3030 return verifyGroupNonUniformArithmeticOp(*this);
3031}
3032
3033ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
3034 OperationState &result) {
3035 return parseGroupNonUniformArithmeticOp(parser, result);
3036}
3037void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
3038 printGroupNonUniformArithmeticOp(*this, p);
3039}
3040
3041//===----------------------------------------------------------------------===//
3042// spirv.IAddCarryOp
3043//===----------------------------------------------------------------------===//
3044
3045LogicalResult spirv::IAddCarryOp::verify() {
3046 return ::verifyArithmeticExtendedBinaryOp(*this);
3047}
3048
3049ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
3050 OperationState &result) {
3051 return ::parseArithmeticExtendedBinaryOp(parser, result);
3052}
3053
3054void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
3055 ::printArithmeticExtendedBinaryOp(*this, printer);
3056}
3057
3058//===----------------------------------------------------------------------===//
3059// spirv.ISubBorrowOp
3060//===----------------------------------------------------------------------===//
3061
3062LogicalResult spirv::ISubBorrowOp::verify() {
3063 return ::verifyArithmeticExtendedBinaryOp(*this);
3064}
3065
3066ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
3067 OperationState &result) {
3068 return ::parseArithmeticExtendedBinaryOp(parser, result);
3069}
3070
3071void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
3072 ::printArithmeticExtendedBinaryOp(*this, printer);
3073}
3074
3075//===----------------------------------------------------------------------===//
3076// spirv.SMulExtended
3077//===----------------------------------------------------------------------===//
3078
3079LogicalResult spirv::SMulExtendedOp::verify() {
3080 return ::verifyArithmeticExtendedBinaryOp(*this);
3081}
3082
3083ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
3084 OperationState &result) {
3085 return ::parseArithmeticExtendedBinaryOp(parser, result);
3086}
3087
3088void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
3089 ::printArithmeticExtendedBinaryOp(*this, printer);
3090}
3091
3092//===----------------------------------------------------------------------===//
3093// spirv.UMulExtended
3094//===----------------------------------------------------------------------===//
3095
3096LogicalResult spirv::UMulExtendedOp::verify() {
3097 return ::verifyArithmeticExtendedBinaryOp(*this);
3098}
3099
3100ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
3101 OperationState &result) {
3102 return ::parseArithmeticExtendedBinaryOp(parser, result);
3103}
3104
3105void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
3106 ::printArithmeticExtendedBinaryOp(*this, printer);
3107}
3108
3109//===----------------------------------------------------------------------===//
3110// spirv.LoadOp
3111//===----------------------------------------------------------------------===//
3112
3113void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
3114 Value basePtr, MemoryAccessAttr memoryAccess,
3115 IntegerAttr alignment) {
3116 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3117 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3118 alignment);
3119}
3120
3121ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3122 // Parse the storage class specification
3123 spirv::StorageClass storageClass;
3124 OpAsmParser::UnresolvedOperand ptrInfo;
3125 Type elementType;
3126 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3127 parseMemoryAccessAttributes(parser, result) ||
3128 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3129 parser.parseType(elementType)) {
3130 return failure();
3131 }
3132
3133 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3134 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3135 return failure();
3136 }
3137
3138 result.addTypes(elementType);
3139 return success();
3140}
3141
3142void spirv::LoadOp::print(OpAsmPrinter &printer) {
3143 SmallVector<StringRef, 4> elidedAttrs;
3144 StringRef sc = stringifyStorageClass(
3145 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3146 printer << " \"" << sc << "\" " << getPtr();
3147
3148 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3149
3150 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3151 printer << " : " << getType();
3152}
3153
3154LogicalResult spirv::LoadOp::verify() {
3155 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3156 // type with fixed size; i.e., it cannot be, nor include, any
3157 // OpTypeRuntimeArray types."
3158 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
3159 return failure();
3160 }
3161 return verifyMemoryAccessAttribute(*this);
3162}
3163
3164//===----------------------------------------------------------------------===//
3165// spirv.mlir.loop
3166//===----------------------------------------------------------------------===//
3167
3168void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3169 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3170 spirv::LoopControl::None));
3171 state.addRegion();
3172}
3173
3174ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3175 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3176 result))
3177 return failure();
3178 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3179}
3180
3181void spirv::LoopOp::print(OpAsmPrinter &printer) {
3182 auto control = getLoopControl();
3183 if (control != spirv::LoopControl::None)
3184 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3185 printer << ' ';
3186 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3187 /*printBlockTerminators=*/true);
3188}
3189
3190/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
3191/// given `dstBlock`.
3192static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3193 // Check that there is only one op in the `srcBlock`.
3194 if (!llvm::hasSingleElement(srcBlock))
3195 return false;
3196
3197 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3198 return branchOp && branchOp.getSuccessor() == &dstBlock;
3199}
3200
3201LogicalResult spirv::LoopOp::verifyRegions() {
3202 auto *op = getOperation();
3203
3204 // We need to verify that the blocks follow the following layout:
3205 //
3206 // +-------------+
3207 // | entry block |
3208 // +-------------+
3209 // |
3210 // v
3211 // +-------------+
3212 // | loop header | <-----+
3213 // +-------------+ |
3214 // |
3215 // ... |
3216 // \ | / |
3217 // v |
3218 // +---------------+ |
3219 // | loop continue | -----+
3220 // +---------------+
3221 //
3222 // ...
3223 // \ | /
3224 // v
3225 // +-------------+
3226 // | merge block |
3227 // +-------------+
3228
3229 auto &region = op->getRegion(0);
3230 // Allow empty region as a degenerated case, which can come from
3231 // optimizations.
3232 if (region.empty())
3233 return success();
3234
3235 // The last block is the merge block.
3236 Block &merge = region.back();
3237 if (!isMergeBlock(merge))
3238 return emitOpError("last block must be the merge block with only one "
3239 "'spirv.mlir.merge' op");
3240
3241 if (std::next(region.begin()) == region.end())
3242 return emitOpError(
3243 "must have an entry block branching to the loop header block");
3244 // The first block is the entry block.
3245 Block &entry = region.front();
3246
3247 if (std::next(region.begin(), 2) == region.end())
3248 return emitOpError(
3249 "must have a loop header block branched from the entry block");
3250 // The second block is the loop header block.
3251 Block &header = *std::next(region.begin(), 1);
3252
3253 if (!hasOneBranchOpTo(entry, header))
3254 return emitOpError(
3255 "entry block must only have one 'spirv.Branch' op to the second block");
3256
3257 if (std::next(region.begin(), 3) == region.end())
3258 return emitOpError(
3259 "requires a loop continue block branching to the loop header block");
3260 // The second to last block is the loop continue block.
3261 Block &cont = *std::prev(region.end(), 2);
3262
3263 // Make sure that we have a branch from the loop continue block to the loop
3264 // header block.
3265 if (llvm::none_of(
3266 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3267 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3268 return emitOpError("second to last block must be the loop continue "
3269 "block that branches to the loop header block");
3270
3271 // Make sure that no other blocks (except the entry and loop continue block)
3272 // branches to the loop header block.
3273 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3274 std::prev(region.end(), 2))) {
3275 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3276 if (block.getSuccessor(i) == &header) {
3277 return emitOpError("can only have the entry and loop continue "
3278 "block branching to the loop header block");
3279 }
3280 }
3281 }
3282
3283 return success();
3284}
3285
3286Block *spirv::LoopOp::getEntryBlock() {
3287 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3287, __extension__
__PRETTY_FUNCTION__))
;
3288 return &getBody().front();
3289}
3290
3291Block *spirv::LoopOp::getHeaderBlock() {
3292 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3292, __extension__
__PRETTY_FUNCTION__))
;
3293 // The second block is the loop header block.
3294 return &*std::next(getBody().begin());
3295}
3296
3297Block *spirv::LoopOp::getContinueBlock() {
3298 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3298, __extension__
__PRETTY_FUNCTION__))
;
3299 // The second to last block is the loop continue block.
3300 return &*std::prev(getBody().end(), 2);
3301}
3302
3303Block *spirv::LoopOp::getMergeBlock() {
3304 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3304, __extension__
__PRETTY_FUNCTION__))
;
3305 // The last block is the loop merge block.
3306 return &getBody().back();
3307}
3308
3309void spirv::LoopOp::addEntryAndMergeBlock() {
3310 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3310, __extension__
__PRETTY_FUNCTION__))
;
3311 getBody().push_back(new Block());
3312 auto *mergeBlock = new Block();
3313 getBody().push_back(mergeBlock);
3314 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3315
3316 // Add a spirv.mlir.merge op into the merge block.
3317 builder.create<spirv::MergeOp>(getLoc());
3318}
3319
3320//===----------------------------------------------------------------------===//
3321// spirv.MemoryBarrierOp
3322//===----------------------------------------------------------------------===//
3323
3324LogicalResult spirv::MemoryBarrierOp::verify() {
3325 return verifyMemorySemantics(getOperation(), getMemorySemantics());
3326}
3327
3328//===----------------------------------------------------------------------===//
3329// spirv.mlir.merge
3330//===----------------------------------------------------------------------===//
3331
3332LogicalResult spirv::MergeOp::verify() {
3333 auto *parentOp = (*this)->getParentOp();
3334 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3335 return emitOpError(
3336 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3337
3338 // TODO: This check should be done in `verifyRegions` of parent op.
3339 Block &parentLastBlock = (*this)->getParentRegion()->back();
3340 if (getOperation() != parentLastBlock.getTerminator())
3341 return emitOpError("can only be used in the last block of "
3342 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3343 return success();
3344}
3345
3346//===----------------------------------------------------------------------===//
3347// spirv.module
3348//===----------------------------------------------------------------------===//
3349
3350void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3351 std::optional<StringRef> name) {
3352 OpBuilder::InsertionGuard guard(builder);
3353 builder.createBlock(state.addRegion());
3354 if (name) {
3355 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3356 builder.getStringAttr(*name));
3357 }
3358}
3359
3360void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3361 spirv::AddressingModel addressingModel,
3362 spirv::MemoryModel memoryModel,
3363 std::optional<VerCapExtAttr> vceTriple,
3364 std::optional<StringRef> name) {
3365 state.addAttribute(
3366 "addressing_model",
3367 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3368 state.addAttribute("memory_model",
3369 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3370 OpBuilder::InsertionGuard guard(builder);
3371 builder.createBlock(state.addRegion());
3372 if (vceTriple)
3373 state.addAttribute(getVCETripleAttrName(), *vceTriple);
3374 if (name)
3375 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3376 builder.getStringAttr(*name));
3377}
3378
3379ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3380 OperationState &result) {
3381 Region *body = result.addRegion();
3382
3383 // If the name is present, parse it.
3384 StringAttr nameAttr;
3385 (void)parser.parseOptionalSymbolName(
3386 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3387
3388 // Parse attributes
3389 spirv::AddressingModel addrModel;
3390 spirv::MemoryModel memoryModel;
3391 if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3392 result) ||
3393 ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3394 result))
3395 return failure();
3396
3397 if (succeeded(parser.parseOptionalKeyword("requires"))) {
3398 spirv::VerCapExtAttr vceTriple;
3399 if (parser.parseAttribute(vceTriple,
3400 spirv::ModuleOp::getVCETripleAttrName(),
3401 result.attributes))
3402 return failure();
3403 }
3404
3405 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3406 parser.parseRegion(*body, /*arguments=*/{}))
3407 return failure();
3408
3409 // Make sure we have at least one block.
3410 if (body->empty())
3411 body->push_back(new Block());
3412
3413 return success();
3414}
3415
3416void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3417 if (std::optional<StringRef> name = getName()) {
3418 printer << ' ';
3419 printer.printSymbolName(*name);
3420 }
3421
3422 SmallVector<StringRef, 2> elidedAttrs;
3423
3424 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
3425 << spirv::stringifyMemoryModel(getMemoryModel());
3426 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3427 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3428 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3429 mlir::SymbolTable::getSymbolAttrName()});
3430
3431 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3432 printer << " requires " << *triple;
3433 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3434 }
3435
3436 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3437 printer << ' ';
3438 printer.printRegion(getRegion());
3439}
3440
3441LogicalResult spirv::ModuleOp::verifyRegions() {
3442 Dialect *dialect = (*this)->getDialect();
3443 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3444 entryPoints;
3445 mlir::SymbolTable table(*this);
3446
3447 for (auto &op : *getBody()) {
3448 if (op.getDialect() != dialect)
3449 return op.emitError("'spirv.module' can only contain spirv.* ops");
3450
3451 // For EntryPoint op, check that the function and execution model is not
3452 // duplicated in EntryPointOps. Also verify that the interface specified
3453 // comes from globalVariables here to make this check cheaper.
3454 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3455 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3456 if (!funcOp) {
3457 return entryPointOp.emitError("function '")
3458 << entryPointOp.getFn() << "' not found in 'spirv.module'";
3459 }
3460 if (auto interface = entryPointOp.getInterface()) {
3461 for (Attribute varRef : interface) {
3462 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3463 if (!varSymRef) {
3464 return entryPointOp.emitError(
3465 "expected symbol reference for interface "
3466 "specification instead of '")
3467 << varRef;
3468 }
3469 auto variableOp =
3470 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3471 if (!variableOp) {
3472 return entryPointOp.emitError("expected spirv.GlobalVariable "
3473 "symbol reference instead of'")
3474 << varSymRef << "'";
3475 }
3476 }
3477 }
3478
3479 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3480 funcOp, entryPointOp.getExecutionModel());
3481 auto entryPtIt = entryPoints.find(key);
3482 if (entryPtIt != entryPoints.end()) {
3483 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3484 }
3485 entryPoints[key] = entryPointOp;
3486 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3487 if (funcOp.isExternal())
3488 return op.emitError("'spirv.module' cannot contain external functions");
3489
3490 // TODO: move this check to spirv.func.
3491 for (auto &block : funcOp)
3492 for (auto &op : block) {
3493 if (op.getDialect() != dialect)
3494 return op.emitError(
3495 "functions in 'spirv.module' can only contain spirv.* ops");
3496 }
3497 }
3498 }
3499
3500 return success();
3501}
3502
3503//===----------------------------------------------------------------------===//
3504// spirv.mlir.referenceof
3505//===----------------------------------------------------------------------===//
3506
3507LogicalResult spirv::ReferenceOfOp::verify() {
3508 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3509 (*this)->getParentOp(), getSpecConstAttr());
3510 Type constType;
3511
3512 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3513 if (specConstOp)
3514 constType = specConstOp.getDefaultValue().getType();
3515
3516 auto specConstCompositeOp =
3517 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3518 if (specConstCompositeOp)
3519 constType = specConstCompositeOp.getType();
3520
3521 if (!specConstOp && !specConstCompositeOp)
3522 return emitOpError(
3523 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3524
3525 if (getReference().getType() != constType)
3526 return emitOpError("result type mismatch with the referenced "
3527 "specialization constant's type");
3528
3529 return success();
3530}
3531
3532//===----------------------------------------------------------------------===//
3533// spirv.Return
3534//===----------------------------------------------------------------------===//
3535
3536LogicalResult spirv::ReturnOp::verify() {
3537 // Verification is performed in spirv.func op.
3538 return success();
3539}
3540
3541//===----------------------------------------------------------------------===//
3542// spirv.ReturnValue
3543//===----------------------------------------------------------------------===//
3544
3545LogicalResult spirv::ReturnValueOp::verify() {
3546 // Verification is performed in spirv.func op.
3547 return success();
3548}
3549
3550//===----------------------------------------------------------------------===//
3551// spirv.Select
3552//===----------------------------------------------------------------------===//
3553
3554LogicalResult spirv::SelectOp::verify() {
3555 if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3556 auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3557 if (!resultVectorTy) {
3558 return emitOpError("result expected to be of vector type when "
3559 "condition is of vector type");
3560 }
3561 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3562 return emitOpError("result should have the same number of elements as "
3563 "the condition when condition is of vector type");
3564 }
3565 }
3566 return success();
3567}
3568
3569//===----------------------------------------------------------------------===//
3570// spirv.mlir.selection
3571//===----------------------------------------------------------------------===//
3572
3573ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3574 OperationState &result) {
3575 if (parseControlAttribute<spirv::SelectionControlAttr,
3576 spirv::SelectionControl>(parser, result))
3577 return failure();
3578 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3579}
3580
3581void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3582 auto control = getSelectionControl();
3583 if (control != spirv::SelectionControl::None)
3584 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3585 printer << ' ';
3586 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3587 /*printBlockTerminators=*/true);
3588}
3589
3590LogicalResult spirv::SelectionOp::verifyRegions() {
3591 auto *op = getOperation();
3592
3593 // We need to verify that the blocks follow the following layout:
3594 //
3595 // +--------------+
3596 // | header block |
3597 // +--------------+
3598 // / | \
3599 // ...
3600 //
3601 //
3602 // +---------+ +---------+ +---------+
3603 // | case #0 | | case #1 | | case #2 | ...
3604 // +---------+ +---------+ +---------+
3605 //
3606 //
3607 // ...
3608 // \ | /
3609 // v
3610 // +-------------+
3611 // | merge block |
3612 // +-------------+
3613
3614 auto &region = op->getRegion(0);
3615 // Allow empty region as a degenerated case, which can come from
3616 // optimizations.
3617 if (region.empty())
3618 return success();
3619
3620 // The last block is the merge block.
3621 if (!isMergeBlock(region.back()))
3622 return emitOpError("last block must be the merge block with only one "
3623 "'spirv.mlir.merge' op");
3624
3625 if (std::next(region.begin()) == region.end())
3626 return emitOpError("must have a selection header block");
3627
3628 return success();
3629}
3630
3631Block *spirv::SelectionOp::getHeaderBlock() {
3632 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3632, __extension__
__PRETTY_FUNCTION__))
;
3633 // The first block is the loop header block.
3634 return &getBody().front();
3635}
3636
3637Block *spirv::SelectionOp::getMergeBlock() {
3638 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3638, __extension__
__PRETTY_FUNCTION__))
;
3639 // The last block is the loop merge block.
3640 return &getBody().back();
3641}
3642
3643void spirv::SelectionOp::addMergeBlock() {
3644 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3644, __extension__
__PRETTY_FUNCTION__))
;
3645 auto *mergeBlock = new Block();
3646 getBody().push_back(mergeBlock);
3647 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3648
3649 // Add a spirv.mlir.merge op into the merge block.
3650 builder.create<spirv::MergeOp>(getLoc());
3651}
3652
3653spirv::SelectionOp spirv::SelectionOp::createIfThen(
3654 Location loc, Value condition,
3655 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3656 auto selectionOp =
3657 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3658
3659 selectionOp.addMergeBlock();
3660 Block *mergeBlock = selectionOp.getMergeBlock();
3661 Block *thenBlock = nullptr;
3662
3663 // Build the "then" block.
3664 {
3665 OpBuilder::InsertionGuard guard(builder);
3666 thenBlock = builder.createBlock(mergeBlock);
3667 thenBody(builder);
3668 builder.create<spirv::BranchOp>(loc, mergeBlock);
3669 }
3670
3671 // Build the header block.
3672 {
3673 OpBuilder::InsertionGuard guard(builder);
3674 builder.createBlock(thenBlock);
3675 builder.create<spirv::BranchConditionalOp>(
3676 loc, condition, thenBlock,
3677 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3678 /*falseArguments=*/ArrayRef<Value>());
3679 }
3680
3681 return selectionOp;
3682}
3683
3684//===----------------------------------------------------------------------===//
3685// spirv.SpecConstant
3686//===----------------------------------------------------------------------===//
3687
3688ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3689 OperationState &result) {
3690 StringAttr nameAttr;
3691 Attribute valueAttr;
3692
3693 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3694 result.attributes))
3695 return failure();
3696
3697 // Parse optional spec_id.
3698 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3699 IntegerAttr specIdAttr;
3700 if (parser.parseLParen() ||
3701 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3702 parser.parseRParen())
3703 return failure();
3704 }
3705
3706 if (parser.parseEqual() ||
3707 parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3708 result.attributes))
3709 return failure();
3710
3711 return success();
3712}
3713
3714void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3715 printer << ' ';
3716 printer.printSymbolName(getSymName());
3717 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3718 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3719 printer << " = " << getDefaultValue();
3720}
3721
3722LogicalResult spirv::SpecConstantOp::verify() {
3723 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3724 if (specID.getValue().isNegative())
3725 return emitOpError("SpecId cannot be negative");
3726
3727 auto value = getDefaultValue();
3728 if (value.isa<IntegerAttr, FloatAttr>()) {
3729 // Make sure bitwidth is allowed.
3730 if (!value.getType().isa<spirv::SPIRVType>())
3731 return emitOpError("default value bitwidth disallowed");
3732 return success();
3733 }
3734 return emitOpError(
3735 "default value can only be a bool, integer, or float scalar");
3736}
3737
3738//===----------------------------------------------------------------------===//
3739// spirv.StoreOp
3740//===----------------------------------------------------------------------===//
3741
3742ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3743 // Parse the storage class specification
3744 spirv::StorageClass storageClass;
3745 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3746 auto loc = parser.getCurrentLocation();
3747 Type elementType;
3748 if (parseEnumStrAttr(storageClass, parser) ||
3749 parser.parseOperandList(operandInfo, 2) ||
3750 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3751 parser.parseType(elementType)) {
3752 return failure();
3753 }
3754
3755 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3756 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3757 result.operands)) {
3758 return failure();
3759 }
3760 return success();
3761}
3762
3763void spirv::StoreOp::print(OpAsmPrinter &printer) {
3764 SmallVector<StringRef, 4> elidedAttrs;
3765 StringRef sc = stringifyStorageClass(
3766 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3767 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
3768
3769 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3770
3771 printer << " : " << getValue().getType();
3772 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3773}
3774
3775LogicalResult spirv::StoreOp::verify() {
3776 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3777 // OpTypePointer whose Type operand is the same as the type of Object."
3778 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
3779 return failure();
3780 return verifyMemoryAccessAttribute(*this);
3781}
3782
3783//===----------------------------------------------------------------------===//
3784// spirv.Unreachable
3785//===----------------------------------------------------------------------===//
3786
3787LogicalResult spirv::UnreachableOp::verify() {
3788 auto *block = (*this)->getBlock();
3789 // Fast track: if this is in entry block, its invalid. Otherwise, if no
3790 // predecessors, it's valid.
3791 if (block->isEntryBlock())
3792 return emitOpError("cannot be used in reachable block");
3793 if (block->hasNoPredecessors())
3794 return success();
3795
3796 // TODO: further verification needs to analyze reachability from
3797 // the entry block.
3798
3799 return success();
3800}
3801
3802//===----------------------------------------------------------------------===//
3803// spirv.Variable
3804//===----------------------------------------------------------------------===//
3805
3806ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3807 OperationState &result) {
3808 // Parse optional initializer
3809 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
3810 if (succeeded(parser.parseOptionalKeyword("init"))) {
3811 initInfo = OpAsmParser::UnresolvedOperand();
3812 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3813 parser.parseRParen())
3814 return failure();
3815 }
3816
3817 if (parseVariableDecorations(parser, result)) {
3818 return failure();
3819 }
3820
3821 // Parse result pointer type
3822 Type type;
3823 if (parser.parseColon())
3824 return failure();
3825 auto loc = parser.getCurrentLocation();
3826 if (parser.parseType(type))
3827 return failure();
3828
3829 auto ptrType = type.dyn_cast<spirv::PointerType>();
3830 if (!ptrType)
3831 return parser.emitError(loc, "expected spirv.ptr type");
3832 result.addTypes(ptrType);
3833
3834 // Resolve the initializer operand
3835 if (initInfo) {
3836 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3837 result.operands))
3838 return failure();
3839 }
3840
3841 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3842 ptrType.getStorageClass());
3843 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3844
3845 return success();
3846}
3847
3848void spirv::VariableOp::print(OpAsmPrinter &printer) {
3849 SmallVector<StringRef, 4> elidedAttrs{
3850 spirv::attributeName<spirv::StorageClass>()};
3851 // Print optional initializer
3852 if (getNumOperands() != 0)
3853 printer << " init(" << getInitializer() << ")";
3854
3855 printVariableDecorations(*this, printer, elidedAttrs);
3856 printer << " : " << getType();
3857}
3858
3859LogicalResult spirv::VariableOp::verify() {
3860 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3861 // object. It cannot be Generic. It must be the same as the Storage Class
3862 // operand of the Result Type."
3863 if (getStorageClass() != spirv::StorageClass::Function) {
3864 return emitOpError(
3865 "can only be used to model function-level variables. Use "
3866 "spirv.GlobalVariable for module-level variables.");
3867 }
3868
3869 auto pointerType = getPointer().getType().cast<spirv::PointerType>();
3870 if (getStorageClass() != pointerType.getStorageClass())
3871 return emitOpError(
3872 "storage class must match result pointer's storage class");
3873
3874 if (getNumOperands() != 0) {
3875 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3876 // a global (module scope) OpVariable instruction".
3877 auto *initOp = getOperand(0).getDefiningOp();
3878 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3879 spirv::ReferenceOfOp, // for spec constant
3880 spirv::AddressOfOp>(initOp))
3881 return emitOpError("initializer must be the result of a "
3882 "constant or spirv.GlobalVariable op");
3883 }
3884
3885 // TODO: generate these strings using ODS.
3886 auto *op = getOperation();
3887 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3888 stringifyDecoration(spirv::Decoration::DescriptorSet));
3889 auto bindingName = llvm::convertToSnakeFromCamelCase(
3890 stringifyDecoration(spirv::Decoration::Binding));
3891 auto builtInName = llvm::convertToSnakeFromCamelCase(
3892 stringifyDecoration(spirv::Decoration::BuiltIn));
3893
3894 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3895 if (op->getAttr(attr))
3896 return emitOpError("cannot have '")
3897 << attr << "' attribute (only allowed in spirv.GlobalVariable)";
3898 }
3899
3900 return success();
3901}
3902
3903//===----------------------------------------------------------------------===//
3904// spirv.VectorShuffle
3905//===----------------------------------------------------------------------===//
3906
3907LogicalResult spirv::VectorShuffleOp::verify() {
3908 VectorType resultType = getType().cast<VectorType>();
3909
3910 size_t numResultElements = resultType.getNumElements();
3911 if (numResultElements != getComponents().size())
3912 return emitOpError("result type element count (")
3913 << numResultElements
3914 << ") mismatch with the number of component selectors ("
3915 << getComponents().size() << ")";
3916
3917 size_t totalSrcElements =
3918 getVector1().getType().cast<VectorType>().getNumElements() +
3919 getVector2().getType().cast<VectorType>().getNumElements();
3920
3921 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3922 uint32_t index = selector.getZExtValue();
3923 if (index >= totalSrcElements &&
3924 index != std::numeric_limits<uint32_t>().max())
3925 return emitOpError("component selector ")
3926 << index << " out of range: expected to be in [0, "
3927 << totalSrcElements << ") or 0xffffffff";
3928 }
3929 return success();
3930}
3931
3932//===----------------------------------------------------------------------===//
3933// spirv.NV.CooperativeMatrixLoad
3934//===----------------------------------------------------------------------===//
3935
3936ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
3937 OperationState &result) {
3938 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3939 Type strideType = parser.getBuilder().getIntegerType(32);
3940 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3941 Type ptrType;
3942 Type elementType;
3943 if (parser.parseOperandList(operandInfo, 3) ||
3944 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3945 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3946 return failure();
3947 }
3948 if (parser.resolveOperands(operandInfo,
3949 {ptrType, strideType, columnMajorType},
3950 parser.getNameLoc(), result.operands)) {
3951 return failure();
3952 }
3953
3954 result.addTypes(elementType);
3955 return success();
3956}
3957
3958void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
3959 printer << " " << getPointer() << ", " << getStride() << ", "
3960 << getColumnmajor();
3961 // Print optional memory access attribute.
3962 if (auto memAccess = getMemoryAccess())
3963 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3964 printer << " : " << getPointer().getType() << " as " << getType();
3965}
3966
3967static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3968 Type coopMatrix) {
3969 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3970 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3971 return op->emitError(
3972 "Pointer must point to a scalar or vector type but provided ")
3973 << pointeeType;
3974 spirv::StorageClass storage =
3975 pointer.cast<spirv::PointerType>().getStorageClass();
3976 if (storage != spirv::StorageClass::Workgroup &&
3977 storage != spirv::StorageClass::StorageBuffer &&
3978 storage != spirv::StorageClass::PhysicalStorageBuffer)
3979 return op->emitError(
3980 "Pointer storage class must be Workgroup, StorageBuffer or "
3981 "PhysicalStorageBufferEXT but provided ")
3982 << stringifyStorageClass(storage);
3983 return success();
3984}
3985
3986LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
3987 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
3988 getResult().getType());
3989}
3990
3991//===----------------------------------------------------------------------===//
3992// spirv.NV.CooperativeMatrixStore
3993//===----------------------------------------------------------------------===//
3994
3995ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
3996 OperationState &result) {
3997 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
3998 Type strideType = parser.getBuilder().getIntegerType(32);
3999 Type columnMajorType = parser.getBuilder().getIntegerType(1);
4000 Type ptrType;
4001 Type elementType;
4002 if (parser.parseOperandList(operandInfo, 4) ||
4003 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
4004 parser.parseType(ptrType) || parser.parseComma() ||
4005 parser.parseType(elementType)) {
4006 return failure();
4007 }
4008 if (parser.resolveOperands(
4009 operandInfo, {ptrType, elementType, strideType, columnMajorType},
4010 parser.getNameLoc(), result.operands)) {
4011 return failure();
4012 }
4013
4014 return success();
4015}
4016
4017void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
4018 printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
4019 << ", " << getColumnmajor();
4020 // Print optional memory access attribute.
4021 if (auto memAccess = getMemoryAccess())
4022 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
4023 printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
4024}
4025
4026LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
4027 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4028 getObject().getType());
4029}
4030
4031//===----------------------------------------------------------------------===//
4032// spirv.NV.CooperativeMatrixMulAdd
4033//===----------------------------------------------------------------------===//
4034
4035static LogicalResult
4036verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
4037 if (op.getC().getType() != op.getResult().getType())
4038 return op.emitOpError("result and third operand must have the same type");
4039 auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
4040 auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
4041 auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
4042 auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
4043 if (typeA.getRows() != typeR.getRows() ||
4044 typeA.getColumns() != typeB.getRows() ||
4045 typeB.getColumns() != typeR.getColumns())
4046 return op.emitOpError("matrix size must match");
4047 if (typeR.getScope() != typeA.getScope() ||
4048 typeR.getScope() != typeB.getScope() ||
4049 typeR.getScope() != typeC.getScope())
4050 return op.emitOpError("matrix scope must match");
4051 if (typeA.getElementType() != typeB.getElementType() ||
4052 typeR.getElementType() != typeC.getElementType())
4053 return op.emitOpError("matrix element type must match");
4054 return success();
4055}
4056
4057LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
4058 return verifyCoopMatrixMulAdd(*this);
4059}
4060
4061static LogicalResult
4062verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
4063 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4064 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4065 return op->emitError(
4066 "Pointer must point to a scalar or vector type but provided ")
4067 << pointeeType;
4068 spirv::StorageClass storage =
4069 pointer.cast<spirv::PointerType>().getStorageClass();
4070 if (storage != spirv::StorageClass::Workgroup &&
4071 storage != spirv::StorageClass::CrossWorkgroup &&
4072 storage != spirv::StorageClass::UniformConstant &&
4073 storage != spirv::StorageClass::Generic)
4074 return op->emitError("Pointer storage class must be Workgroup or "
4075 "CrossWorkgroup but provided ")
4076 << stringifyStorageClass(storage);
4077 return success();
4078}
4079
4080//===----------------------------------------------------------------------===//
4081// spirv.INTEL.JointMatrixLoad
4082//===----------------------------------------------------------------------===//
4083
4084LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
4085 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4086 getResult().getType());
4087}
4088
4089//===----------------------------------------------------------------------===//
4090// spirv.INTEL.JointMatrixStore
4091//===----------------------------------------------------------------------===//
4092
4093LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
4094 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4095 getObject().getType());
4096}
4097
4098//===----------------------------------------------------------------------===//
4099// spirv.INTEL.JointMatrixMad
4100//===----------------------------------------------------------------------===//
4101
4102static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
4103 if (op.getC().getType() != op.getResult().getType())
4104 return op.emitOpError("result and third operand must have the same type");
4105 auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
4106 auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
4107 auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
4108 auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
4109 if (typeA.getRows() != typeR.getRows() ||
4110 typeA.getColumns() != typeB.getRows() ||
4111 typeB.getColumns() != typeR.getColumns())
4112 return op.emitOpError("matrix size must match");
4113 if (typeR.getScope() != typeA.getScope() ||
4114 typeR.getScope() != typeB.getScope() ||
4115 typeR.getScope() != typeC.getScope())
4116 return op.emitOpError("matrix scope must match");
4117 if (typeA.getElementType() != typeB.getElementType() ||
4118 typeR.getElementType() != typeC.getElementType())
4119 return op.emitOpError("matrix element type must match");
4120 return success();
4121}
4122
4123LogicalResult spirv::INTELJointMatrixMadOp::verify() {
4124 return verifyJointMatrixMad(*this);
4125}
4126
4127//===----------------------------------------------------------------------===//
4128// spirv.MatrixTimesScalar
4129//===----------------------------------------------------------------------===//
4130
4131LogicalResult spirv::MatrixTimesScalarOp::verify() {
4132 if (auto inputCoopmat =
4133 getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
4134 if (inputCoopmat.getElementType() != getScalar().getType())
4135 return emitError("input matrix components' type and scaling value must "
4136 "have the same type");
4137 return success();
4138 }
4139
4140 // Check that the scalar type is the same as the matrix element type.
4141 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4142 if (getScalar().getType() != inputMatrix.getElementType())
4143 return emitError("input matrix components' type and scaling value must "
4144 "have the same type");
4145
4146 return success();
4147}
4148
4149//===----------------------------------------------------------------------===//
4150// spirv.CopyMemory
4151//===----------------------------------------------------------------------===//
4152
4153void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
4154 printer << ' ';
4155
4156 StringRef targetStorageClass = stringifyStorageClass(
4157 getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4158 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
4159
4160 StringRef sourceStorageClass = stringifyStorageClass(
4161 getSource().getType().cast<spirv::PointerType>().getStorageClass());
4162 printer << " \"" << sourceStorageClass << "\" " << getSource();
4163
4164 SmallVector<StringRef, 4> elidedAttrs;
4165 printMemoryAccessAttribute(*this, printer, elidedAttrs);
4166 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4167 getSourceMemoryAccess(),
4168 getSourceAlignment());
4169
4170 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4171
4172 Type pointeeType =
4173 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4174 printer << " : " << pointeeType;
4175}
4176
4177ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4178 OperationState &result) {
4179 spirv::StorageClass targetStorageClass;
4180 OpAsmParser::UnresolvedOperand targetPtrInfo;
4181
4182 spirv::StorageClass sourceStorageClass;
1
'sourceStorageClass' declared without an initial value
4183 OpAsmParser::UnresolvedOperand sourcePtrInfo;
4184
4185 Type elementType;
4186
4187 if (parseEnumStrAttr(targetStorageClass, parser) ||
7
Taking false branch
4188 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4189 parseEnumStrAttr(sourceStorageClass, parser) ||
2
Calling 'parseEnumStrAttr<mlir::spirv::StorageClass>'
6
Returning from 'parseEnumStrAttr<mlir::spirv::StorageClass>'
4190 parser.parseOperand(sourcePtrInfo) ||
4191 parseMemoryAccessAttributes(parser, result)) {
4192 return failure();
4193 }
4194
4195 if (!parser.parseOptionalComma()) {
4196 // Parse 2nd memory access attributes.
4197 if (parseSourceMemoryAccessAttributes(parser, result)) {
4198 return failure();
4199 }
4200 }
4201
4202 if (parser.parseColon() || parser.parseType(elementType))
8
Taking false branch
4203 return failure();
4204
4205 if (parser.parseOptionalAttrDict(result.attributes))
9
Taking false branch
4206 return failure();
4207
4208 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4209 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
10
2nd function call argument is an uninitialized value
4210
4211 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4212 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4213 return failure();
4214 }
4215
4216 return success();
4217}
4218
4219LogicalResult spirv::CopyMemoryOp::verify() {
4220 Type targetType =
4221 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4222
4223 Type sourceType =
4224 getSource().getType().cast<spirv::PointerType>().getPointeeType();
4225
4226 if (targetType != sourceType)
4227 return emitOpError("both operands must be pointers to the same type");
4228
4229 if (failed(verifyMemoryAccessAttribute(*this)))
4230 return failure();
4231
4232 // TODO - According to the spec:
4233 //
4234 // If two masks are present, the first applies to Target and cannot include
4235 // MakePointerVisible, and the second applies to Source and cannot include
4236 // MakePointerAvailable.
4237 //
4238 // Add such verification here.
4239
4240 return verifySourceMemoryAccessAttribute(*this);
4241}
4242
4243//===----------------------------------------------------------------------===//
4244// spirv.Transpose
4245//===----------------------------------------------------------------------===//
4246
4247LogicalResult spirv::TransposeOp::verify() {
4248 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4249 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4250
4251 // Verify that the input and output matrices have correct shapes.
4252 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4253 return emitError("input matrix rows count must be equal to "
4254 "output matrix columns count");
4255
4256 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4257 return emitError("input matrix columns count must be equal to "
4258 "output matrix rows count");
4259
4260 // Verify that the input and output matrices have the same component type
4261 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4262 return emitError("input and output matrices must have the same "
4263 "component type");
4264
4265 return success();
4266}
4267
4268//===----------------------------------------------------------------------===//
4269// spirv.MatrixTimesMatrix
4270//===----------------------------------------------------------------------===//
4271
4272LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4273 auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
4274 auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
4275 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4276
4277 // left matrix columns' count and right matrix rows' count must be equal
4278 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4279 return emitError("left matrix columns' count must be equal to "
4280 "the right matrix rows' count");
4281
4282 // right and result matrices columns' count must be the same
4283 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4284 return emitError(
4285 "right and result matrices must have equal columns' count");
4286
4287 // right and result matrices component type must be the same
4288 if (rightMatrix.getElementType() != resultMatrix.getElementType())
4289 return emitError("right and result matrices' component type must"
4290 " be the same");
4291
4292 // left and result matrices component type must be the same
4293 if (leftMatrix.getElementType() != resultMatrix.getElementType())
4294 return emitError("left and result matrices' component type"
4295 " must be the same");
4296
4297 // left and result matrices rows count must be the same
4298 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4299 return emitError("left and result matrices must have equal rows' count");
4300
4301 return success();
4302}
4303
4304//===----------------------------------------------------------------------===//
4305// spirv.SpecConstantComposite
4306//===----------------------------------------------------------------------===//
4307
4308ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4309 OperationState &result) {
4310
4311 StringAttr compositeName;
4312 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4313 result.attributes))
4314 return failure();
4315
4316 if (parser.parseLParen())
4317 return failure();
4318
4319 SmallVector<Attribute, 4> constituents;
4320
4321 do {
4322 // The name of the constituent attribute isn't important
4323 const char *attrName = "spec_const";
4324 FlatSymbolRefAttr specConstRef;
4325 NamedAttrList attrs;
4326
4327 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4328 return failure();
4329
4330 constituents.push_back(specConstRef);
4331 } while (!parser.parseOptionalComma());
4332
4333 if (parser.parseRParen())
4334 return failure();
4335
4336 result.addAttribute(kCompositeSpecConstituentsName,
4337 parser.getBuilder().getArrayAttr(constituents));
4338
4339 Type type;
4340 if (parser.parseColonType(type))
4341 return failure();
4342
4343 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4344
4345 return success();
4346}
4347
4348void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4349 printer << " ";
4350 printer.printSymbolName(getSymName());
4351 printer << " (";
4352 auto constituents = this->getConstituents().getValue();
4353
4354 if (!constituents.empty())
4355 llvm::interleaveComma(constituents, printer);
4356
4357 printer << ") : " << getType();
4358}
4359
4360LogicalResult spirv::SpecConstantCompositeOp::verify() {
4361 auto cType = getType().dyn_cast<spirv::CompositeType>();
4362 auto constituents = this->getConstituents().getValue();
4363
4364 if (!cType)
4365 return emitError("result type must be a composite type, but provided ")
4366 << getType();
4367
4368 if (cType.isa<spirv::CooperativeMatrixNVType>())
4369 return emitError("unsupported composite type ") << cType;
4370 if (cType.isa<spirv::JointMatrixINTELType>())
4371 return emitError("unsupported composite type ") << cType;
4372 if (constituents.size() != cType.getNumElements())
4373 return emitError("has incorrect number of operands: expected ")
4374 << cType.getNumElements() << ", but provided "
4375 << constituents.size();
4376
4377 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4378 auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4379
4380 auto constituentSpecConstOp =
4381 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4382 (*this)->getParentOp(), constituent.getAttr()));
4383
4384 if (constituentSpecConstOp.getDefaultValue().getType() !=
4385 cType.getElementType(index))
4386 return emitError("has incorrect types of operands: expected ")
4387 << cType.getElementType(index) << ", but provided "
4388 << constituentSpecConstOp.getDefaultValue().getType();
4389 }
4390
4391 return success();
4392}
4393
4394//===----------------------------------------------------------------------===//
4395// spirv.SpecConstantOperation
4396//===----------------------------------------------------------------------===//
4397
4398ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4399 OperationState &result) {
4400 Region *body = result.addRegion();
4401
4402 if (parser.parseKeyword("wraps"))
4403 return failure();
4404
4405 body->push_back(new Block);
4406 Block &block = body->back();
4407 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4408
4409 if (!wrappedOp)
4410 return failure();
4411
4412 OpBuilder builder(parser.getContext());
4413 builder.setInsertionPointToEnd(&block);
4414 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4415 result.location = wrappedOp->getLoc();
4416
4417 result.addTypes(wrappedOp->getResult(0).getType());
4418
4419 if (parser.parseOptionalAttrDict(result.attributes))
4420 return failure();
4421
4422 return success();
4423}
4424
4425void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4426 printer << " wraps ";
4427 printer.printGenericOp(&getBody().front().front());
4428}
4429
4430LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4431 Block &block = getRegion().getBlocks().front();
4432
4433 if (block.getOperations().size() != 2)
4434 return emitOpError("expected exactly 2 nested ops");
4435
4436 Operation &enclosedOp = block.getOperations().front();
4437
4438 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4439 return emitOpError("invalid enclosed op");
4440
4441 for (auto operand : enclosedOp.getOperands())
4442 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4443 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4444 return emitOpError(
4445 "invalid operand, must be defined by a constant operation");
4446
4447 return success();
4448}
4449
4450//===----------------------------------------------------------------------===//
4451// spirv.GL.FrexpStruct
4452//===----------------------------------------------------------------------===//
4453
4454LogicalResult spirv::GLFrexpStructOp::verify() {
4455 spirv::StructType structTy =
4456 getResult().getType().dyn_cast<spirv::StructType>();
4457
4458 if (structTy.getNumElements() != 2)
4459 return emitError("result type must be a struct type with two memebers");
4460
4461 Type significandTy = structTy.getElementType(0);
4462 Type exponentTy = structTy.getElementType(1);
4463 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4464 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4465
4466 Type operandTy = getOperand().getType();
4467 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4468 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4469
4470 if (significandTy != operandTy)
4471 return emitError("member zero of the resulting struct type must be the "
4472 "same type as the operand");
4473
4474 if (exponentVecTy) {
4475 IntegerType componentIntTy =
4476 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4477 if (!componentIntTy || componentIntTy.getWidth() != 32)
4478 return emitError("member one of the resulting struct type must"
4479 "be a scalar or vector of 32 bit integer type");
4480 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4481 return emitError("member one of the resulting struct type "
4482 "must be a scalar or vector of 32 bit integer type");
4483 }
4484
4485 // Check that the two member types have the same number of components
4486 if (operandVecTy && exponentVecTy &&
4487 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4488 return success();
4489
4490 if (operandFTy && exponentIntTy)
4491 return success();
4492
4493 return emitError("member one of the resulting struct type must have the same "
4494 "number of components as the operand type");
4495}
4496
4497//===----------------------------------------------------------------------===//
4498// spirv.GL.Ldexp
4499//===----------------------------------------------------------------------===//
4500
4501LogicalResult spirv::GLLdexpOp::verify() {
4502 Type significandType = getX().getType();
4503 Type exponentType = getExp().getType();
4504
4505 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4506 return emitOpError("operands must both be scalars or vectors");
4507
4508 auto getNumElements = [](Type type) -> unsigned {
4509 if (auto vectorType = type.dyn_cast<VectorType>())
4510 return vectorType.getNumElements();
4511 return 1;
4512 };
4513
4514 if (getNumElements(significandType) != getNumElements(exponentType))
4515 return emitOpError("operands must have the same number of elements");
4516
4517 return success();
4518}
4519
4520//===----------------------------------------------------------------------===//
4521// spirv.ImageDrefGather
4522//===----------------------------------------------------------------------===//
4523
4524LogicalResult spirv::ImageDrefGatherOp::verify() {
4525 VectorType resultType = getResult().getType().cast<VectorType>();
4526 auto sampledImageType =
4527 getSampledimage().getType().cast<spirv::SampledImageType>();
4528 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4529
4530 if (resultType.getNumElements() != 4)
4531 return emitOpError("result type must be a vector of four components");
4532
4533 Type elementType = resultType.getElementType();
4534 Type sampledElementType = imageType.getElementType();
4535 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4536 return emitOpError(
4537 "the component type of result must be the same as sampled type of the "
4538 "underlying image type");
4539
4540 spirv::Dim imageDim = imageType.getDim();
4541 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4542
4543 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4544 imageDim != spirv::Dim::Rect)
4545 return emitOpError(
4546 "the Dim operand of the underlying image type must be 2D, Cube, or "
4547 "Rect");
4548
4549 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4550 return emitOpError("the MS operand of the underlying image type must be 0");
4551
4552 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4553 auto operandArguments = getOperandArguments();
4554
4555 return verifyImageOperands(*this, attr, operandArguments);
4556}
4557
4558//===----------------------------------------------------------------------===//
4559// spirv.ShiftLeftLogicalOp
4560//===----------------------------------------------------------------------===//
4561
4562LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4563 return verifyShiftOp(*this);
4564}
4565
4566//===----------------------------------------------------------------------===//
4567// spirv.ShiftRightArithmeticOp
4568//===----------------------------------------------------------------------===//
4569
4570LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4571 return verifyShiftOp(*this);
4572}
4573
4574//===----------------------------------------------------------------------===//
4575// spirv.ShiftRightLogicalOp
4576//===----------------------------------------------------------------------===//
4577
4578LogicalResult spirv::ShiftRightLogicalOp::verify() {
4579 return verifyShiftOp(*this);
4580}
4581
4582//===----------------------------------------------------------------------===//
4583// spirv.ImageQuerySize
4584//===----------------------------------------------------------------------===//
4585
4586LogicalResult spirv::ImageQuerySizeOp::verify() {
4587 spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
4588 Type resultType = getResult().getType();
4589
4590 spirv::Dim dim = imageType.getDim();
4591 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4592 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4593 switch (dim) {
4594 case spirv::Dim::Dim1D:
4595 case spirv::Dim::Dim2D:
4596 case spirv::Dim::Dim3D:
4597 case spirv::Dim::Cube:
4598 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4599 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4600 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4601 return emitError(
4602 "if Dim is 1D, 2D, 3D, or Cube, "
4603 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4604 break;
4605 case spirv::Dim::Buffer:
4606 case spirv::Dim::Rect:
4607 break;
4608 default:
4609 return emitError("the Dim operand of the image type must "
4610 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4611 }
4612
4613 unsigned componentNumber = 0;
4614 switch (dim) {
4615 case spirv::Dim::Dim1D:
4616 case spirv::Dim::Buffer:
4617 componentNumber = 1;
4618 break;
4619 case spirv::Dim::Dim2D:
4620 case spirv::Dim::Cube:
4621 case spirv::Dim::Rect:
4622 componentNumber = 2;
4623 break;
4624 case spirv::Dim::Dim3D:
4625 componentNumber = 3;
4626 break;
4627 default:
4628 break;
4629 }
4630
4631 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4632 componentNumber += 1;
4633
4634 unsigned resultComponentNumber = 1;
4635 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4636 resultComponentNumber = resultVectorType.getNumElements();
4637
4638 if (componentNumber != resultComponentNumber)
4639 return emitError("expected the result to have ")
4640 << componentNumber << " component(s), but found "
4641 << resultComponentNumber << " component(s)";
4642
4643 return success();
4644}
4645
4646static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4647 OpAsmParser &parser,
4648 OperationState &state) {
4649 OpAsmParser::UnresolvedOperand ptrInfo;
4650 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4651 Type type;
4652 auto loc = parser.getCurrentLocation();
4653 SmallVector<Type, 4> indicesTypes;
4654
4655 if (parser.parseOperand(ptrInfo) ||
4656 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4657 parser.parseColonType(type) ||
4658 parser.resolveOperand(ptrInfo, type, state.operands))
4659 return failure();
4660
4661 // Check that the provided indices list is not empty before parsing their
4662 // type list.
4663 if (indicesInfo.empty())
4664 return emitError(state.location) << opName << " expected element";
4665
4666 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4667 return failure();
4668
4669 // Check that the indices types list is not empty and that it has a one-to-one
4670 // mapping to the provided indices.
4671 if (indicesTypes.size() != indicesInfo.size())
4672 return emitError(state.location)
4673 << opName
4674 << " indices types' count must be equal to indices info count";
4675
4676 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4677 return failure();
4678
4679 auto resultType = getElementPtrType(
4680 type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4681 if (!resultType)
4682 return failure();
4683
4684 state.addTypes(resultType);
4685 return success();
4686}
4687
4688template <typename Op>
4689static auto concatElemAndIndices(Op op) {
4690 SmallVector<Value> ret(op.getIndices().size() + 1);
4691 ret[0] = op.getElement();
4692 llvm::copy(op.getIndices(), ret.begin() + 1);
4693 return ret;
4694}
4695
4696//===----------------------------------------------------------------------===//
4697// spirv.InBoundsPtrAccessChainOp
4698//===----------------------------------------------------------------------===//
4699
4700void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4701 OperationState &state,
4702 Value basePtr, Value element,
4703 ValueRange indices) {
4704 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4705 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4705, __extension__
__PRETTY_FUNCTION__))
;
4706 build(builder, state, type, basePtr, element, indices);
4707}
4708
4709ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4710 OperationState &result) {
4711 return parsePtrAccessChainOpImpl(
4712 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4713}
4714
4715void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4716 printAccessChain(*this, concatElemAndIndices(*this), printer);
4717}
4718
4719LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4720 return verifyAccessChain(*this, getIndices());
4721}
4722
4723//===----------------------------------------------------------------------===//
4724// spirv.PtrAccessChainOp
4725//===----------------------------------------------------------------------===//
4726
4727void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4728 Value basePtr, Value element,
4729 ValueRange indices) {
4730 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4731 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4731, __extension__
__PRETTY_FUNCTION__))
;
4732 build(builder, state, type, basePtr, element, indices);
4733}
4734
4735ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4736 OperationState &result) {
4737 return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4738 parser, result);
4739}
4740
4741void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4742 printAccessChain(*this, concatElemAndIndices(*this), printer);
4743}
4744
4745LogicalResult spirv::PtrAccessChainOp::verify() {
4746 return verifyAccessChain(*this, getIndices());
4747}
4748
4749//===----------------------------------------------------------------------===//
4750// spirv.VectorTimesScalarOp
4751//===----------------------------------------------------------------------===//
4752
4753LogicalResult spirv::VectorTimesScalarOp::verify() {
4754 if (getVector().getType() != getType())
4755 return emitOpError("vector operand and result type mismatch");
4756 auto scalarType = getType().cast<VectorType>().getElementType();
4757 if (getScalar().getType() != scalarType)
4758 return emitOpError("scalar operand and result element type match");
4759 return success();
4760}
4761
4762//===----------------------------------------------------------------------===//
4763// Group ops
4764//===----------------------------------------------------------------------===//
4765
4766template <typename Op>
4767static LogicalResult verifyGroupOp(Op op) {
4768 spirv::Scope scope = op.getExecutionScope();
4769 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
4770 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
4771
4772 return success();
4773}
4774
4775LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); }
4776
4777LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); }
4778
4779LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); }
4780
4781LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); }
4782
4783LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); }
4784
4785LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); }
4786
4787LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); }
4788
4789LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); }
4790
4791LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
4792
4793LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
4794
4795//===----------------------------------------------------------------------===//
4796// Integer Dot Product ops
4797//===----------------------------------------------------------------------===//
4798
4799static LogicalResult verifyIntegerDotProduct(Operation *op) {
4800 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&(static_cast <bool> (llvm::is_contained({2u, 3u}, op->
getNumOperands()) && "Not an integer dot product op?"
) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4801, __extension__
__PRETTY_FUNCTION__))
4801 "Not an integer dot product op?")(static_cast <bool> (llvm::is_contained({2u, 3u}, op->
getNumOperands()) && "Not an integer dot product op?"
) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4801, __extension__
__PRETTY_FUNCTION__))
;
4802 assert(op->getNumResults() == 1 && "Expected a single result")(static_cast <bool> (op->getNumResults() == 1 &&
"Expected a single result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"Expected a single result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4802, __extension__
__PRETTY_FUNCTION__))
;
4803
4804 Type factorTy = op->getOperand(0).getType();
4805 if (op->getOperand(1).getType() != factorTy)
4806 return op->emitOpError("requires the same type for both vector operands");
4807
4808 unsigned expectedNumAttrs = 0;
4809 if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4810 ++expectedNumAttrs;
4811 auto packedVectorFormat =
4812 op->getAttr(kPackedVectorFormatAttrName)
4813 .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
4814 if (!packedVectorFormat)
4815 return op->emitOpError("requires Packed Vector Format attribute for "
4816 "integer vector operands");
4817
4818 assert(packedVectorFormat.getValue() ==(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__
__PRETTY_FUNCTION__))
4819 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__
__PRETTY_FUNCTION__))
4820 "Unknown Packed Vector Format")(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__
__PRETTY_FUNCTION__))
;
4821 if (intTy.getWidth() != 32)
4822 return op->emitOpError(
4823 llvm::formatv("with specified Packed Vector Format ({0}) requires "
4824 "integer vector operands to be 32-bits wide",
4825 packedVectorFormat.getValue()));
4826 } else {
4827 if (op->hasAttr(kPackedVectorFormatAttrName))
4828 return op->emitOpError(llvm::formatv(
4829 "with invalid format attribute for vector operands of type '{0}'",
4830 factorTy));
4831 }
4832
4833 if (op->getAttrs().size() > expectedNumAttrs)
4834 return op->emitError(
4835 "op only supports the 'format' #spirv.packed_vector_format attribute");
4836
4837 Type resultTy = op->getResultTypes().front();
4838 bool hasAccumulator = op->getNumOperands() == 3;
4839 if (hasAccumulator && op->getOperand(2).getType() != resultTy)
4840 return op->emitOpError(
4841 "requires the same accumulator operand and result types");
4842
4843 unsigned factorBitWidth = getBitWidth(factorTy);
4844 unsigned resultBitWidth = getBitWidth(resultTy);
4845 if (factorBitWidth > resultBitWidth)
4846 return op->emitOpError(
4847 llvm::formatv("result type has insufficient bit-width ({0} bits) "
4848 "for the specified vector operand type ({1} bits)",
4849 resultBitWidth, factorBitWidth));
4850
4851 return success();
4852}
4853
4854static Optional<spirv::Version> getIntegerDotProductMinVersion() {
4855 return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
4856}
4857
4858static Optional<spirv::Version> getIntegerDotProductMaxVersion() {
4859 return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
4860}
4861
4862static SmallVector<ArrayRef<spirv::Extension>, 1>
4863getIntegerDotProductExtensions() {
4864 // Requires the SPV_KHR_integer_dot_product extension, specified either
4865 // explicitly or implied by target env's SPIR-V version >= 1.6.
4866 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
4867 return {extension};
4868}
4869
4870static SmallVector<ArrayRef<spirv::Capability>, 1>
4871getIntegerDotProductCapabilities(Operation *op) {
4872 // Requires the the DotProduct capability and capabilities that depend on
4873 // exact op types.
4874 static const auto dotProductCap = spirv::Capability::DotProduct;
4875 static const auto dotProductInput4x8BitPackedCap =
4876 spirv::Capability::DotProductInput4x8BitPacked;
4877 static const auto dotProductInput4x8BitCap =
4878 spirv::Capability::DotProductInput4x8Bit;
4879 static const auto dotProductInputAllCap =
4880 spirv::Capability::DotProductInputAll;
4881
4882 SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
4883
4884 Type factorTy = op->getOperand(0).getType();
4885 if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4886 auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
4887 .cast<spirv::PackedVectorFormatAttr>();
4888 if (formatAttr.getValue() ==
4889 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
4890 capabilities.push_back(dotProductInput4x8BitPackedCap);
4891
4892 return capabilities;
4893 }
4894
4895 auto vecTy = factorTy.cast<VectorType>();
4896 if (vecTy.getElementTypeBitWidth() == 8) {
4897 capabilities.push_back(dotProductInput4x8BitCap);
4898 return capabilities;
4899 }
4900
4901 capabilities.push_back(dotProductInputAllCap);
4902 return capabilities;
4903}
4904
4905#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
4906 LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
4907 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
4908 return getIntegerDotProductExtensions(); \
4909 } \
4910 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
4911 return getIntegerDotProductCapabilities(*this); \
4912 } \
4913 Optional<spirv::Version> OpName::getMinVersion() { \
4914 return getIntegerDotProductMinVersion(); \
4915 } \
4916 Optional<spirv::Version> OpName::getMaxVersion() { \
4917 return getIntegerDotProductMaxVersion(); \
4918 }
4919
4920SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
4921SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
4922SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
4923SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
4924SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
4925SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
4926
4927#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
4928
4929// TableGen'erated operation interfaces for querying versions, extensions, and
4930// capabilities.
4931#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4932
4933// TablenGen'erated operation definitions.
4934#define GET_OP_CLASSES
4935#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4936
4937namespace mlir {
4938namespace spirv {
4939// TableGen'erated operation availability interface implementations.
4940#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4941} // namespace spirv
4942} // namespace mlir