Bug Summary

File:build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Warning:line 3732, column 18
2nd function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SPIRVOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/SPIRV/IR -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/SPIRV/IR -I include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/llvm/include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-10-03-140002-15933-1 -x c++ /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/FunctionImplementation.h"
25#include "mlir/IR/OpDefinition.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Interfaces/CallInterfaces.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/APInt.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/bit.h"
34#include <numeric>
35
36using namespace mlir;
37
38// TODO: generate these strings using ODS.
39constexpr char kMemoryAccessAttrName[] = "memory_access";
40constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
41constexpr char kAlignmentAttrName[] = "alignment";
42constexpr char kSourceAlignmentAttrName[] = "source_alignment";
43constexpr char kBranchWeightAttrName[] = "branch_weights";
44constexpr char kCallee[] = "callee";
45constexpr char kClusterSize[] = "cluster_size";
46constexpr char kControl[] = "control";
47constexpr char kDefaultValueAttrName[] = "default_value";
48constexpr char kExecutionScopeAttrName[] = "execution_scope";
49constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
50constexpr char kFnNameAttrName[] = "fn";
51constexpr char kGroupOperationAttrName[] = "group_operation";
52constexpr char kIndicesAttrName[] = "indices";
53constexpr char kInitializerAttrName[] = "initializer";
54constexpr char kInterfaceAttrName[] = "interface";
55constexpr char kMemoryScopeAttrName[] = "memory_scope";
56constexpr char kSemanticsAttrName[] = "semantics";
57constexpr char kSpecIdAttrName[] = "spec_id";
58constexpr char kTypeAttrName[] = "type";
59constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
60constexpr char kValueAttrName[] = "value";
61constexpr char kValuesAttrName[] = "values";
62constexpr char kCompositeSpecConstituentsName[] = "constituents";
63
64//===----------------------------------------------------------------------===//
65// Common utility functions
66//===----------------------------------------------------------------------===//
67
68static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
69 OperationState &result) {
70 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
71 Type type;
72 // If the operand list is in-between parentheses, then we have a generic form.
73 // (see the fallback in `printOneResultOp`).
74 SMLoc loc = parser.getCurrentLocation();
75 if (!parser.parseOptionalLParen()) {
76 if (parser.parseOperandList(ops) || parser.parseRParen() ||
77 parser.parseOptionalAttrDict(result.attributes) ||
78 parser.parseColon() || parser.parseType(type))
79 return failure();
80 auto fnType = type.dyn_cast<FunctionType>();
81 if (!fnType) {
82 parser.emitError(loc, "expected function type");
83 return failure();
84 }
85 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
86 return failure();
87 result.addTypes(fnType.getResults());
88 return success();
89 }
90 return failure(parser.parseOperandList(ops) ||
91 parser.parseOptionalAttrDict(result.attributes) ||
92 parser.parseColonType(type) ||
93 parser.resolveOperands(ops, type, result.operands) ||
94 parser.addTypeToList(type, result.types));
95}
96
97static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
98 assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 &&
"op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 98, __extension__
__PRETTY_FUNCTION__))
;
99
100 // If not all the operand and result types are the same, just use the
101 // generic assembly form to avoid omitting information in printing.
102 auto resultType = op->getResult(0).getType();
103 if (llvm::any_of(op->getOperandTypes(),
104 [&](Type type) { return type != resultType; })) {
105 p.printGenericOp(op, /*printOpName=*/false);
106 return;
107 }
108
109 p << ' ';
110 p.printOperands(op->getOperands());
111 p.printOptionalAttrDict(op->getAttrs());
112 // Now we can output only one type for all operands and the result.
113 p << " : " << resultType;
114}
115
116/// Returns true if the given op is a function-like op or nested in a
117/// function-like op without a module-like op in the middle.
118static bool isNestedInFunctionOpInterface(Operation *op) {
119 if (!op)
120 return false;
121 if (op->hasTrait<OpTrait::SymbolTable>())
122 return false;
123 if (isa<FunctionOpInterface>(op))
124 return true;
125 return isNestedInFunctionOpInterface(op->getParentOp());
126}
127
128/// Returns true if the given op is an module-like op that maintains a symbol
129/// table.
130static bool isDirectInModuleLikeOp(Operation *op) {
131 return op && op->hasTrait<OpTrait::SymbolTable>();
132}
133
134static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
135 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
136 if (!constOp) {
137 return failure();
138 }
139 auto valueAttr = constOp.getValue();
140 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
141 if (!integerValueAttr) {
142 return failure();
143 }
144
145 if (integerValueAttr.getType().isSignlessInteger())
146 value = integerValueAttr.getInt();
147 else
148 value = integerValueAttr.getSInt();
149
150 return success();
151}
152
153template <typename Ty>
154static ArrayAttr
155getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
156 function_ref<StringRef(Ty)> stringifyFn) {
157 if (enumValues.empty()) {
158 return nullptr;
159 }
160 SmallVector<StringRef, 1> enumValStrs;
161 enumValStrs.reserve(enumValues.size());
162 for (auto val : enumValues) {
163 enumValStrs.emplace_back(stringifyFn(val));
164 }
165 return builder.getStrArrayAttr(enumValStrs);
166}
167
168/// Parses the next string attribute in `parser` as an enumerant of the given
169/// `EnumClass`.
170template <typename EnumClass>
171static ParseResult
172parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
173 StringRef attrName = spirv::attributeName<EnumClass>()) {
174 Attribute attrVal;
175 NamedAttrList attr;
176 auto loc = parser.getCurrentLocation();
177 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
3
Taking false branch
178 attrName, attr))
179 return failure();
180 if (!attrVal.isa<StringAttr>())
4
Taking true branch
181 return parser.emitError(loc, "expected ")
5
Returning without writing to 'value'
182 << attrName << " attribute specified as string";
183 auto attrOptional =
184 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
185 if (!attrOptional)
186 return parser.emitError(loc, "invalid ")
187 << attrName << " attribute specification: " << attrVal;
188 value = *attrOptional;
189 return success();
190}
191
192/// Parses the next string attribute in `parser` as an enumerant of the given
193/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
194/// attribute with the enum class's name as attribute name.
195template <typename EnumAttrClass,
196 typename EnumClass = typename EnumAttrClass::ValueType>
197static ParseResult
198parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
199 StringRef attrName = spirv::attributeName<EnumClass>()) {
200 if (parseEnumStrAttr(value, parser))
201 return failure();
202 state.addAttribute(attrName,
203 parser.getBuilder().getAttr<EnumAttrClass>(value));
204 return success();
205}
206
207/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
208/// and inserts the enumerant into `state` as an 32-bit integer attribute with
209/// the enum class's name as attribute name.
210template <typename EnumAttrClass,
211 typename EnumClass = typename EnumAttrClass::ValueType>
212static ParseResult
213parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
214 OperationState &state,
215 StringRef attrName = spirv::attributeName<EnumClass>()) {
216 if (parseEnumKeywordAttr(value, parser))
217 return failure();
218 state.addAttribute(attrName,
219 parser.getBuilder().getAttr<EnumAttrClass>(value));
220 return success();
221}
222
223/// Parses Function, Selection and Loop control attributes. If no control is
224/// specified, "None" is used as a default.
225template <typename EnumAttrClass, typename EnumClass>
226static ParseResult
227parseControlAttribute(OpAsmParser &parser, OperationState &state,
228 StringRef attrName = spirv::attributeName<EnumClass>()) {
229 if (succeeded(parser.parseOptionalKeyword(kControl))) {
230 EnumClass control;
231 if (parser.parseLParen() ||
232 parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
233 parser.parseRParen())
234 return failure();
235 return success();
236 }
237 // Set control to "None" otherwise.
238 Builder builder = parser.getBuilder();
239 state.addAttribute(attrName,
240 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
241 return success();
242}
243
244/// Parses optional memory access attributes attached to a memory access
245/// operand/pointer. Specifically, parses the following syntax:
246/// (`[` memory-access `]`)?
247/// where:
248/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
249/// integer-literal | `"NonTemporal"`
250static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
251 OperationState &state) {
252 // Parse an optional list of attributes staring with '['
253 if (parser.parseOptionalLSquare()) {
254 // Nothing to do
255 return success();
256 }
257
258 spirv::MemoryAccess memoryAccessAttr;
259 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
260 kMemoryAccessAttrName))
261 return failure();
262
263 if (spirv::bitEnumContainsAll(memoryAccessAttr,
264 spirv::MemoryAccess::Aligned)) {
265 // Parse integer attribute for alignment.
266 Attribute alignmentAttr;
267 Type i32Type = parser.getBuilder().getIntegerType(32);
268 if (parser.parseComma() ||
269 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
270 state.attributes)) {
271 return failure();
272 }
273 }
274 return parser.parseRSquare();
275}
276
277// TODO Make sure to merge this and the previous function into one template
278// parameterized by memory access attribute name and alignment. Doing so now
279// results in VS2017 in producing an internal error (at the call site) that's
280// not detailed enough to understand what is happening.
281static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
282 OperationState &state) {
283 // Parse an optional list of attributes staring with '['
284 if (parser.parseOptionalLSquare()) {
285 // Nothing to do
286 return success();
287 }
288
289 spirv::MemoryAccess memoryAccessAttr;
290 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
291 kSourceMemoryAccessAttrName))
292 return failure();
293
294 if (spirv::bitEnumContainsAll(memoryAccessAttr,
295 spirv::MemoryAccess::Aligned)) {
296 // Parse integer attribute for alignment.
297 Attribute alignmentAttr;
298 Type i32Type = parser.getBuilder().getIntegerType(32);
299 if (parser.parseComma() ||
300 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
301 state.attributes)) {
302 return failure();
303 }
304 }
305 return parser.parseRSquare();
306}
307
308template <typename MemoryOpTy>
309static void printMemoryAccessAttribute(
310 MemoryOpTy memoryOp, OpAsmPrinter &printer,
311 SmallVectorImpl<StringRef> &elidedAttrs,
312 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
313 Optional<uint32_t> alignmentAttrValue = None) {
314 // Print optional memory access attribute.
315 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
316 : memoryOp.getMemoryAccess())) {
317 elidedAttrs.push_back(kMemoryAccessAttrName);
318
319 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
320
321 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
322 // Print integer alignment attribute.
323 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
324 : memoryOp.getAlignment())) {
325 elidedAttrs.push_back(kAlignmentAttrName);
326 printer << ", " << alignment;
327 }
328 }
329 printer << "]";
330 }
331 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
332}
333
334// TODO Make sure to merge this and the previous function into one template
335// parameterized by memory access attribute name and alignment. Doing so now
336// results in VS2017 in producing an internal error (at the call site) that's
337// not detailed enough to understand what is happening.
338template <typename MemoryOpTy>
339static void printSourceMemoryAccessAttribute(
340 MemoryOpTy memoryOp, OpAsmPrinter &printer,
341 SmallVectorImpl<StringRef> &elidedAttrs,
342 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
343 Optional<uint32_t> alignmentAttrValue = None) {
344
345 printer << ", ";
346
347 // Print optional memory access attribute.
348 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
349 : memoryOp.getMemoryAccess())) {
350 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
351
352 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
353
354 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
355 // Print integer alignment attribute.
356 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
357 : memoryOp.getAlignment())) {
358 elidedAttrs.push_back(kSourceAlignmentAttrName);
359 printer << ", " << alignment;
360 }
361 }
362 printer << "]";
363 }
364 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
365}
366
367static ParseResult parseImageOperands(OpAsmParser &parser,
368 spirv::ImageOperandsAttr &attr) {
369 // Expect image operands
370 if (parser.parseOptionalLSquare())
371 return success();
372
373 spirv::ImageOperands imageOperands;
374 if (parseEnumStrAttr(imageOperands, parser))
375 return failure();
376
377 attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
378
379 return parser.parseRSquare();
380}
381
382static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
383 spirv::ImageOperandsAttr attr) {
384 if (attr) {
385 auto strImageOperands = stringifyImageOperands(attr.getValue());
386 printer << "[\"" << strImageOperands << "\"]";
387 }
388}
389
390template <typename Op>
391static LogicalResult verifyImageOperands(Op imageOp,
392 spirv::ImageOperandsAttr attr,
393 Operation::operand_range operands) {
394 if (!attr) {
395 if (operands.empty())
396 return success();
397
398 return imageOp.emitError("the Image Operands should encode what operands "
399 "follow, as per Image Operands");
400 }
401
402 // TODO: Add the validation rules for the following Image Operands.
403 spirv::ImageOperands noSupportOperands =
404 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
405 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
406 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
407 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
408 spirv::ImageOperands::MakeTexelAvailable |
409 spirv::ImageOperands::MakeTexelVisible |
410 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
411
412 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
413 llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 413)
;
414
415 return success();
416}
417
418static LogicalResult verifyCastOp(Operation *op,
419 bool requireSameBitWidth = true,
420 bool skipBitWidthCheck = false) {
421 // Some CastOps have no limit on bit widths for result and operand type.
422 if (skipBitWidthCheck)
423 return success();
424
425 Type operandType = op->getOperand(0).getType();
426 Type resultType = op->getResult(0).getType();
427
428 // ODS checks that result type and operand type have the same shape.
429 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
430 operandType = vectorType.getElementType();
431 resultType = resultType.cast<VectorType>().getElementType();
432 }
433
434 if (auto coopMatrixType =
435 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
436 operandType = coopMatrixType.getElementType();
437 resultType =
438 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
439 }
440
441 if (auto jointMatrixType =
442 operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
443 operandType = jointMatrixType.getElementType();
444 resultType =
445 resultType.cast<spirv::JointMatrixINTELType>().getElementType();
446 }
447
448 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
449 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
450 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
451
452 if (requireSameBitWidth) {
453 if (!isSameBitWidth) {
454 return op->emitOpError(
455 "expected the same bit widths for operand type and result "
456 "type, but provided ")
457 << operandType << " and " << resultType;
458 }
459 return success();
460 }
461
462 if (isSameBitWidth) {
463 return op->emitOpError(
464 "expected the different bit widths for operand type and result "
465 "type, but provided ")
466 << operandType << " and " << resultType;
467 }
468 return success();
469}
470
471template <typename MemoryOpTy>
472static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
473 // ODS checks for attributes values. Just need to verify that if the
474 // memory-access attribute is Aligned, then the alignment attribute must be
475 // present.
476 auto *op = memoryOp.getOperation();
477 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
478 if (!memAccessAttr) {
479 // Alignment attribute shouldn't be present if memory access attribute is
480 // not present.
481 if (op->getAttr(kAlignmentAttrName)) {
482 return memoryOp.emitOpError(
483 "invalid alignment specification without aligned memory access "
484 "specification");
485 }
486 return success();
487 }
488
489 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
490
491 if (!memAccess) {
492 return memoryOp.emitOpError("invalid memory access specifier: ")
493 << memAccessAttr;
494 }
495
496 if (spirv::bitEnumContainsAll(memAccess.getValue(),
497 spirv::MemoryAccess::Aligned)) {
498 if (!op->getAttr(kAlignmentAttrName)) {
499 return memoryOp.emitOpError("missing alignment value");
500 }
501 } else {
502 if (op->getAttr(kAlignmentAttrName)) {
503 return memoryOp.emitOpError(
504 "invalid alignment specification with non-aligned memory access "
505 "specification");
506 }
507 }
508 return success();
509}
510
511// TODO Make sure to merge this and the previous function into one template
512// parameterized by memory access attribute name and alignment. Doing so now
513// results in VS2017 in producing an internal error (at the call site) that's
514// not detailed enough to understand what is happening.
515template <typename MemoryOpTy>
516static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
517 // ODS checks for attributes values. Just need to verify that if the
518 // memory-access attribute is Aligned, then the alignment attribute must be
519 // present.
520 auto *op = memoryOp.getOperation();
521 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
522 if (!memAccessAttr) {
523 // Alignment attribute shouldn't be present if memory access attribute is
524 // not present.
525 if (op->getAttr(kSourceAlignmentAttrName)) {
526 return memoryOp.emitOpError(
527 "invalid alignment specification without aligned memory access "
528 "specification");
529 }
530 return success();
531 }
532
533 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
534
535 if (!memAccess) {
536 return memoryOp.emitOpError("invalid memory access specifier: ")
537 << memAccess;
538 }
539
540 if (spirv::bitEnumContainsAll(memAccess.getValue(),
541 spirv::MemoryAccess::Aligned)) {
542 if (!op->getAttr(kSourceAlignmentAttrName)) {
543 return memoryOp.emitOpError("missing alignment value");
544 }
545 } else {
546 if (op->getAttr(kSourceAlignmentAttrName)) {
547 return memoryOp.emitOpError(
548 "invalid alignment specification with non-aligned memory access "
549 "specification");
550 }
551 }
552 return success();
553}
554
555static LogicalResult
556verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
557 // According to the SPIR-V specification:
558 // "Despite being a mask and allowing multiple bits to be combined, it is
559 // invalid for more than one of these four bits to be set: Acquire, Release,
560 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
561 // Release semantics is done by setting the AcquireRelease bit, not by setting
562 // two bits."
563 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
564 spirv::MemorySemantics::Release |
565 spirv::MemorySemantics::AcquireRelease |
566 spirv::MemorySemantics::SequentiallyConsistent;
567
568 auto bitCount = llvm::countPopulation(
569 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
570 if (bitCount > 1) {
571 return op->emitError(
572 "expected at most one of these four memory constraints "
573 "to be set: `Acquire`, `Release`,"
574 "`AcquireRelease` or `SequentiallyConsistent`");
575 }
576 return success();
577}
578
579template <typename LoadStoreOpTy>
580static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
581 Value val) {
582 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
583 // type of the pointer and the type of the value are the same
584 //
585 // TODO: Check that the value type satisfies restrictions of
586 // SPIR-V OpLoad/OpStore operations
587 if (val.getType() !=
588 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
589 return op.emitOpError("mismatch in result type and pointer type");
590 }
591 return success();
592}
593
594template <typename BlockReadWriteOpTy>
595static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
596 Value ptr, Value val) {
597 auto valType = val.getType();
598 if (auto valVecTy = valType.dyn_cast<VectorType>())
599 valType = valVecTy.getElementType();
600
601 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
602 return op.emitOpError("mismatch in result type and pointer type");
603 }
604 return success();
605}
606
607static ParseResult parseVariableDecorations(OpAsmParser &parser,
608 OperationState &state) {
609 auto builtInName = llvm::convertToSnakeFromCamelCase(
610 stringifyDecoration(spirv::Decoration::BuiltIn));
611 if (succeeded(parser.parseOptionalKeyword("bind"))) {
612 Attribute set, binding;
613 // Parse optional descriptor binding
614 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
615 stringifyDecoration(spirv::Decoration::DescriptorSet));
616 auto bindingName = llvm::convertToSnakeFromCamelCase(
617 stringifyDecoration(spirv::Decoration::Binding));
618 Type i32Type = parser.getBuilder().getIntegerType(32);
619 if (parser.parseLParen() ||
620 parser.parseAttribute(set, i32Type, descriptorSetName,
621 state.attributes) ||
622 parser.parseComma() ||
623 parser.parseAttribute(binding, i32Type, bindingName,
624 state.attributes) ||
625 parser.parseRParen()) {
626 return failure();
627 }
628 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
629 StringAttr builtIn;
630 if (parser.parseLParen() ||
631 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
632 parser.parseRParen()) {
633 return failure();
634 }
635 }
636
637 // Parse other attributes
638 if (parser.parseOptionalAttrDict(state.attributes))
639 return failure();
640
641 return success();
642}
643
644static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
645 SmallVectorImpl<StringRef> &elidedAttrs) {
646 // Print optional descriptor binding
647 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
648 stringifyDecoration(spirv::Decoration::DescriptorSet));
649 auto bindingName = llvm::convertToSnakeFromCamelCase(
650 stringifyDecoration(spirv::Decoration::Binding));
651 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
652 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
653 if (descriptorSet && binding) {
654 elidedAttrs.push_back(descriptorSetName);
655 elidedAttrs.push_back(bindingName);
656 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
657 << ")";
658 }
659
660 // Print BuiltIn attribute if present
661 auto builtInName = llvm::convertToSnakeFromCamelCase(
662 stringifyDecoration(spirv::Decoration::BuiltIn));
663 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
664 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
665 elidedAttrs.push_back(builtInName);
666 }
667
668 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
669}
670
671// Get bit width of types.
672static unsigned getBitWidth(Type type) {
673 if (type.isa<spirv::PointerType>()) {
674 // Just return 64 bits for pointer types for now.
675 // TODO: Make sure not caller relies on the actual pointer width value.
676 return 64;
677 }
678
679 if (type.isIntOrFloat())
680 return type.getIntOrFloatBitWidth();
681
682 if (auto vectorType = type.dyn_cast<VectorType>()) {
683 assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 683, __extension__
__PRETTY_FUNCTION__))
;
684 return vectorType.getNumElements() *
685 vectorType.getElementType().getIntOrFloatBitWidth();
686 }
687 llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 687)
;
688}
689
690/// Walks the given type hierarchy with the given indices, potentially down
691/// to component granularity, to select an element type. Returns null type and
692/// emits errors with the given loc on failure.
693static Type
694getElementType(Type type, ArrayRef<int32_t> indices,
695 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
696 if (indices.empty()) {
697 emitErrorFn("expected at least one index for spirv.CompositeExtract");
698 return nullptr;
699 }
700
701 for (auto index : indices) {
702 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
703 if (cType.hasCompileTimeKnownNumElements() &&
704 (index < 0 ||
705 static_cast<uint64_t>(index) >= cType.getNumElements())) {
706 emitErrorFn("index ") << index << " out of bounds for " << type;
707 return nullptr;
708 }
709 type = cType.getElementType(index);
710 } else {
711 emitErrorFn("cannot extract from non-composite type ")
712 << type << " with index " << index;
713 return nullptr;
714 }
715 }
716 return type;
717}
718
719static Type
720getElementType(Type type, Attribute indices,
721 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
722 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
723 if (!indicesArrayAttr) {
724 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
725 return nullptr;
726 }
727 if (indicesArrayAttr.empty()) {
728 emitErrorFn("expected at least one index for spirv.CompositeExtract");
729 return nullptr;
730 }
731
732 SmallVector<int32_t, 2> indexVals;
733 for (auto indexAttr : indicesArrayAttr) {
734 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
735 if (!indexIntAttr) {
736 emitErrorFn("expected an 32-bit integer for index, but found '")
737 << indexAttr << "'";
738 return nullptr;
739 }
740 indexVals.push_back(indexIntAttr.getInt());
741 }
742 return getElementType(type, indexVals, emitErrorFn);
743}
744
745static Type getElementType(Type type, Attribute indices, Location loc) {
746 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
747 return ::mlir::emitError(loc, err);
748 };
749 return getElementType(type, indices, errorFn);
750}
751
752static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
753 SMLoc loc) {
754 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
755 return parser.emitError(loc, err);
756 };
757 return getElementType(type, indices, errorFn);
758}
759
760/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
761static inline bool isMergeBlock(Block &block) {
762 return !block.empty() && std::next(block.begin()) == block.end() &&
763 isa<spirv::MergeOp>(block.front());
764}
765
766//===----------------------------------------------------------------------===//
767// Common parsers and printers
768//===----------------------------------------------------------------------===//
769
770// Parses an atomic update op. If the update op does not take a value (like
771// AtomicIIncrement) `hasValue` must be false.
772static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
773 OperationState &state, bool hasValue) {
774 spirv::Scope scope;
775 spirv::MemorySemantics memoryScope;
776 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
777 OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
778 Type type;
779 SMLoc loc;
780 if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
781 kMemoryScopeAttrName) ||
782 parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
783 kSemanticsAttrName) ||
784 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
785 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
786 return failure();
787
788 auto ptrType = type.dyn_cast<spirv::PointerType>();
789 if (!ptrType)
790 return parser.emitError(loc, "expected pointer type");
791
792 SmallVector<Type, 2> operandTypes;
793 operandTypes.push_back(ptrType);
794 if (hasValue)
795 operandTypes.push_back(ptrType.getPointeeType());
796 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
797 state.operands))
798 return failure();
799 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
800}
801
802// Prints an atomic update op.
803static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
804 printer << " \"";
805 auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
806 printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
807 auto memorySemanticsAttr =
808 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
809 printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
810 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
811}
812
813template <typename T>
814static StringRef stringifyTypeName();
815
816template <>
817StringRef stringifyTypeName<IntegerType>() {
818 return "integer";
819}
820
821template <>
822StringRef stringifyTypeName<FloatType>() {
823 return "float";
824}
825
826// Verifies an atomic update op.
827template <typename ExpectedElementType>
828static LogicalResult verifyAtomicUpdateOp(Operation *op) {
829 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
830 auto elementType = ptrType.getPointeeType();
831 if (!elementType.isa<ExpectedElementType>())
832 return op->emitOpError() << "pointer operand must point to an "
833 << stringifyTypeName<ExpectedElementType>()
834 << " value, found " << elementType;
835
836 if (op->getNumOperands() > 1) {
837 auto valueType = op->getOperand(1).getType();
838 if (valueType != elementType)
839 return op->emitOpError("expected value to have the same type as the "
840 "pointer operand's pointee type ")
841 << elementType << ", but found " << valueType;
842 }
843 auto memorySemantics =
844 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
845 .getValue();
846 if (failed(verifyMemorySemantics(op, memorySemantics))) {
847 return failure();
848 }
849 return success();
850}
851
852static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
853 OperationState &state) {
854 spirv::Scope executionScope;
855 spirv::GroupOperation groupOperation;
856 OpAsmParser::UnresolvedOperand valueInfo;
857 if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
858 kExecutionScopeAttrName) ||
859 parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
860 kGroupOperationAttrName) ||
861 parser.parseOperand(valueInfo))
862 return failure();
863
864 Optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
865 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
866 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
867 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
868 parser.parseRParen())
869 return failure();
870 }
871
872 Type resultType;
873 if (parser.parseColonType(resultType))
874 return failure();
875
876 if (parser.resolveOperand(valueInfo, resultType, state.operands))
877 return failure();
878
879 if (clusterSizeInfo) {
880 Type i32Type = parser.getBuilder().getIntegerType(32);
881 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
882 return failure();
883 }
884
885 return parser.addTypeToList(resultType, state.types);
886}
887
888static void printGroupNonUniformArithmeticOp(Operation *groupOp,
889 OpAsmPrinter &printer) {
890 printer
891 << " \""
892 << stringifyScope(
893 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
894 .getValue())
895 << "\" \""
896 << stringifyGroupOperation(groupOp
897 ->getAttrOfType<spirv::GroupOperationAttr>(
898 kGroupOperationAttrName)
899 .getValue())
900 << "\" " << groupOp->getOperand(0);
901
902 if (groupOp->getNumOperands() > 1)
903 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
904 printer << " : " << groupOp->getResult(0).getType();
905}
906
907static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
908 spirv::Scope scope =
909 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
910 .getValue();
911 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
912 return groupOp->emitOpError(
913 "execution scope must be 'Workgroup' or 'Subgroup'");
914
915 spirv::GroupOperation operation =
916 groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
917 .getValue();
918 if (operation == spirv::GroupOperation::ClusteredReduce &&
919 groupOp->getNumOperands() == 1)
920 return groupOp->emitOpError("cluster size operand must be provided for "
921 "'ClusteredReduce' group operation");
922 if (groupOp->getNumOperands() > 1) {
923 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
924 int32_t clusterSize = 0;
925
926 // TODO: support specialization constant here.
927 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
928 return groupOp->emitOpError(
929 "cluster size operand must come from a constant op");
930
931 if (!llvm::isPowerOf2_32(clusterSize))
932 return groupOp->emitOpError(
933 "cluster size operand must be a power of two");
934 }
935 return success();
936}
937
938/// Result of a logical op must be a scalar or vector of boolean type.
939static Type getUnaryOpResultType(Type operandType) {
940 Builder builder(operandType.getContext());
941 Type resultType = builder.getIntegerType(1);
942 if (auto vecType = operandType.dyn_cast<VectorType>())
943 return VectorType::get(vecType.getNumElements(), resultType);
944 return resultType;
945}
946
947static LogicalResult verifyShiftOp(Operation *op) {
948 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
949 return op->emitError("expected the same type for the first operand and "
950 "result, but provided ")
951 << op->getOperand(0).getType() << " and "
952 << op->getResult(0).getType();
953 }
954 return success();
955}
956
957static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
958 Value lhs, Value rhs) {
959 assert(lhs.getType() == rhs.getType())(static_cast <bool> (lhs.getType() == rhs.getType()) ? void
(0) : __assert_fail ("lhs.getType() == rhs.getType()", "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp"
, 959, __extension__ __PRETTY_FUNCTION__))
;
960
961 Type boolType = builder.getI1Type();
962 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
963 boolType = VectorType::get(vecType.getShape(), boolType);
964 state.addTypes(boolType);
965
966 state.addOperands({lhs, rhs});
967}
968
969static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
970 Value value) {
971 Type boolType = builder.getI1Type();
972 if (auto vecType = value.getType().dyn_cast<VectorType>())
973 boolType = VectorType::get(vecType.getShape(), boolType);
974 state.addTypes(boolType);
975
976 state.addOperands(value);
977}
978
979//===----------------------------------------------------------------------===//
980// spirv.AccessChainOp
981//===----------------------------------------------------------------------===//
982
983static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
984 auto ptrType = type.dyn_cast<spirv::PointerType>();
985 if (!ptrType) {
986 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
987 "to composite type, but provided ")
988 << type;
989 return nullptr;
990 }
991
992 auto resultType = ptrType.getPointeeType();
993 auto resultStorageClass = ptrType.getStorageClass();
994 int32_t index = 0;
995
996 for (auto indexSSA : indices) {
997 auto cType = resultType.dyn_cast<spirv::CompositeType>();
998 if (!cType) {
999 emitError(
1000 baseLoc,
1001 "'spirv.AccessChain' op cannot extract from non-composite type ")
1002 << resultType << " with index " << index;
1003 return nullptr;
1004 }
1005 index = 0;
1006 if (resultType.isa<spirv::StructType>()) {
1007 Operation *op = indexSSA.getDefiningOp();
1008 if (!op) {
1009 emitError(baseLoc, "'spirv.AccessChain' op index must be an "
1010 "integer spirv.Constant to access "
1011 "element of spirv.struct");
1012 return nullptr;
1013 }
1014
1015 // TODO: this should be relaxed to allow
1016 // integer literals of other bitwidths.
1017 if (failed(extractValueFromConstOp(op, index))) {
1018 emitError(
1019 baseLoc,
1020 "'spirv.AccessChain' index must be an integer spirv.Constant to "
1021 "access element of spirv.struct, but provided ")
1022 << op->getName();
1023 return nullptr;
1024 }
1025 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1026 emitError(baseLoc, "'spirv.AccessChain' op index ")
1027 << index << " out of bounds for " << resultType;
1028 return nullptr;
1029 }
1030 }
1031 resultType = cType.getElementType(index);
1032 }
1033 return spirv::PointerType::get(resultType, resultStorageClass);
1034}
1035
1036void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1037 Value basePtr, ValueRange indices) {
1038 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1039 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1039, __extension__
__PRETTY_FUNCTION__))
;
1040 build(builder, state, type, basePtr, indices);
1041}
1042
1043ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1044 OperationState &result) {
1045 OpAsmParser::UnresolvedOperand ptrInfo;
1046 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1047 Type type;
1048 auto loc = parser.getCurrentLocation();
1049 SmallVector<Type, 4> indicesTypes;
1050
1051 if (parser.parseOperand(ptrInfo) ||
1052 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1053 parser.parseColonType(type) ||
1054 parser.resolveOperand(ptrInfo, type, result.operands)) {
1055 return failure();
1056 }
1057
1058 // Check that the provided indices list is not empty before parsing their
1059 // type list.
1060 if (indicesInfo.empty()) {
1061 return mlir::emitError(result.location,
1062 "'spirv.AccessChain' op expected at "
1063 "least one index ");
1064 }
1065
1066 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1067 return failure();
1068
1069 // Check that the indices types list is not empty and that it has a one-to-one
1070 // mapping to the provided indices.
1071 if (indicesTypes.size() != indicesInfo.size()) {
1072 return mlir::emitError(
1073 result.location, "'spirv.AccessChain' op indices types' count must be "
1074 "equal to indices info count");
1075 }
1076
1077 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1078 return failure();
1079
1080 auto resultType = getElementPtrType(
1081 type, llvm::makeArrayRef(result.operands).drop_front(), result.location);
1082 if (!resultType) {
1083 return failure();
1084 }
1085
1086 result.addTypes(resultType);
1087 return success();
1088}
1089
1090template <typename Op>
1091static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1092 printer << ' ' << op.getBasePtr() << '[' << indices
1093 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
1094}
1095
1096void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1097 printAccessChain(*this, getIndices(), printer);
1098}
1099
1100template <typename Op>
1101static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1102 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
1103 indices, accessChainOp.getLoc());
1104 if (!resultType)
1105 return failure();
1106
1107 auto providedResultType =
1108 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1109 if (!providedResultType)
1110 return accessChainOp.emitOpError(
1111 "result type must be a pointer, but provided")
1112 << providedResultType;
1113
1114 if (resultType != providedResultType)
1115 return accessChainOp.emitOpError("invalid result type: expected ")
1116 << resultType << ", but provided " << providedResultType;
1117
1118 return success();
1119}
1120
1121LogicalResult spirv::AccessChainOp::verify() {
1122 return verifyAccessChain(*this, getIndices());
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// spirv.mlir.addressof
1127//===----------------------------------------------------------------------===//
1128
1129void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1130 spirv::GlobalVariableOp var) {
1131 build(builder, state, var.getType(), SymbolRefAttr::get(var));
1132}
1133
1134LogicalResult spirv::AddressOfOp::verify() {
1135 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1136 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1137 getVariableAttr()));
1138 if (!varOp) {
1139 return emitOpError("expected spirv.GlobalVariable symbol");
1140 }
1141 if (getPointer().getType() != varOp.getType()) {
1142 return emitOpError(
1143 "result type mismatch with the referenced global variable's type");
1144 }
1145 return success();
1146}
1147
1148template <typename T>
1149static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1150 printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
1151 << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
1152 << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
1153 << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
1154}
1155
1156static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1157 OperationState &state) {
1158 spirv::Scope memoryScope;
1159 spirv::MemorySemantics equalSemantics, unequalSemantics;
1160 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1161 Type type;
1162 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1163 kMemoryScopeAttrName) ||
1164 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1165 equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1166 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1167 unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1168 parser.parseOperandList(operandInfo, 3))
1169 return failure();
1170
1171 auto loc = parser.getCurrentLocation();
1172 if (parser.parseColonType(type))
1173 return failure();
1174
1175 auto ptrType = type.dyn_cast<spirv::PointerType>();
1176 if (!ptrType)
1177 return parser.emitError(loc, "expected pointer type");
1178
1179 if (parser.resolveOperands(
1180 operandInfo,
1181 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1182 parser.getNameLoc(), state.operands))
1183 return failure();
1184
1185 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1186}
1187
1188template <typename T>
1189static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1190 // According to the spec:
1191 // "The type of Value must be the same as Result Type. The type of the value
1192 // pointed to by Pointer must be the same as Result Type. This type must also
1193 // match the type of Comparator."
1194 if (atomOp.getType() != atomOp.getValue().getType())
1195 return atomOp.emitOpError("value operand must have the same type as the op "
1196 "result, but found ")
1197 << atomOp.getValue().getType() << " vs " << atomOp.getType();
1198
1199 if (atomOp.getType() != atomOp.getComparator().getType())
1200 return atomOp.emitOpError(
1201 "comparator operand must have the same type as the op "
1202 "result, but found ")
1203 << atomOp.getComparator().getType() << " vs " << atomOp.getType();
1204
1205 Type pointeeType = atomOp.getPointer()
1206 .getType()
1207 .template cast<spirv::PointerType>()
1208 .getPointeeType();
1209 if (atomOp.getType() != pointeeType)
1210 return atomOp.emitOpError(
1211 "pointer operand's pointee type must have the same "
1212 "as the op result type, but found ")
1213 << pointeeType << " vs " << atomOp.getType();
1214
1215 // TODO: Unequal cannot be set to Release or Acquire and Release.
1216 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1217
1218 return success();
1219}
1220
1221//===----------------------------------------------------------------------===//
1222// spirv.AtomicAndOp
1223//===----------------------------------------------------------------------===//
1224
1225LogicalResult spirv::AtomicAndOp::verify() {
1226 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1227}
1228
1229ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1230 OperationState &result) {
1231 return ::parseAtomicUpdateOp(parser, result, true);
1232}
1233void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1234 ::printAtomicUpdateOp(*this, p);
1235}
1236
1237//===----------------------------------------------------------------------===//
1238// spirv.AtomicCompareExchangeOp
1239//===----------------------------------------------------------------------===//
1240
1241LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1242 return ::verifyAtomicCompareExchangeImpl(*this);
1243}
1244
1245ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1246 OperationState &result) {
1247 return ::parseAtomicCompareExchangeImpl(parser, result);
1248}
1249void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1250 ::printAtomicCompareExchangeImpl(*this, p);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// spirv.AtomicCompareExchangeWeakOp
1255//===----------------------------------------------------------------------===//
1256
1257LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1258 return ::verifyAtomicCompareExchangeImpl(*this);
1259}
1260
1261ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1262 OperationState &result) {
1263 return ::parseAtomicCompareExchangeImpl(parser, result);
1264}
1265void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1266 ::printAtomicCompareExchangeImpl(*this, p);
1267}
1268
1269//===----------------------------------------------------------------------===//
1270// spirv.AtomicExchange
1271//===----------------------------------------------------------------------===//
1272
1273void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1274 printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
1275 << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
1276 << " : " << getPointer().getType();
1277}
1278
1279ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1280 OperationState &result) {
1281 spirv::Scope memoryScope;
1282 spirv::MemorySemantics semantics;
1283 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1284 Type type;
1285 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1286 kMemoryScopeAttrName) ||
1287 parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1288 kSemanticsAttrName) ||
1289 parser.parseOperandList(operandInfo, 2))
1290 return failure();
1291
1292 auto loc = parser.getCurrentLocation();
1293 if (parser.parseColonType(type))
1294 return failure();
1295
1296 auto ptrType = type.dyn_cast<spirv::PointerType>();
1297 if (!ptrType)
1298 return parser.emitError(loc, "expected pointer type");
1299
1300 if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1301 parser.getNameLoc(), result.operands))
1302 return failure();
1303
1304 return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1305}
1306
1307LogicalResult spirv::AtomicExchangeOp::verify() {
1308 if (getType() != getValue().getType())
1309 return emitOpError("value operand must have the same type as the op "
1310 "result, but found ")
1311 << getValue().getType() << " vs " << getType();
1312
1313 Type pointeeType =
1314 getPointer().getType().cast<spirv::PointerType>().getPointeeType();
1315 if (getType() != pointeeType)
1316 return emitOpError("pointer operand's pointee type must have the same "
1317 "as the op result type, but found ")
1318 << pointeeType << " vs " << getType();
1319
1320 return success();
1321}
1322
1323//===----------------------------------------------------------------------===//
1324// spirv.AtomicIAddOp
1325//===----------------------------------------------------------------------===//
1326
1327LogicalResult spirv::AtomicIAddOp::verify() {
1328 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1329}
1330
1331ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1332 OperationState &result) {
1333 return ::parseAtomicUpdateOp(parser, result, true);
1334}
1335void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1336 ::printAtomicUpdateOp(*this, p);
1337}
1338
1339//===----------------------------------------------------------------------===//
1340// spirv.EXT.AtomicFAddOp
1341//===----------------------------------------------------------------------===//
1342
1343LogicalResult spirv::EXTAtomicFAddOp::verify() {
1344 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1345}
1346
1347ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
1348 OperationState &result) {
1349 return ::parseAtomicUpdateOp(parser, result, true);
1350}
1351void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
1352 ::printAtomicUpdateOp(*this, p);
1353}
1354
1355//===----------------------------------------------------------------------===//
1356// spirv.AtomicIDecrementOp
1357//===----------------------------------------------------------------------===//
1358
1359LogicalResult spirv::AtomicIDecrementOp::verify() {
1360 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1361}
1362
1363ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1364 OperationState &result) {
1365 return ::parseAtomicUpdateOp(parser, result, false);
1366}
1367void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1368 ::printAtomicUpdateOp(*this, p);
1369}
1370
1371//===----------------------------------------------------------------------===//
1372// spirv.AtomicIIncrementOp
1373//===----------------------------------------------------------------------===//
1374
1375LogicalResult spirv::AtomicIIncrementOp::verify() {
1376 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1377}
1378
1379ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1380 OperationState &result) {
1381 return ::parseAtomicUpdateOp(parser, result, false);
1382}
1383void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1384 ::printAtomicUpdateOp(*this, p);
1385}
1386
1387//===----------------------------------------------------------------------===//
1388// spirv.AtomicISubOp
1389//===----------------------------------------------------------------------===//
1390
1391LogicalResult spirv::AtomicISubOp::verify() {
1392 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1393}
1394
1395ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1396 OperationState &result) {
1397 return ::parseAtomicUpdateOp(parser, result, true);
1398}
1399void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1400 ::printAtomicUpdateOp(*this, p);
1401}
1402
1403//===----------------------------------------------------------------------===//
1404// spirv.AtomicOrOp
1405//===----------------------------------------------------------------------===//
1406
1407LogicalResult spirv::AtomicOrOp::verify() {
1408 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1409}
1410
1411ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1412 OperationState &result) {
1413 return ::parseAtomicUpdateOp(parser, result, true);
1414}
1415void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1416 ::printAtomicUpdateOp(*this, p);
1417}
1418
1419//===----------------------------------------------------------------------===//
1420// spirv.AtomicSMaxOp
1421//===----------------------------------------------------------------------===//
1422
1423LogicalResult spirv::AtomicSMaxOp::verify() {
1424 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1425}
1426
1427ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1428 OperationState &result) {
1429 return ::parseAtomicUpdateOp(parser, result, true);
1430}
1431void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1432 ::printAtomicUpdateOp(*this, p);
1433}
1434
1435//===----------------------------------------------------------------------===//
1436// spirv.AtomicSMinOp
1437//===----------------------------------------------------------------------===//
1438
1439LogicalResult spirv::AtomicSMinOp::verify() {
1440 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1441}
1442
1443ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1444 OperationState &result) {
1445 return ::parseAtomicUpdateOp(parser, result, true);
1446}
1447void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1448 ::printAtomicUpdateOp(*this, p);
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// spirv.AtomicUMaxOp
1453//===----------------------------------------------------------------------===//
1454
1455LogicalResult spirv::AtomicUMaxOp::verify() {
1456 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1457}
1458
1459ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1460 OperationState &result) {
1461 return ::parseAtomicUpdateOp(parser, result, true);
1462}
1463void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1464 ::printAtomicUpdateOp(*this, p);
1465}
1466
1467//===----------------------------------------------------------------------===//
1468// spirv.AtomicUMinOp
1469//===----------------------------------------------------------------------===//
1470
1471LogicalResult spirv::AtomicUMinOp::verify() {
1472 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1473}
1474
1475ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1476 OperationState &result) {
1477 return ::parseAtomicUpdateOp(parser, result, true);
1478}
1479void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1480 ::printAtomicUpdateOp(*this, p);
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// spirv.AtomicXorOp
1485//===----------------------------------------------------------------------===//
1486
1487LogicalResult spirv::AtomicXorOp::verify() {
1488 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1489}
1490
1491ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1492 OperationState &result) {
1493 return ::parseAtomicUpdateOp(parser, result, true);
1494}
1495void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1496 ::printAtomicUpdateOp(*this, p);
1497}
1498
1499//===----------------------------------------------------------------------===//
1500// spirv.BitcastOp
1501//===----------------------------------------------------------------------===//
1502
1503LogicalResult spirv::BitcastOp::verify() {
1504 // TODO: The SPIR-V spec validation rules are different for different
1505 // versions.
1506 auto operandType = getOperand().getType();
1507 auto resultType = getResult().getType();
1508 if (operandType == resultType) {
1509 return emitError("result type must be different from operand type");
1510 }
1511 if (operandType.isa<spirv::PointerType>() &&
1512 !resultType.isa<spirv::PointerType>()) {
1513 return emitError(
1514 "unhandled bit cast conversion from pointer type to non-pointer type");
1515 }
1516 if (!operandType.isa<spirv::PointerType>() &&
1517 resultType.isa<spirv::PointerType>()) {
1518 return emitError(
1519 "unhandled bit cast conversion from non-pointer type to pointer type");
1520 }
1521 auto operandBitWidth = getBitWidth(operandType);
1522 auto resultBitWidth = getBitWidth(resultType);
1523 if (operandBitWidth != resultBitWidth) {
1524 return emitOpError("mismatch in result type bitwidth ")
1525 << resultBitWidth << " and operand type bitwidth "
1526 << operandBitWidth;
1527 }
1528 return success();
1529}
1530
1531//===----------------------------------------------------------------------===//
1532// spirv.PtrCastToGenericOp
1533//===----------------------------------------------------------------------===//
1534
1535LogicalResult spirv::PtrCastToGenericOp::verify() {
1536 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1537 auto resultType = getResult().getType().cast<spirv::PointerType>();
1538
1539 spirv::StorageClass operandStorage = operandType.getStorageClass();
1540 if (operandStorage != spirv::StorageClass::Workgroup &&
1541 operandStorage != spirv::StorageClass::CrossWorkgroup &&
1542 operandStorage != spirv::StorageClass::Function)
1543 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
1544 ", or Function Storage Class");
1545
1546 spirv::StorageClass resultStorage = resultType.getStorageClass();
1547 if (resultStorage != spirv::StorageClass::Generic)
1548 return emitError("result type must be of storage class Generic");
1549
1550 Type operandPointeeType = operandType.getPointeeType();
1551 Type resultPointeeType = resultType.getPointeeType();
1552 if (operandPointeeType != resultPointeeType)
1553 return emitOpError("pointer operand's pointee type must have the same "
1554 "as the op result type, but found ")
1555 << operandPointeeType << " vs " << resultPointeeType;
1556 return success();
1557}
1558
1559//===----------------------------------------------------------------------===//
1560// spirv.GenericCastToPtrOp
1561//===----------------------------------------------------------------------===//
1562
1563LogicalResult spirv::GenericCastToPtrOp::verify() {
1564 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1565 auto resultType = getResult().getType().cast<spirv::PointerType>();
1566
1567 spirv::StorageClass operandStorage = operandType.getStorageClass();
1568 if (operandStorage != spirv::StorageClass::Generic)
1569 return emitError("pointer type must be of storage class Generic");
1570
1571 spirv::StorageClass resultStorage = resultType.getStorageClass();
1572 if (resultStorage != spirv::StorageClass::Workgroup &&
1573 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1574 resultStorage != spirv::StorageClass::Function)
1575 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1576 "or Function Storage Class");
1577
1578 Type operandPointeeType = operandType.getPointeeType();
1579 Type resultPointeeType = resultType.getPointeeType();
1580 if (operandPointeeType != resultPointeeType)
1581 return emitOpError("pointer operand's pointee type must have the same "
1582 "as the op result type, but found ")
1583 << operandPointeeType << " vs " << resultPointeeType;
1584 return success();
1585}
1586
1587//===----------------------------------------------------------------------===//
1588// spirv.GenericCastToPtrExplicitOp
1589//===----------------------------------------------------------------------===//
1590
1591LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
1592 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1593 auto resultType = getResult().getType().cast<spirv::PointerType>();
1594
1595 spirv::StorageClass operandStorage = operandType.getStorageClass();
1596 if (operandStorage != spirv::StorageClass::Generic)
1597 return emitError("pointer type must be of storage class Generic");
1598
1599 spirv::StorageClass resultStorage = resultType.getStorageClass();
1600 if (resultStorage != spirv::StorageClass::Workgroup &&
1601 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1602 resultStorage != spirv::StorageClass::Function)
1603 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1604 "or Function Storage Class");
1605
1606 Type operandPointeeType = operandType.getPointeeType();
1607 Type resultPointeeType = resultType.getPointeeType();
1608 if (operandPointeeType != resultPointeeType)
1609 return emitOpError("pointer operand's pointee type must have the same "
1610 "as the op result type, but found ")
1611 << operandPointeeType << " vs " << resultPointeeType;
1612 return success();
1613}
1614
1615//===----------------------------------------------------------------------===//
1616// spirv.BranchOp
1617//===----------------------------------------------------------------------===//
1618
1619SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1620 assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index"
) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1620, __extension__
__PRETTY_FUNCTION__))
;
1621 return SuccessorOperands(0, getTargetOperandsMutable());
1622}
1623
1624//===----------------------------------------------------------------------===//
1625// spirv.BranchConditionalOp
1626//===----------------------------------------------------------------------===//
1627
1628SuccessorOperands
1629spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1630 assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index"
) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1630, __extension__
__PRETTY_FUNCTION__))
;
1631 return SuccessorOperands(index == kTrueIndex
1632 ? getTrueTargetOperandsMutable()
1633 : getFalseTargetOperandsMutable());
1634}
1635
1636ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1637 OperationState &result) {
1638 auto &builder = parser.getBuilder();
1639 OpAsmParser::UnresolvedOperand condInfo;
1640 Block *dest;
1641
1642 // Parse the condition.
1643 Type boolTy = builder.getI1Type();
1644 if (parser.parseOperand(condInfo) ||
1645 parser.resolveOperand(condInfo, boolTy, result.operands))
1646 return failure();
1647
1648 // Parse the optional branch weights.
1649 if (succeeded(parser.parseOptionalLSquare())) {
1650 IntegerAttr trueWeight, falseWeight;
1651 NamedAttrList weights;
1652
1653 auto i32Type = builder.getIntegerType(32);
1654 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1655 parser.parseComma() ||
1656 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1657 parser.parseRSquare())
1658 return failure();
1659
1660 result.addAttribute(kBranchWeightAttrName,
1661 builder.getArrayAttr({trueWeight, falseWeight}));
1662 }
1663
1664 // Parse the true branch.
1665 SmallVector<Value, 4> trueOperands;
1666 if (parser.parseComma() ||
1667 parser.parseSuccessorAndUseList(dest, trueOperands))
1668 return failure();
1669 result.addSuccessors(dest);
1670 result.addOperands(trueOperands);
1671
1672 // Parse the false branch.
1673 SmallVector<Value, 4> falseOperands;
1674 if (parser.parseComma() ||
1675 parser.parseSuccessorAndUseList(dest, falseOperands))
1676 return failure();
1677 result.addSuccessors(dest);
1678 result.addOperands(falseOperands);
1679 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1680 builder.getDenseI32ArrayAttr(
1681 {1, static_cast<int32_t>(trueOperands.size()),
1682 static_cast<int32_t>(falseOperands.size())}));
1683
1684 return success();
1685}
1686
1687void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1688 printer << ' ' << getCondition();
1689
1690 if (auto weights = getBranchWeights()) {
1691 printer << " [";
1692 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1693 printer << a.cast<IntegerAttr>().getInt();
1694 });
1695 printer << "]";
1696 }
1697
1698 printer << ", ";
1699 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1700 printer << ", ";
1701 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1702}
1703
1704LogicalResult spirv::BranchConditionalOp::verify() {
1705 if (auto weights = getBranchWeights()) {
1706 if (weights->getValue().size() != 2) {
1707 return emitOpError("must have exactly two branch weights");
1708 }
1709 if (llvm::all_of(*weights, [](Attribute attr) {
1710 return attr.cast<IntegerAttr>().getValue().isNullValue();
1711 }))
1712 return emitOpError("branch weights cannot both be zero");
1713 }
1714
1715 return success();
1716}
1717
1718//===----------------------------------------------------------------------===//
1719// spirv.CompositeConstruct
1720//===----------------------------------------------------------------------===//
1721
1722LogicalResult spirv::CompositeConstructOp::verify() {
1723 auto cType = getType().cast<spirv::CompositeType>();
1724 operand_range constituents = this->getConstituents();
1725
1726 if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1727 if (constituents.size() != 1)
1728 return emitOpError("has incorrect number of operands: expected ")
1729 << "1, but provided " << constituents.size();
1730 if (coopType.getElementType() != constituents.front().getType())
1731 return emitOpError("operand type mismatch: expected operand type ")
1732 << coopType.getElementType() << ", but provided "
1733 << constituents.front().getType();
1734 return success();
1735 }
1736
1737 if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1738 if (constituents.size() != 1)
1739 return emitOpError("has incorrect number of operands: expected ")
1740 << "1, but provided " << constituents.size();
1741 if (jointType.getElementType() != constituents.front().getType())
1742 return emitOpError("operand type mismatch: expected operand type ")
1743 << jointType.getElementType() << ", but provided "
1744 << constituents.front().getType();
1745 return success();
1746 }
1747
1748 if (constituents.size() == cType.getNumElements()) {
1749 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1750 if (constituents[index].getType() != cType.getElementType(index)) {
1751 return emitOpError("operand type mismatch: expected operand type ")
1752 << cType.getElementType(index) << ", but provided "
1753 << constituents[index].getType();
1754 }
1755 }
1756 return success();
1757 }
1758
1759 // If not constructing a cooperative matrix type, then we must be constructing
1760 // a vector type.
1761 auto resultType = cType.dyn_cast<VectorType>();
1762 if (!resultType)
1763 return emitOpError(
1764 "expected to return a vector or cooperative matrix when the number of "
1765 "constituents is less than what the result needs");
1766
1767 SmallVector<unsigned> sizes;
1768 for (Value component : constituents) {
1769 if (!component.getType().isa<VectorType>() &&
1770 !component.getType().isIntOrFloat())
1771 return emitOpError("operand type mismatch: expected operand to have "
1772 "a scalar or vector type, but provided ")
1773 << component.getType();
1774
1775 Type elementType = component.getType();
1776 if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1777 sizes.push_back(vectorType.getNumElements());
1778 elementType = vectorType.getElementType();
1779 } else {
1780 sizes.push_back(1);
1781 }
1782
1783 if (elementType != resultType.getElementType())
1784 return emitOpError("operand element type mismatch: expected to be ")
1785 << resultType.getElementType() << ", but provided " << elementType;
1786 }
1787 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1788 if (totalCount != cType.getNumElements())
1789 return emitOpError("has incorrect number of operands: expected ")
1790 << cType.getNumElements() << ", but provided " << totalCount;
1791 return success();
1792}
1793
1794//===----------------------------------------------------------------------===//
1795// spirv.CompositeExtractOp
1796//===----------------------------------------------------------------------===//
1797
1798void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1799 Value composite,
1800 ArrayRef<int32_t> indices) {
1801 auto indexAttr = builder.getI32ArrayAttr(indices);
1802 auto elementType =
1803 getElementType(composite.getType(), indexAttr, state.location);
1804 if (!elementType) {
1805 return;
1806 }
1807 build(builder, state, elementType, composite, indexAttr);
1808}
1809
1810ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1811 OperationState &result) {
1812 OpAsmParser::UnresolvedOperand compositeInfo;
1813 Attribute indicesAttr;
1814 Type compositeType;
1815 SMLoc attrLocation;
1816
1817 if (parser.parseOperand(compositeInfo) ||
1818 parser.getCurrentLocation(&attrLocation) ||
1819 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1820 parser.parseColonType(compositeType) ||
1821 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1822 return failure();
1823 }
1824
1825 Type resultType =
1826 getElementType(compositeType, indicesAttr, parser, attrLocation);
1827 if (!resultType) {
1828 return failure();
1829 }
1830 result.addTypes(resultType);
1831 return success();
1832}
1833
1834void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1835 printer << ' ' << getComposite() << getIndices() << " : "
1836 << getComposite().getType();
1837}
1838
1839LogicalResult spirv::CompositeExtractOp::verify() {
1840 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1841 auto resultType =
1842 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1843 if (!resultType)
1844 return failure();
1845
1846 if (resultType != getType()) {
1847 return emitOpError("invalid result type: expected ")
1848 << resultType << " but provided " << getType();
1849 }
1850
1851 return success();
1852}
1853
1854//===----------------------------------------------------------------------===//
1855// spirv.CompositeInsert
1856//===----------------------------------------------------------------------===//
1857
1858void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1859 Value object, Value composite,
1860 ArrayRef<int32_t> indices) {
1861 auto indexAttr = builder.getI32ArrayAttr(indices);
1862 build(builder, state, composite.getType(), object, composite, indexAttr);
1863}
1864
1865ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1866 OperationState &result) {
1867 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1868 Type objectType, compositeType;
1869 Attribute indicesAttr;
1870 auto loc = parser.getCurrentLocation();
1871
1872 return failure(
1873 parser.parseOperandList(operands, 2) ||
1874 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1875 parser.parseColonType(objectType) ||
1876 parser.parseKeywordType("into", compositeType) ||
1877 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1878 result.operands) ||
1879 parser.addTypesToList(compositeType, result.types));
1880}
1881
1882LogicalResult spirv::CompositeInsertOp::verify() {
1883 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1884 auto objectType =
1885 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1886 if (!objectType)
1887 return failure();
1888
1889 if (objectType != getObject().getType()) {
1890 return emitOpError("object operand type should be ")
1891 << objectType << ", but found " << getObject().getType();
1892 }
1893
1894 if (getComposite().getType() != getType()) {
1895 return emitOpError("result type should be the same as "
1896 "the composite type, but found ")
1897 << getComposite().getType() << " vs " << getType();
1898 }
1899
1900 return success();
1901}
1902
1903void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1904 printer << " " << getObject() << ", " << getComposite() << getIndices()
1905 << " : " << getObject().getType() << " into "
1906 << getComposite().getType();
1907}
1908
1909//===----------------------------------------------------------------------===//
1910// spirv.Constant
1911//===----------------------------------------------------------------------===//
1912
1913ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1914 OperationState &result) {
1915 Attribute value;
1916 if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1917 return failure();
1918
1919 Type type = NoneType::get(parser.getContext());
1920 if (auto typedAttr = value.dyn_cast<TypedAttr>())
1921 type = typedAttr.getType();
1922 if (type.isa<NoneType, TensorType>()) {
1923 if (parser.parseColonType(type))
1924 return failure();
1925 }
1926
1927 return parser.addTypeToList(type, result.types);
1928}
1929
1930void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1931 printer << ' ' << getValue();
1932 if (getType().isa<spirv::ArrayType>())
1933 printer << " : " << getType();
1934}
1935
1936static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1937 Type opType) {
1938 if (value.isa<IntegerAttr, FloatAttr>()) {
1939 auto valueType = value.cast<TypedAttr>().getType();
1940 if (valueType != opType)
1941 return op.emitOpError("result type (")
1942 << opType << ") does not match value type (" << valueType << ")";
1943 return success();
1944 }
1945 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1946 auto valueType = value.cast<TypedAttr>().getType();
1947 if (valueType == opType)
1948 return success();
1949 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1950 auto shapedType = valueType.dyn_cast<ShapedType>();
1951 if (!arrayType)
1952 return op.emitOpError("result or element type (")
1953 << opType << ") does not match value type (" << valueType
1954 << "), must be the same or spirv.array";
1955
1956 int numElements = arrayType.getNumElements();
1957 auto opElemType = arrayType.getElementType();
1958 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1959 numElements *= t.getNumElements();
1960 opElemType = t.getElementType();
1961 }
1962 if (!opElemType.isIntOrFloat())
1963 return op.emitOpError("only support nested array result type");
1964
1965 auto valueElemType = shapedType.getElementType();
1966 if (valueElemType != opElemType) {
1967 return op.emitOpError("result element type (")
1968 << opElemType << ") does not match value element type ("
1969 << valueElemType << ")";
1970 }
1971
1972 if (numElements != shapedType.getNumElements()) {
1973 return op.emitOpError("result number of elements (")
1974 << numElements << ") does not match value number of elements ("
1975 << shapedType.getNumElements() << ")";
1976 }
1977 return success();
1978 }
1979 if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
1980 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1981 if (!arrayType)
1982 return op.emitOpError(
1983 "must have spirv.array result type for array value");
1984 Type elemType = arrayType.getElementType();
1985 for (Attribute element : arrayAttr.getValue()) {
1986 // Verify array elements recursively.
1987 if (failed(verifyConstantType(op, element, elemType)))
1988 return failure();
1989 }
1990 return success();
1991 }
1992 return op.emitOpError("cannot have attribute: ") << value;
1993}
1994
1995LogicalResult spirv::ConstantOp::verify() {
1996 // ODS already generates checks to make sure the result type is valid. We just
1997 // need to additionally check that the value's attribute type is consistent
1998 // with the result type.
1999 return verifyConstantType(*this, getValueAttr(), getType());
2000}
2001
2002bool spirv::ConstantOp::isBuildableWith(Type type) {
2003 // Must be valid SPIR-V type first.
2004 if (!type.isa<spirv::SPIRVType>())
2005 return false;
2006
2007 if (isa<SPIRVDialect>(type.getDialect())) {
2008 // TODO: support constant struct
2009 return type.isa<spirv::ArrayType>();
2010 }
2011
2012 return true;
2013}
2014
2015spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
2016 OpBuilder &builder) {
2017 if (auto intType = type.dyn_cast<IntegerType>()) {
2018 unsigned width = intType.getWidth();
2019 if (width == 1)
2020 return builder.create<spirv::ConstantOp>(loc, type,
2021 builder.getBoolAttr(false));
2022 return builder.create<spirv::ConstantOp>(
2023 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
2024 }
2025 if (auto floatType = type.dyn_cast<FloatType>()) {
2026 return builder.create<spirv::ConstantOp>(
2027 loc, type, builder.getFloatAttr(floatType, 0.0));
2028 }
2029 if (auto vectorType = type.dyn_cast<VectorType>()) {
2030 Type elemType = vectorType.getElementType();
2031 if (elemType.isa<IntegerType>()) {
2032 return builder.create<spirv::ConstantOp>(
2033 loc, type,
2034 DenseElementsAttr::get(vectorType,
2035 IntegerAttr::get(elemType, 0.0).getValue()));
2036 }
2037 if (elemType.isa<FloatType>()) {
2038 return builder.create<spirv::ConstantOp>(
2039 loc, type,
2040 DenseFPElementsAttr::get(vectorType,
2041 FloatAttr::get(elemType, 0.0).getValue()));
2042 }
2043 }
2044
2045 llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2045)
;
2046}
2047
2048spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
2049 OpBuilder &builder) {
2050 if (auto intType = type.dyn_cast<IntegerType>()) {
2051 unsigned width = intType.getWidth();
2052 if (width == 1)
2053 return builder.create<spirv::ConstantOp>(loc, type,
2054 builder.getBoolAttr(true));
2055 return builder.create<spirv::ConstantOp>(
2056 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
2057 }
2058 if (auto floatType = type.dyn_cast<FloatType>()) {
2059 return builder.create<spirv::ConstantOp>(
2060 loc, type, builder.getFloatAttr(floatType, 1.0));
2061 }
2062 if (auto vectorType = type.dyn_cast<VectorType>()) {
2063 Type elemType = vectorType.getElementType();
2064 if (elemType.isa<IntegerType>()) {
2065 return builder.create<spirv::ConstantOp>(
2066 loc, type,
2067 DenseElementsAttr::get(vectorType,
2068 IntegerAttr::get(elemType, 1.0).getValue()));
2069 }
2070 if (elemType.isa<FloatType>()) {
2071 return builder.create<spirv::ConstantOp>(
2072 loc, type,
2073 DenseFPElementsAttr::get(vectorType,
2074 FloatAttr::get(elemType, 1.0).getValue()));
2075 }
2076 }
2077
2078 llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2078)
;
2079}
2080
2081void mlir::spirv::ConstantOp::getAsmResultNames(
2082 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2083 Type type = getType();
2084
2085 SmallString<32> specialNameBuffer;
2086 llvm::raw_svector_ostream specialName(specialNameBuffer);
2087 specialName << "cst";
2088
2089 IntegerType intTy = type.dyn_cast<IntegerType>();
2090
2091 if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2092 if (intTy && intTy.getWidth() == 1) {
2093 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2094 }
2095
2096 if (intTy.isSignless()) {
2097 specialName << intCst.getInt();
2098 } else {
2099 specialName << intCst.getSInt();
2100 }
2101 }
2102
2103 if (intTy || type.isa<FloatType>()) {
2104 specialName << '_' << type;
2105 }
2106
2107 if (auto vecType = type.dyn_cast<VectorType>()) {
2108 specialName << "_vec_";
2109 specialName << vecType.getDimSize(0);
2110
2111 Type elementType = vecType.getElementType();
2112
2113 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2114 specialName << "x" << elementType;
2115 }
2116 }
2117
2118 setNameFn(getResult(), specialName.str());
2119}
2120
2121void mlir::spirv::AddressOfOp::getAsmResultNames(
2122 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2123 SmallString<32> specialNameBuffer;
2124 llvm::raw_svector_ostream specialName(specialNameBuffer);
2125 specialName << getVariable() << "_addr";
2126 setNameFn(getResult(), specialName.str());
2127}
2128
2129//===----------------------------------------------------------------------===//
2130// spirv.ControlBarrierOp
2131//===----------------------------------------------------------------------===//
2132
2133LogicalResult spirv::ControlBarrierOp::verify() {
2134 return verifyMemorySemantics(getOperation(), getMemorySemantics());
2135}
2136
2137//===----------------------------------------------------------------------===//
2138// spirv.ConvertFToSOp
2139//===----------------------------------------------------------------------===//
2140
2141LogicalResult spirv::ConvertFToSOp::verify() {
2142 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2143 /*skipBitWidthCheck=*/true);
2144}
2145
2146//===----------------------------------------------------------------------===//
2147// spirv.ConvertFToUOp
2148//===----------------------------------------------------------------------===//
2149
2150LogicalResult spirv::ConvertFToUOp::verify() {
2151 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2152 /*skipBitWidthCheck=*/true);
2153}
2154
2155//===----------------------------------------------------------------------===//
2156// spirv.ConvertSToFOp
2157//===----------------------------------------------------------------------===//
2158
2159LogicalResult spirv::ConvertSToFOp::verify() {
2160 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2161 /*skipBitWidthCheck=*/true);
2162}
2163
2164//===----------------------------------------------------------------------===//
2165// spirv.ConvertUToFOp
2166//===----------------------------------------------------------------------===//
2167
2168LogicalResult spirv::ConvertUToFOp::verify() {
2169 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2170 /*skipBitWidthCheck=*/true);
2171}
2172
2173//===----------------------------------------------------------------------===//
2174// spirv.EntryPoint
2175//===----------------------------------------------------------------------===//
2176
2177void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2178 spirv::ExecutionModel executionModel,
2179 spirv::FuncOp function,
2180 ArrayRef<Attribute> interfaceVars) {
2181 build(builder, state,
2182 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2183 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2184}
2185
2186ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2187 OperationState &result) {
2188 spirv::ExecutionModel execModel;
2189 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2190 SmallVector<Type, 0> idTypes;
2191 SmallVector<Attribute, 4> interfaceVars;
2192
2193 FlatSymbolRefAttr fn;
2194 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2195 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2196 return failure();
2197 }
2198
2199 if (!parser.parseOptionalComma()) {
2200 // Parse the interface variables
2201 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2202 // The name of the interface variable attribute isnt important
2203 FlatSymbolRefAttr var;
2204 NamedAttrList attrs;
2205 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2206 return failure();
2207 interfaceVars.push_back(var);
2208 return success();
2209 }))
2210 return failure();
2211 }
2212 result.addAttribute(kInterfaceAttrName,
2213 parser.getBuilder().getArrayAttr(interfaceVars));
2214 return success();
2215}
2216
2217void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2218 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
2219 printer.printSymbolName(getFn());
2220 auto interfaceVars = getInterface().getValue();
2221 if (!interfaceVars.empty()) {
2222 printer << ", ";
2223 llvm::interleaveComma(interfaceVars, printer);
2224 }
2225}
2226
2227LogicalResult spirv::EntryPointOp::verify() {
2228 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2229 // verification.
2230 return success();
2231}
2232
2233//===----------------------------------------------------------------------===//
2234// spirv.ExecutionMode
2235//===----------------------------------------------------------------------===//
2236
2237void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2238 spirv::FuncOp function,
2239 spirv::ExecutionMode executionMode,
2240 ArrayRef<int32_t> params) {
2241 build(builder, state, SymbolRefAttr::get(function),
2242 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2243 builder.getI32ArrayAttr(params));
2244}
2245
2246ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2247 OperationState &result) {
2248 spirv::ExecutionMode execMode;
2249 Attribute fn;
2250 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2251 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2252 return failure();
2253 }
2254
2255 SmallVector<int32_t, 4> values;
2256 Type i32Type = parser.getBuilder().getIntegerType(32);
2257 while (!parser.parseOptionalComma()) {
2258 NamedAttrList attr;
2259 Attribute value;
2260 if (parser.parseAttribute(value, i32Type, "value", attr)) {
2261 return failure();
2262 }
2263 values.push_back(value.cast<IntegerAttr>().getInt());
2264 }
2265 result.addAttribute(kValuesAttrName,
2266 parser.getBuilder().getI32ArrayAttr(values));
2267 return success();
2268}
2269
2270void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2271 printer << " ";
2272 printer.printSymbolName(getFn());
2273 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
2274 auto values = this->getValues();
2275 if (values.empty())
2276 return;
2277 printer << ", ";
2278 llvm::interleaveComma(values, printer, [&](Attribute a) {
2279 printer << a.cast<IntegerAttr>().getInt();
2280 });
2281}
2282
2283//===----------------------------------------------------------------------===//
2284// spirv.FConvertOp
2285//===----------------------------------------------------------------------===//
2286
2287LogicalResult spirv::FConvertOp::verify() {
2288 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2289}
2290
2291//===----------------------------------------------------------------------===//
2292// spirv.SConvertOp
2293//===----------------------------------------------------------------------===//
2294
2295LogicalResult spirv::SConvertOp::verify() {
2296 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2297}
2298
2299//===----------------------------------------------------------------------===//
2300// spirv.UConvertOp
2301//===----------------------------------------------------------------------===//
2302
2303LogicalResult spirv::UConvertOp::verify() {
2304 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2305}
2306
2307//===----------------------------------------------------------------------===//
2308// spirv.func
2309//===----------------------------------------------------------------------===//
2310
2311ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2312 SmallVector<OpAsmParser::Argument> entryArgs;
2313 SmallVector<DictionaryAttr> resultAttrs;
2314 SmallVector<Type> resultTypes;
2315 auto &builder = parser.getBuilder();
2316
2317 // Parse the name as a symbol.
2318 StringAttr nameAttr;
2319 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2320 result.attributes))
2321 return failure();
2322
2323 // Parse the function signature.
2324 bool isVariadic = false;
2325 if (function_interface_impl::parseFunctionSignature(
2326 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2327 resultAttrs))
2328 return failure();
2329
2330 SmallVector<Type> argTypes;
2331 for (auto &arg : entryArgs)
2332 argTypes.push_back(arg.type);
2333 auto fnType = builder.getFunctionType(argTypes, resultTypes);
2334 result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2335 TypeAttr::get(fnType));
2336
2337 // Parse the optional function control keyword.
2338 spirv::FunctionControl fnControl;
2339 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2340 return failure();
2341
2342 // If additional attributes are present, parse them.
2343 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2344 return failure();
2345
2346 // Add the attributes to the function arguments.
2347 assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes.
size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2347, __extension__
__PRETTY_FUNCTION__))
;
2348 function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
2349 resultAttrs);
2350
2351 // Parse the optional function body.
2352 auto *body = result.addRegion();
2353 OptionalParseResult parseResult =
2354 parser.parseOptionalRegion(*body, entryArgs);
2355 return failure(parseResult.has_value() && failed(*parseResult));
2356}
2357
2358void spirv::FuncOp::print(OpAsmPrinter &printer) {
2359 // Print function name, signature, and control.
2360 printer << " ";
2361 printer.printSymbolName(getSymName());
2362 auto fnType = getFunctionType();
2363 function_interface_impl::printFunctionSignature(
2364 printer, *this, fnType.getInputs(),
2365 /*isVariadic=*/false, fnType.getResults());
2366 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
2367 << "\"";
2368 function_interface_impl::printFunctionAttributes(
2369 printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2370 {spirv::attributeName<spirv::FunctionControl>()});
2371
2372 // Print the body if this is not an external function.
2373 Region &body = this->getBody();
2374 if (!body.empty()) {
2375 printer << ' ';
2376 printer.printRegion(body, /*printEntryBlockArgs=*/false,
2377 /*printBlockTerminators=*/true);
2378 }
2379}
2380
2381LogicalResult spirv::FuncOp::verifyType() {
2382 auto type = getFunctionTypeAttr().getValue();
2383 if (!type.isa<FunctionType>())
2384 return emitOpError("requires '" + getTypeAttrName() +
2385 "' attribute of function type");
2386 if (getFunctionType().getNumResults() > 1)
2387 return emitOpError("cannot have more than one result");
2388 return success();
2389}
2390
2391LogicalResult spirv::FuncOp::verifyBody() {
2392 FunctionType fnType = getFunctionType();
2393
2394 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2395 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2396 if (fnType.getNumResults() != 0)
2397 return retOp.emitOpError("cannot be used in functions returning value");
2398 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2399 if (fnType.getNumResults() != 1)
2400 return retOp.emitOpError(
2401 "returns 1 value but enclosing function requires ")
2402 << fnType.getNumResults() << " results";
2403
2404 auto retOperandType = retOp.getValue().getType();
2405 auto fnResultType = fnType.getResult(0);
2406 if (retOperandType != fnResultType)
2407 return retOp.emitOpError(" return value's type (")
2408 << retOperandType << ") mismatch with function's result type ("
2409 << fnResultType << ")";
2410 }
2411 return WalkResult::advance();
2412 });
2413
2414 // TODO: verify other bits like linkage type.
2415
2416 return failure(walkResult.wasInterrupted());
2417}
2418
2419void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2420 StringRef name, FunctionType type,
2421 spirv::FunctionControl control,
2422 ArrayRef<NamedAttribute> attrs) {
2423 state.addAttribute(SymbolTable::getSymbolAttrName(),
2424 builder.getStringAttr(name));
2425 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2426 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2427 builder.getAttr<spirv::FunctionControlAttr>(control));
2428 state.attributes.append(attrs.begin(), attrs.end());
2429 state.addRegion();
2430}
2431
2432// CallableOpInterface
2433Region *spirv::FuncOp::getCallableRegion() {
2434 return isExternal() ? nullptr : &getBody();
2435}
2436
2437// CallableOpInterface
2438ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2439 return getFunctionType().getResults();
2440}
2441
2442//===----------------------------------------------------------------------===//
2443// spirv.FunctionCall
2444//===----------------------------------------------------------------------===//
2445
2446LogicalResult spirv::FunctionCallOp::verify() {
2447 auto fnName = getCalleeAttr();
2448
2449 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2450 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2451 if (!funcOp) {
2452 return emitOpError("callee function '")
2453 << fnName.getValue() << "' not found in nearest symbol table";
2454 }
2455
2456 auto functionType = funcOp.getFunctionType();
2457
2458 if (getNumResults() > 1) {
2459 return emitOpError(
2460 "expected callee function to have 0 or 1 result, but provided ")
2461 << getNumResults();
2462 }
2463
2464 if (functionType.getNumInputs() != getNumOperands()) {
2465 return emitOpError("has incorrect number of operands for callee: expected ")
2466 << functionType.getNumInputs() << ", but provided "
2467 << getNumOperands();
2468 }
2469
2470 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2471 if (getOperand(i).getType() != functionType.getInput(i)) {
2472 return emitOpError("operand type mismatch: expected operand type ")
2473 << functionType.getInput(i) << ", but provided "
2474 << getOperand(i).getType() << " for operand number " << i;
2475 }
2476 }
2477
2478 if (functionType.getNumResults() != getNumResults()) {
2479 return emitOpError(
2480 "has incorrect number of results has for callee: expected ")
2481 << functionType.getNumResults() << ", but provided "
2482 << getNumResults();
2483 }
2484
2485 if (getNumResults() &&
2486 (getResult(0).getType() != functionType.getResult(0))) {
2487 return emitOpError("result type mismatch: expected ")
2488 << functionType.getResult(0) << ", but provided "
2489 << getResult(0).getType();
2490 }
2491
2492 return success();
2493}
2494
2495CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2496 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2497}
2498
2499Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2500 return getArguments();
2501}
2502
2503//===----------------------------------------------------------------------===//
2504// spirv.GLFClampOp
2505//===----------------------------------------------------------------------===//
2506
2507ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2508 OperationState &result) {
2509 return parseOneResultSameOperandTypeOp(parser, result);
2510}
2511void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2512
2513//===----------------------------------------------------------------------===//
2514// spirv.GLUClampOp
2515//===----------------------------------------------------------------------===//
2516
2517ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2518 OperationState &result) {
2519 return parseOneResultSameOperandTypeOp(parser, result);
2520}
2521void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2522
2523//===----------------------------------------------------------------------===//
2524// spirv.GLSClampOp
2525//===----------------------------------------------------------------------===//
2526
2527ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2528 OperationState &result) {
2529 return parseOneResultSameOperandTypeOp(parser, result);
2530}
2531void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2532
2533//===----------------------------------------------------------------------===//
2534// spirv.GLFmaOp
2535//===----------------------------------------------------------------------===//
2536
2537ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2538 return parseOneResultSameOperandTypeOp(parser, result);
2539}
2540void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2541
2542//===----------------------------------------------------------------------===//
2543// spirv.GlobalVariable
2544//===----------------------------------------------------------------------===//
2545
2546void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2547 Type type, StringRef name,
2548 unsigned descriptorSet, unsigned binding) {
2549 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2550 state.addAttribute(
2551 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2552 builder.getI32IntegerAttr(descriptorSet));
2553 state.addAttribute(
2554 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2555 builder.getI32IntegerAttr(binding));
2556}
2557
2558void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2559 Type type, StringRef name,
2560 spirv::BuiltIn builtin) {
2561 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2562 state.addAttribute(
2563 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2564 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2565}
2566
2567ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2568 OperationState &result) {
2569 // Parse variable name.
2570 StringAttr nameAttr;
2571 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2572 result.attributes)) {
2573 return failure();
2574 }
2575
2576 // Parse optional initializer
2577 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2578 FlatSymbolRefAttr initSymbol;
2579 if (parser.parseLParen() ||
2580 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2581 result.attributes) ||
2582 parser.parseRParen())
2583 return failure();
2584 }
2585
2586 if (parseVariableDecorations(parser, result)) {
2587 return failure();
2588 }
2589
2590 Type type;
2591 auto loc = parser.getCurrentLocation();
2592 if (parser.parseColonType(type)) {
2593 return failure();
2594 }
2595 if (!type.isa<spirv::PointerType>()) {
2596 return parser.emitError(loc, "expected spirv.ptr type");
2597 }
2598 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2599
2600 return success();
2601}
2602
2603void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2604 SmallVector<StringRef, 4> elidedAttrs{
2605 spirv::attributeName<spirv::StorageClass>()};
2606
2607 // Print variable name.
2608 printer << ' ';
2609 printer.printSymbolName(getSymName());
2610 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2611
2612 // Print optional initializer
2613 if (auto initializer = this->getInitializer()) {
2614 printer << " " << kInitializerAttrName << '(';
2615 printer.printSymbolName(*initializer);
2616 printer << ')';
2617 elidedAttrs.push_back(kInitializerAttrName);
2618 }
2619
2620 elidedAttrs.push_back(kTypeAttrName);
2621 printVariableDecorations(*this, printer, elidedAttrs);
2622 printer << " : " << getType();
2623}
2624
2625LogicalResult spirv::GlobalVariableOp::verify() {
2626 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2627 // object. It cannot be Generic. It must be the same as the Storage Class
2628 // operand of the Result Type."
2629 // Also, Function storage class is reserved by spirv.Variable.
2630 auto storageClass = this->storageClass();
2631 if (storageClass == spirv::StorageClass::Generic ||
2632 storageClass == spirv::StorageClass::Function) {
2633 return emitOpError("storage class cannot be '")
2634 << stringifyStorageClass(storageClass) << "'";
2635 }
2636
2637 if (auto init =
2638 (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2639 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2640 (*this)->getParentOp(), init.getAttr());
2641 // TODO: Currently only variable initialization with specialization
2642 // constants and other variables is supported. They could be normal
2643 // constants in the module scope as well.
2644 if (!initOp ||
2645 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2646 return emitOpError("initializer must be result of a "
2647 "spirv.SpecConstant or spirv.GlobalVariable op");
2648 }
2649 }
2650
2651 return success();
2652}
2653
2654//===----------------------------------------------------------------------===//
2655// spirv.GroupBroadcast
2656//===----------------------------------------------------------------------===//
2657
2658LogicalResult spirv::GroupBroadcastOp::verify() {
2659 spirv::Scope scope = getExecutionScope();
2660 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2661 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2662
2663 if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2664 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2665 return emitOpError("localid is a vector and can be with only "
2666 " 2 or 3 components, actual number is ")
2667 << localIdTy.getNumElements();
2668
2669 return success();
2670}
2671
2672//===----------------------------------------------------------------------===//
2673// spirv.GroupNonUniformBallotOp
2674//===----------------------------------------------------------------------===//
2675
2676LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2677 spirv::Scope scope = getExecutionScope();
2678 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2679 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2680
2681 return success();
2682}
2683
2684//===----------------------------------------------------------------------===//
2685// spirv.GroupNonUniformBroadcast
2686//===----------------------------------------------------------------------===//
2687
2688LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2689 spirv::Scope scope = getExecutionScope();
2690 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2691 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2692
2693 // SPIR-V spec: "Before version 1.5, Id must come from a
2694 // constant instruction.
2695 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2696 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2697 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2698
2699 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2700 auto *idOp = getId().getDefiningOp();
2701 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2702 spirv::ReferenceOfOp>(idOp)) // for spec constant
2703 return emitOpError("id must be the result of a constant op");
2704 }
2705
2706 return success();
2707}
2708
2709//===----------------------------------------------------------------------===//
2710// spirv.GroupNonUniformShuffle*
2711//===----------------------------------------------------------------------===//
2712
2713template <typename OpTy>
2714static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
2715 spirv::Scope scope = op.getExecutionScope();
2716 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2717 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2718
2719 if (op.getOperands().back().getType().isSignedInteger())
2720 return op.emitOpError("second operand must be a singless/unsigned integer");
2721
2722 return success();
2723}
2724
2725LogicalResult spirv::GroupNonUniformShuffleOp::verify() {
2726 return verifyGroupNonUniformShuffleOp(*this);
2727}
2728LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() {
2729 return verifyGroupNonUniformShuffleOp(*this);
2730}
2731LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() {
2732 return verifyGroupNonUniformShuffleOp(*this);
2733}
2734LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
2735 return verifyGroupNonUniformShuffleOp(*this);
2736}
2737
2738//===----------------------------------------------------------------------===//
2739// spirv.INTEL.SubgroupBlockRead
2740//===----------------------------------------------------------------------===//
2741
2742ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
2743 OperationState &result) {
2744 // Parse the storage class specification
2745 spirv::StorageClass storageClass;
2746 OpAsmParser::UnresolvedOperand ptrInfo;
2747 Type elementType;
2748 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2749 parser.parseColon() || parser.parseType(elementType)) {
2750 return failure();
2751 }
2752
2753 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2754 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2755 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2756
2757 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2758 return failure();
2759 }
2760
2761 result.addTypes(elementType);
2762 return success();
2763}
2764
2765void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
2766 printer << " " << getPtr() << " : " << getType();
2767}
2768
2769LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
2770 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2771 return failure();
2772
2773 return success();
2774}
2775
2776//===----------------------------------------------------------------------===//
2777// spirv.INTEL.SubgroupBlockWrite
2778//===----------------------------------------------------------------------===//
2779
2780ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
2781 OperationState &result) {
2782 // Parse the storage class specification
2783 spirv::StorageClass storageClass;
2784 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2785 auto loc = parser.getCurrentLocation();
2786 Type elementType;
2787 if (parseEnumStrAttr(storageClass, parser) ||
2788 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2789 parser.parseType(elementType)) {
2790 return failure();
2791 }
2792
2793 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2794 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2795 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2796
2797 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2798 result.operands)) {
2799 return failure();
2800 }
2801 return success();
2802}
2803
2804void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
2805 printer << " " << getPtr() << ", " << getValue() << " : "
2806 << getValue().getType();
2807}
2808
2809LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
2810 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2811 return failure();
2812
2813 return success();
2814}
2815
2816//===----------------------------------------------------------------------===//
2817// spirv.GroupNonUniformElectOp
2818//===----------------------------------------------------------------------===//
2819
2820LogicalResult spirv::GroupNonUniformElectOp::verify() {
2821 spirv::Scope scope = getExecutionScope();
2822 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2823 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2824
2825 return success();
2826}
2827
2828//===----------------------------------------------------------------------===//
2829// spirv.GroupNonUniformFAddOp
2830//===----------------------------------------------------------------------===//
2831
2832LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2833 return verifyGroupNonUniformArithmeticOp(*this);
2834}
2835
2836ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2837 OperationState &result) {
2838 return parseGroupNonUniformArithmeticOp(parser, result);
2839}
2840void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2841 printGroupNonUniformArithmeticOp(*this, p);
2842}
2843
2844//===----------------------------------------------------------------------===//
2845// spirv.GroupNonUniformFMaxOp
2846//===----------------------------------------------------------------------===//
2847
2848LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2849 return verifyGroupNonUniformArithmeticOp(*this);
2850}
2851
2852ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2853 OperationState &result) {
2854 return parseGroupNonUniformArithmeticOp(parser, result);
2855}
2856void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2857 printGroupNonUniformArithmeticOp(*this, p);
2858}
2859
2860//===----------------------------------------------------------------------===//
2861// spirv.GroupNonUniformFMinOp
2862//===----------------------------------------------------------------------===//
2863
2864LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2865 return verifyGroupNonUniformArithmeticOp(*this);
2866}
2867
2868ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2869 OperationState &result) {
2870 return parseGroupNonUniformArithmeticOp(parser, result);
2871}
2872void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2873 printGroupNonUniformArithmeticOp(*this, p);
2874}
2875
2876//===----------------------------------------------------------------------===//
2877// spirv.GroupNonUniformFMulOp
2878//===----------------------------------------------------------------------===//
2879
2880LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2881 return verifyGroupNonUniformArithmeticOp(*this);
2882}
2883
2884ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2885 OperationState &result) {
2886 return parseGroupNonUniformArithmeticOp(parser, result);
2887}
2888void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2889 printGroupNonUniformArithmeticOp(*this, p);
2890}
2891
2892//===----------------------------------------------------------------------===//
2893// spirv.GroupNonUniformIAddOp
2894//===----------------------------------------------------------------------===//
2895
2896LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2897 return verifyGroupNonUniformArithmeticOp(*this);
2898}
2899
2900ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2901 OperationState &result) {
2902 return parseGroupNonUniformArithmeticOp(parser, result);
2903}
2904void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2905 printGroupNonUniformArithmeticOp(*this, p);
2906}
2907
2908//===----------------------------------------------------------------------===//
2909// spirv.GroupNonUniformIMulOp
2910//===----------------------------------------------------------------------===//
2911
2912LogicalResult spirv::GroupNonUniformIMulOp::verify() {
2913 return verifyGroupNonUniformArithmeticOp(*this);
2914}
2915
2916ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2917 OperationState &result) {
2918 return parseGroupNonUniformArithmeticOp(parser, result);
2919}
2920void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
2921 printGroupNonUniformArithmeticOp(*this, p);
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// spirv.GroupNonUniformSMaxOp
2926//===----------------------------------------------------------------------===//
2927
2928LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
2929 return verifyGroupNonUniformArithmeticOp(*this);
2930}
2931
2932ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2933 OperationState &result) {
2934 return parseGroupNonUniformArithmeticOp(parser, result);
2935}
2936void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
2937 printGroupNonUniformArithmeticOp(*this, p);
2938}
2939
2940//===----------------------------------------------------------------------===//
2941// spirv.GroupNonUniformSMinOp
2942//===----------------------------------------------------------------------===//
2943
2944LogicalResult spirv::GroupNonUniformSMinOp::verify() {
2945 return verifyGroupNonUniformArithmeticOp(*this);
2946}
2947
2948ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2949 OperationState &result) {
2950 return parseGroupNonUniformArithmeticOp(parser, result);
2951}
2952void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
2953 printGroupNonUniformArithmeticOp(*this, p);
2954}
2955
2956//===----------------------------------------------------------------------===//
2957// spirv.GroupNonUniformUMaxOp
2958//===----------------------------------------------------------------------===//
2959
2960LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
2961 return verifyGroupNonUniformArithmeticOp(*this);
2962}
2963
2964ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
2965 OperationState &result) {
2966 return parseGroupNonUniformArithmeticOp(parser, result);
2967}
2968void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
2969 printGroupNonUniformArithmeticOp(*this, p);
2970}
2971
2972//===----------------------------------------------------------------------===//
2973// spirv.GroupNonUniformUMinOp
2974//===----------------------------------------------------------------------===//
2975
2976LogicalResult spirv::GroupNonUniformUMinOp::verify() {
2977 return verifyGroupNonUniformArithmeticOp(*this);
2978}
2979
2980ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
2981 OperationState &result) {
2982 return parseGroupNonUniformArithmeticOp(parser, result);
2983}
2984void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
2985 printGroupNonUniformArithmeticOp(*this, p);
2986}
2987
2988//===----------------------------------------------------------------------===//
2989// spirv.IAddCarryOp
2990//===----------------------------------------------------------------------===//
2991
2992LogicalResult spirv::IAddCarryOp::verify() {
2993 auto resultType = getType().cast<spirv::StructType>();
2994 if (resultType.getNumElements() != 2)
2995 return emitOpError("expected result struct type containing two members");
2996
2997 if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
2998 resultType.getElementType(0),
2999 resultType.getElementType(1)}))
3000 return emitOpError(
3001 "expected all operand types and struct member types are the same");
3002
3003 return success();
3004}
3005
3006ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
3007 OperationState &result) {
3008 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
3009 if (parser.parseOptionalAttrDict(result.attributes) ||
3010 parser.parseOperandList(operands) || parser.parseColon())
3011 return failure();
3012
3013 Type resultType;
3014 SMLoc loc = parser.getCurrentLocation();
3015 if (parser.parseType(resultType))
3016 return failure();
3017
3018 auto structType = resultType.dyn_cast<spirv::StructType>();
3019 if (!structType || structType.getNumElements() != 2)
3020 return parser.emitError(loc, "expected spirv.struct type with two members");
3021
3022 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
3023 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
3024 return failure();
3025
3026 result.addTypes(resultType);
3027 return success();
3028}
3029
3030void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
3031 printer << ' ';
3032 printer.printOptionalAttrDict((*this)->getAttrs());
3033 printer.printOperands((*this)->getOperands());
3034 printer << " : " << getType();
3035}
3036
3037//===----------------------------------------------------------------------===//
3038// spirv.ISubBorrowOp
3039//===----------------------------------------------------------------------===//
3040
3041LogicalResult spirv::ISubBorrowOp::verify() {
3042 auto resultType = getType().cast<spirv::StructType>();
3043 if (resultType.getNumElements() != 2)
3044 return emitOpError("expected result struct type containing two members");
3045
3046 if (!llvm::all_equal({getOperand1().getType(), getOperand2().getType(),
3047 resultType.getElementType(0),
3048 resultType.getElementType(1)}))
3049 return emitOpError(
3050 "expected all operand types and struct member types are the same");
3051
3052 return success();
3053}
3054
3055ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
3056 OperationState &result) {
3057 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
3058 if (parser.parseOptionalAttrDict(result.attributes) ||
3059 parser.parseOperandList(operands) || parser.parseColon())
3060 return failure();
3061
3062 Type resultType;
3063 auto loc = parser.getCurrentLocation();
3064 if (parser.parseType(resultType))
3065 return failure();
3066
3067 auto structType = resultType.dyn_cast<spirv::StructType>();
3068 if (!structType || structType.getNumElements() != 2)
3069 return parser.emitError(loc, "expected spirv.struct type with two members");
3070
3071 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
3072 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
3073 return failure();
3074
3075 result.addTypes(resultType);
3076 return success();
3077}
3078
3079void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
3080 printer << ' ';
3081 printer.printOptionalAttrDict((*this)->getAttrs());
3082 printer.printOperands((*this)->getOperands());
3083 printer << " : " << getType();
3084}
3085
3086//===----------------------------------------------------------------------===//
3087// spirv.LoadOp
3088//===----------------------------------------------------------------------===//
3089
3090void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
3091 Value basePtr, MemoryAccessAttr memoryAccess,
3092 IntegerAttr alignment) {
3093 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3094 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3095 alignment);
3096}
3097
3098ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3099 // Parse the storage class specification
3100 spirv::StorageClass storageClass;
3101 OpAsmParser::UnresolvedOperand ptrInfo;
3102 Type elementType;
3103 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3104 parseMemoryAccessAttributes(parser, result) ||
3105 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3106 parser.parseType(elementType)) {
3107 return failure();
3108 }
3109
3110 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3111 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3112 return failure();
3113 }
3114
3115 result.addTypes(elementType);
3116 return success();
3117}
3118
3119void spirv::LoadOp::print(OpAsmPrinter &printer) {
3120 SmallVector<StringRef, 4> elidedAttrs;
3121 StringRef sc = stringifyStorageClass(
3122 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3123 printer << " \"" << sc << "\" " << getPtr();
3124
3125 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3126
3127 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3128 printer << " : " << getType();
3129}
3130
3131LogicalResult spirv::LoadOp::verify() {
3132 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3133 // type with fixed size; i.e., it cannot be, nor include, any
3134 // OpTypeRuntimeArray types."
3135 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
3136 return failure();
3137 }
3138 return verifyMemoryAccessAttribute(*this);
3139}
3140
3141//===----------------------------------------------------------------------===//
3142// spirv.mlir.loop
3143//===----------------------------------------------------------------------===//
3144
3145void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3146 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3147 spirv::LoopControl::None));
3148 state.addRegion();
3149}
3150
3151ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3152 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3153 result))
3154 return failure();
3155 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3156}
3157
3158void spirv::LoopOp::print(OpAsmPrinter &printer) {
3159 auto control = getLoopControl();
3160 if (control != spirv::LoopControl::None)
3161 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3162 printer << ' ';
3163 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3164 /*printBlockTerminators=*/true);
3165}
3166
3167/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
3168/// given `dstBlock`.
3169static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3170 // Check that there is only one op in the `srcBlock`.
3171 if (!llvm::hasSingleElement(srcBlock))
3172 return false;
3173
3174 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3175 return branchOp && branchOp.getSuccessor() == &dstBlock;
3176}
3177
3178LogicalResult spirv::LoopOp::verifyRegions() {
3179 auto *op = getOperation();
3180
3181 // We need to verify that the blocks follow the following layout:
3182 //
3183 // +-------------+
3184 // | entry block |
3185 // +-------------+
3186 // |
3187 // v
3188 // +-------------+
3189 // | loop header | <-----+
3190 // +-------------+ |
3191 // |
3192 // ... |
3193 // \ | / |
3194 // v |
3195 // +---------------+ |
3196 // | loop continue | -----+
3197 // +---------------+
3198 //
3199 // ...
3200 // \ | /
3201 // v
3202 // +-------------+
3203 // | merge block |
3204 // +-------------+
3205
3206 auto &region = op->getRegion(0);
3207 // Allow empty region as a degenerated case, which can come from
3208 // optimizations.
3209 if (region.empty())
3210 return success();
3211
3212 // The last block is the merge block.
3213 Block &merge = region.back();
3214 if (!isMergeBlock(merge))
3215 return emitOpError("last block must be the merge block with only one "
3216 "'spirv.mlir.merge' op");
3217
3218 if (std::next(region.begin()) == region.end())
3219 return emitOpError(
3220 "must have an entry block branching to the loop header block");
3221 // The first block is the entry block.
3222 Block &entry = region.front();
3223
3224 if (std::next(region.begin(), 2) == region.end())
3225 return emitOpError(
3226 "must have a loop header block branched from the entry block");
3227 // The second block is the loop header block.
3228 Block &header = *std::next(region.begin(), 1);
3229
3230 if (!hasOneBranchOpTo(entry, header))
3231 return emitOpError(
3232 "entry block must only have one 'spirv.Branch' op to the second block");
3233
3234 if (std::next(region.begin(), 3) == region.end())
3235 return emitOpError(
3236 "requires a loop continue block branching to the loop header block");
3237 // The second to last block is the loop continue block.
3238 Block &cont = *std::prev(region.end(), 2);
3239
3240 // Make sure that we have a branch from the loop continue block to the loop
3241 // header block.
3242 if (llvm::none_of(
3243 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3244 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3245 return emitOpError("second to last block must be the loop continue "
3246 "block that branches to the loop header block");
3247
3248 // Make sure that no other blocks (except the entry and loop continue block)
3249 // branches to the loop header block.
3250 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3251 std::prev(region.end(), 2))) {
3252 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3253 if (block.getSuccessor(i) == &header) {
3254 return emitOpError("can only have the entry and loop continue "
3255 "block branching to the loop header block");
3256 }
3257 }
3258 }
3259
3260 return success();
3261}
3262
3263Block *spirv::LoopOp::getEntryBlock() {
3264 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3264, __extension__
__PRETTY_FUNCTION__))
;
3265 return &getBody().front();
3266}
3267
3268Block *spirv::LoopOp::getHeaderBlock() {
3269 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3269, __extension__
__PRETTY_FUNCTION__))
;
3270 // The second block is the loop header block.
3271 return &*std::next(getBody().begin());
3272}
3273
3274Block *spirv::LoopOp::getContinueBlock() {
3275 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3275, __extension__
__PRETTY_FUNCTION__))
;
3276 // The second to last block is the loop continue block.
3277 return &*std::prev(getBody().end(), 2);
3278}
3279
3280Block *spirv::LoopOp::getMergeBlock() {
3281 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3281, __extension__
__PRETTY_FUNCTION__))
;
3282 // The last block is the loop merge block.
3283 return &getBody().back();
3284}
3285
3286void spirv::LoopOp::addEntryAndMergeBlock() {
3287 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3287, __extension__
__PRETTY_FUNCTION__))
;
3288 getBody().push_back(new Block());
3289 auto *mergeBlock = new Block();
3290 getBody().push_back(mergeBlock);
3291 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3292
3293 // Add a spirv.mlir.merge op into the merge block.
3294 builder.create<spirv::MergeOp>(getLoc());
3295}
3296
3297//===----------------------------------------------------------------------===//
3298// spirv.MemoryBarrierOp
3299//===----------------------------------------------------------------------===//
3300
3301LogicalResult spirv::MemoryBarrierOp::verify() {
3302 return verifyMemorySemantics(getOperation(), getMemorySemantics());
3303}
3304
3305//===----------------------------------------------------------------------===//
3306// spirv.mlir.merge
3307//===----------------------------------------------------------------------===//
3308
3309LogicalResult spirv::MergeOp::verify() {
3310 auto *parentOp = (*this)->getParentOp();
3311 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3312 return emitOpError(
3313 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3314
3315 // TODO: This check should be done in `verifyRegions` of parent op.
3316 Block &parentLastBlock = (*this)->getParentRegion()->back();
3317 if (getOperation() != parentLastBlock.getTerminator())
3318 return emitOpError("can only be used in the last block of "
3319 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3320 return success();
3321}
3322
3323//===----------------------------------------------------------------------===//
3324// spirv.module
3325//===----------------------------------------------------------------------===//
3326
3327void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3328 Optional<StringRef> name) {
3329 OpBuilder::InsertionGuard guard(builder);
3330 builder.createBlock(state.addRegion());
3331 if (name) {
3332 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3333 builder.getStringAttr(*name));
3334 }
3335}
3336
3337void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3338 spirv::AddressingModel addressingModel,
3339 spirv::MemoryModel memoryModel,
3340 Optional<VerCapExtAttr> vceTriple,
3341 Optional<StringRef> name) {
3342 state.addAttribute(
3343 "addressing_model",
3344 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3345 state.addAttribute("memory_model",
3346 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3347 OpBuilder::InsertionGuard guard(builder);
3348 builder.createBlock(state.addRegion());
3349 if (vceTriple)
3350 state.addAttribute(getVCETripleAttrName(), *vceTriple);
3351 if (name)
3352 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3353 builder.getStringAttr(*name));
3354}
3355
3356ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3357 OperationState &result) {
3358 Region *body = result.addRegion();
3359
3360 // If the name is present, parse it.
3361 StringAttr nameAttr;
3362 (void)parser.parseOptionalSymbolName(
3363 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3364
3365 // Parse attributes
3366 spirv::AddressingModel addrModel;
3367 spirv::MemoryModel memoryModel;
3368 if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3369 result) ||
3370 ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3371 result))
3372 return failure();
3373
3374 if (succeeded(parser.parseOptionalKeyword("requires"))) {
3375 spirv::VerCapExtAttr vceTriple;
3376 if (parser.parseAttribute(vceTriple,
3377 spirv::ModuleOp::getVCETripleAttrName(),
3378 result.attributes))
3379 return failure();
3380 }
3381
3382 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3383 parser.parseRegion(*body, /*arguments=*/{}))
3384 return failure();
3385
3386 // Make sure we have at least one block.
3387 if (body->empty())
3388 body->push_back(new Block());
3389
3390 return success();
3391}
3392
3393void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3394 if (Optional<StringRef> name = getName()) {
3395 printer << ' ';
3396 printer.printSymbolName(*name);
3397 }
3398
3399 SmallVector<StringRef, 2> elidedAttrs;
3400
3401 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
3402 << spirv::stringifyMemoryModel(getMemoryModel());
3403 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3404 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3405 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3406 mlir::SymbolTable::getSymbolAttrName()});
3407
3408 if (Optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3409 printer << " requires " << *triple;
3410 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3411 }
3412
3413 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3414 printer << ' ';
3415 printer.printRegion(getRegion());
3416}
3417
3418LogicalResult spirv::ModuleOp::verifyRegions() {
3419 Dialect *dialect = (*this)->getDialect();
3420 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3421 entryPoints;
3422 mlir::SymbolTable table(*this);
3423
3424 for (auto &op : *getBody()) {
3425 if (op.getDialect() != dialect)
3426 return op.emitError("'spirv.module' can only contain spirv.* ops");
3427
3428 // For EntryPoint op, check that the function and execution model is not
3429 // duplicated in EntryPointOps. Also verify that the interface specified
3430 // comes from globalVariables here to make this check cheaper.
3431 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3432 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3433 if (!funcOp) {
3434 return entryPointOp.emitError("function '")
3435 << entryPointOp.getFn() << "' not found in 'spirv.module'";
3436 }
3437 if (auto interface = entryPointOp.getInterface()) {
3438 for (Attribute varRef : interface) {
3439 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3440 if (!varSymRef) {
3441 return entryPointOp.emitError(
3442 "expected symbol reference for interface "
3443 "specification instead of '")
3444 << varRef;
3445 }
3446 auto variableOp =
3447 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3448 if (!variableOp) {
3449 return entryPointOp.emitError("expected spirv.GlobalVariable "
3450 "symbol reference instead of'")
3451 << varSymRef << "'";
3452 }
3453 }
3454 }
3455
3456 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3457 funcOp, entryPointOp.getExecutionModel());
3458 auto entryPtIt = entryPoints.find(key);
3459 if (entryPtIt != entryPoints.end()) {
3460 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3461 }
3462 entryPoints[key] = entryPointOp;
3463 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3464 if (funcOp.isExternal())
3465 return op.emitError("'spirv.module' cannot contain external functions");
3466
3467 // TODO: move this check to spirv.func.
3468 for (auto &block : funcOp)
3469 for (auto &op : block) {
3470 if (op.getDialect() != dialect)
3471 return op.emitError(
3472 "functions in 'spirv.module' can only contain spirv.* ops");
3473 }
3474 }
3475 }
3476
3477 return success();
3478}
3479
3480//===----------------------------------------------------------------------===//
3481// spirv.mlir.referenceof
3482//===----------------------------------------------------------------------===//
3483
3484LogicalResult spirv::ReferenceOfOp::verify() {
3485 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3486 (*this)->getParentOp(), getSpecConstAttr());
3487 Type constType;
3488
3489 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3490 if (specConstOp)
3491 constType = specConstOp.getDefaultValue().getType();
3492
3493 auto specConstCompositeOp =
3494 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3495 if (specConstCompositeOp)
3496 constType = specConstCompositeOp.getType();
3497
3498 if (!specConstOp && !specConstCompositeOp)
3499 return emitOpError(
3500 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3501
3502 if (getReference().getType() != constType)
3503 return emitOpError("result type mismatch with the referenced "
3504 "specialization constant's type");
3505
3506 return success();
3507}
3508
3509//===----------------------------------------------------------------------===//
3510// spirv.Return
3511//===----------------------------------------------------------------------===//
3512
3513LogicalResult spirv::ReturnOp::verify() {
3514 // Verification is performed in spirv.func op.
3515 return success();
3516}
3517
3518//===----------------------------------------------------------------------===//
3519// spirv.ReturnValue
3520//===----------------------------------------------------------------------===//
3521
3522LogicalResult spirv::ReturnValueOp::verify() {
3523 // Verification is performed in spirv.func op.
3524 return success();
3525}
3526
3527//===----------------------------------------------------------------------===//
3528// spirv.Select
3529//===----------------------------------------------------------------------===//
3530
3531LogicalResult spirv::SelectOp::verify() {
3532 if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3533 auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3534 if (!resultVectorTy) {
3535 return emitOpError("result expected to be of vector type when "
3536 "condition is of vector type");
3537 }
3538 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3539 return emitOpError("result should have the same number of elements as "
3540 "the condition when condition is of vector type");
3541 }
3542 }
3543 return success();
3544}
3545
3546//===----------------------------------------------------------------------===//
3547// spirv.mlir.selection
3548//===----------------------------------------------------------------------===//
3549
3550ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3551 OperationState &result) {
3552 if (parseControlAttribute<spirv::SelectionControlAttr,
3553 spirv::SelectionControl>(parser, result))
3554 return failure();
3555 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3556}
3557
3558void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3559 auto control = getSelectionControl();
3560 if (control != spirv::SelectionControl::None)
3561 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3562 printer << ' ';
3563 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3564 /*printBlockTerminators=*/true);
3565}
3566
3567LogicalResult spirv::SelectionOp::verifyRegions() {
3568 auto *op = getOperation();
3569
3570 // We need to verify that the blocks follow the following layout:
3571 //
3572 // +--------------+
3573 // | header block |
3574 // +--------------+
3575 // / | \
3576 // ...
3577 //
3578 //
3579 // +---------+ +---------+ +---------+
3580 // | case #0 | | case #1 | | case #2 | ...
3581 // +---------+ +---------+ +---------+
3582 //
3583 //
3584 // ...
3585 // \ | /
3586 // v
3587 // +-------------+
3588 // | merge block |
3589 // +-------------+
3590
3591 auto &region = op->getRegion(0);
3592 // Allow empty region as a degenerated case, which can come from
3593 // optimizations.
3594 if (region.empty())
3595 return success();
3596
3597 // The last block is the merge block.
3598 if (!isMergeBlock(region.back()))
3599 return emitOpError("last block must be the merge block with only one "
3600 "'spirv.mlir.merge' op");
3601
3602 if (std::next(region.begin()) == region.end())
3603 return emitOpError("must have a selection header block");
3604
3605 return success();
3606}
3607
3608Block *spirv::SelectionOp::getHeaderBlock() {
3609 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3609, __extension__
__PRETTY_FUNCTION__))
;
3610 // The first block is the loop header block.
3611 return &getBody().front();
3612}
3613
3614Block *spirv::SelectionOp::getMergeBlock() {
3615 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3615, __extension__
__PRETTY_FUNCTION__))
;
3616 // The last block is the loop merge block.
3617 return &getBody().back();
3618}
3619
3620void spirv::SelectionOp::addMergeBlock() {
3621 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3621, __extension__
__PRETTY_FUNCTION__))
;
3622 auto *mergeBlock = new Block();
3623 getBody().push_back(mergeBlock);
3624 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3625
3626 // Add a spirv.mlir.merge op into the merge block.
3627 builder.create<spirv::MergeOp>(getLoc());
3628}
3629
3630spirv::SelectionOp spirv::SelectionOp::createIfThen(
3631 Location loc, Value condition,
3632 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3633 auto selectionOp =
3634 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3635
3636 selectionOp.addMergeBlock();
3637 Block *mergeBlock = selectionOp.getMergeBlock();
3638 Block *thenBlock = nullptr;
3639
3640 // Build the "then" block.
3641 {
3642 OpBuilder::InsertionGuard guard(builder);
3643 thenBlock = builder.createBlock(mergeBlock);
3644 thenBody(builder);
3645 builder.create<spirv::BranchOp>(loc, mergeBlock);
3646 }
3647
3648 // Build the header block.
3649 {
3650 OpBuilder::InsertionGuard guard(builder);
3651 builder.createBlock(thenBlock);
3652 builder.create<spirv::BranchConditionalOp>(
3653 loc, condition, thenBlock,
3654 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3655 /*falseArguments=*/ArrayRef<Value>());
3656 }
3657
3658 return selectionOp;
3659}
3660
3661//===----------------------------------------------------------------------===//
3662// spirv.SpecConstant
3663//===----------------------------------------------------------------------===//
3664
3665ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3666 OperationState &result) {
3667 StringAttr nameAttr;
3668 Attribute valueAttr;
3669
3670 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3671 result.attributes))
3672 return failure();
3673
3674 // Parse optional spec_id.
3675 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3676 IntegerAttr specIdAttr;
3677 if (parser.parseLParen() ||
3678 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3679 parser.parseRParen())
3680 return failure();
3681 }
3682
3683 if (parser.parseEqual() ||
3684 parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3685 result.attributes))
3686 return failure();
3687
3688 return success();
3689}
3690
3691void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3692 printer << ' ';
3693 printer.printSymbolName(getSymName());
3694 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3695 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3696 printer << " = " << getDefaultValue();
3697}
3698
3699LogicalResult spirv::SpecConstantOp::verify() {
3700 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3701 if (specID.getValue().isNegative())
3702 return emitOpError("SpecId cannot be negative");
3703
3704 auto value = getDefaultValue();
3705 if (value.isa<IntegerAttr, FloatAttr>()) {
3706 // Make sure bitwidth is allowed.
3707 if (!value.getType().isa<spirv::SPIRVType>())
3708 return emitOpError("default value bitwidth disallowed");
3709 return success();
3710 }
3711 return emitOpError(
3712 "default value can only be a bool, integer, or float scalar");
3713}
3714
3715//===----------------------------------------------------------------------===//
3716// spirv.StoreOp
3717//===----------------------------------------------------------------------===//
3718
3719ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3720 // Parse the storage class specification
3721 spirv::StorageClass storageClass;
1
'storageClass' declared without an initial value
3722 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3723 auto loc = parser.getCurrentLocation();
3724 Type elementType;
3725 if (parseEnumStrAttr(storageClass, parser) ||
2
Calling 'parseEnumStrAttr<mlir::spirv::StorageClass>'
6
Returning from 'parseEnumStrAttr<mlir::spirv::StorageClass>'
7
Taking false branch
3726 parser.parseOperandList(operandInfo, 2) ||
3727 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3728 parser.parseType(elementType)) {
3729 return failure();
3730 }
3731
3732 auto ptrType = spirv::PointerType::get(elementType, storageClass);
8
2nd function call argument is an uninitialized value
3733 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3734 result.operands)) {
3735 return failure();
3736 }
3737 return success();
3738}
3739
3740void spirv::StoreOp::print(OpAsmPrinter &printer) {
3741 SmallVector<StringRef, 4> elidedAttrs;
3742 StringRef sc = stringifyStorageClass(
3743 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3744 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
3745
3746 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3747
3748 printer << " : " << getValue().getType();
3749 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3750}
3751
3752LogicalResult spirv::StoreOp::verify() {
3753 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3754 // OpTypePointer whose Type operand is the same as the type of Object."
3755 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
3756 return failure();
3757 return verifyMemoryAccessAttribute(*this);
3758}
3759
3760//===----------------------------------------------------------------------===//
3761// spirv.Unreachable
3762//===----------------------------------------------------------------------===//
3763
3764LogicalResult spirv::UnreachableOp::verify() {
3765 auto *block = (*this)->getBlock();
3766 // Fast track: if this is in entry block, its invalid. Otherwise, if no
3767 // predecessors, it's valid.
3768 if (block->isEntryBlock())
3769 return emitOpError("cannot be used in reachable block");
3770 if (block->hasNoPredecessors())
3771 return success();
3772
3773 // TODO: further verification needs to analyze reachability from
3774 // the entry block.
3775
3776 return success();
3777}
3778
3779//===----------------------------------------------------------------------===//
3780// spirv.Variable
3781//===----------------------------------------------------------------------===//
3782
3783ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3784 OperationState &result) {
3785 // Parse optional initializer
3786 Optional<OpAsmParser::UnresolvedOperand> initInfo;
3787 if (succeeded(parser.parseOptionalKeyword("init"))) {
3788 initInfo = OpAsmParser::UnresolvedOperand();
3789 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3790 parser.parseRParen())
3791 return failure();
3792 }
3793
3794 if (parseVariableDecorations(parser, result)) {
3795 return failure();
3796 }
3797
3798 // Parse result pointer type
3799 Type type;
3800 if (parser.parseColon())
3801 return failure();
3802 auto loc = parser.getCurrentLocation();
3803 if (parser.parseType(type))
3804 return failure();
3805
3806 auto ptrType = type.dyn_cast<spirv::PointerType>();
3807 if (!ptrType)
3808 return parser.emitError(loc, "expected spirv.ptr type");
3809 result.addTypes(ptrType);
3810
3811 // Resolve the initializer operand
3812 if (initInfo) {
3813 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3814 result.operands))
3815 return failure();
3816 }
3817
3818 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3819 ptrType.getStorageClass());
3820 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3821
3822 return success();
3823}
3824
3825void spirv::VariableOp::print(OpAsmPrinter &printer) {
3826 SmallVector<StringRef, 4> elidedAttrs{
3827 spirv::attributeName<spirv::StorageClass>()};
3828 // Print optional initializer
3829 if (getNumOperands() != 0)
3830 printer << " init(" << getInitializer() << ")";
3831
3832 printVariableDecorations(*this, printer, elidedAttrs);
3833 printer << " : " << getType();
3834}
3835
3836LogicalResult spirv::VariableOp::verify() {
3837 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3838 // object. It cannot be Generic. It must be the same as the Storage Class
3839 // operand of the Result Type."
3840 if (getStorageClass() != spirv::StorageClass::Function) {
3841 return emitOpError(
3842 "can only be used to model function-level variables. Use "
3843 "spirv.GlobalVariable for module-level variables.");
3844 }
3845
3846 auto pointerType = getPointer().getType().cast<spirv::PointerType>();
3847 if (getStorageClass() != pointerType.getStorageClass())
3848 return emitOpError(
3849 "storage class must match result pointer's storage class");
3850
3851 if (getNumOperands() != 0) {
3852 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3853 // a global (module scope) OpVariable instruction".
3854 auto *initOp = getOperand(0).getDefiningOp();
3855 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3856 spirv::ReferenceOfOp, // for spec constant
3857 spirv::AddressOfOp>(initOp))
3858 return emitOpError("initializer must be the result of a "
3859 "constant or spirv.GlobalVariable op");
3860 }
3861
3862 // TODO: generate these strings using ODS.
3863 auto *op = getOperation();
3864 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3865 stringifyDecoration(spirv::Decoration::DescriptorSet));
3866 auto bindingName = llvm::convertToSnakeFromCamelCase(
3867 stringifyDecoration(spirv::Decoration::Binding));
3868 auto builtInName = llvm::convertToSnakeFromCamelCase(
3869 stringifyDecoration(spirv::Decoration::BuiltIn));
3870
3871 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3872 if (op->getAttr(attr))
3873 return emitOpError("cannot have '")
3874 << attr << "' attribute (only allowed in spirv.GlobalVariable)";
3875 }
3876
3877 return success();
3878}
3879
3880//===----------------------------------------------------------------------===//
3881// spirv.VectorShuffle
3882//===----------------------------------------------------------------------===//
3883
3884LogicalResult spirv::VectorShuffleOp::verify() {
3885 VectorType resultType = getType().cast<VectorType>();
3886
3887 size_t numResultElements = resultType.getNumElements();
3888 if (numResultElements != getComponents().size())
3889 return emitOpError("result type element count (")
3890 << numResultElements
3891 << ") mismatch with the number of component selectors ("
3892 << getComponents().size() << ")";
3893
3894 size_t totalSrcElements =
3895 getVector1().getType().cast<VectorType>().getNumElements() +
3896 getVector2().getType().cast<VectorType>().getNumElements();
3897
3898 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3899 uint32_t index = selector.getZExtValue();
3900 if (index >= totalSrcElements &&
3901 index != std::numeric_limits<uint32_t>().max())
3902 return emitOpError("component selector ")
3903 << index << " out of range: expected to be in [0, "
3904 << totalSrcElements << ") or 0xffffffff";
3905 }
3906 return success();
3907}
3908
3909//===----------------------------------------------------------------------===//
3910// spirv.NV.CooperativeMatrixLoad
3911//===----------------------------------------------------------------------===//
3912
3913ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
3914 OperationState &result) {
3915 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3916 Type strideType = parser.getBuilder().getIntegerType(32);
3917 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3918 Type ptrType;
3919 Type elementType;
3920 if (parser.parseOperandList(operandInfo, 3) ||
3921 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3922 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3923 return failure();
3924 }
3925 if (parser.resolveOperands(operandInfo,
3926 {ptrType, strideType, columnMajorType},
3927 parser.getNameLoc(), result.operands)) {
3928 return failure();
3929 }
3930
3931 result.addTypes(elementType);
3932 return success();
3933}
3934
3935void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
3936 printer << " " << getPointer() << ", " << getStride() << ", "
3937 << getColumnmajor();
3938 // Print optional memory access attribute.
3939 if (auto memAccess = getMemoryAccess())
3940 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3941 printer << " : " << getPointer().getType() << " as " << getType();
3942}
3943
3944static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3945 Type coopMatrix) {
3946 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3947 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3948 return op->emitError(
3949 "Pointer must point to a scalar or vector type but provided ")
3950 << pointeeType;
3951 spirv::StorageClass storage =
3952 pointer.cast<spirv::PointerType>().getStorageClass();
3953 if (storage != spirv::StorageClass::Workgroup &&
3954 storage != spirv::StorageClass::StorageBuffer &&
3955 storage != spirv::StorageClass::PhysicalStorageBuffer)
3956 return op->emitError(
3957 "Pointer storage class must be Workgroup, StorageBuffer or "
3958 "PhysicalStorageBufferEXT but provided ")
3959 << stringifyStorageClass(storage);
3960 return success();
3961}
3962
3963LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
3964 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
3965 getResult().getType());
3966}
3967
3968//===----------------------------------------------------------------------===//
3969// spirv.NV.CooperativeMatrixStore
3970//===----------------------------------------------------------------------===//
3971
3972ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
3973 OperationState &result) {
3974 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
3975 Type strideType = parser.getBuilder().getIntegerType(32);
3976 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3977 Type ptrType;
3978 Type elementType;
3979 if (parser.parseOperandList(operandInfo, 4) ||
3980 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3981 parser.parseType(ptrType) || parser.parseComma() ||
3982 parser.parseType(elementType)) {
3983 return failure();
3984 }
3985 if (parser.resolveOperands(
3986 operandInfo, {ptrType, elementType, strideType, columnMajorType},
3987 parser.getNameLoc(), result.operands)) {
3988 return failure();
3989 }
3990
3991 return success();
3992}
3993
3994void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
3995 printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
3996 << ", " << getColumnmajor();
3997 // Print optional memory access attribute.
3998 if (auto memAccess = getMemoryAccess())
3999 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
4000 printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
4001}
4002
4003LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
4004 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4005 getObject().getType());
4006}
4007
4008//===----------------------------------------------------------------------===//
4009// spirv.NV.CooperativeMatrixMulAdd
4010//===----------------------------------------------------------------------===//
4011
4012static LogicalResult
4013verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
4014 if (op.getC().getType() != op.getResult().getType())
4015 return op.emitOpError("result and third operand must have the same type");
4016 auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
4017 auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
4018 auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
4019 auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
4020 if (typeA.getRows() != typeR.getRows() ||
4021 typeA.getColumns() != typeB.getRows() ||
4022 typeB.getColumns() != typeR.getColumns())
4023 return op.emitOpError("matrix size must match");
4024 if (typeR.getScope() != typeA.getScope() ||
4025 typeR.getScope() != typeB.getScope() ||
4026 typeR.getScope() != typeC.getScope())
4027 return op.emitOpError("matrix scope must match");
4028 if (typeA.getElementType() != typeB.getElementType() ||
4029 typeR.getElementType() != typeC.getElementType())
4030 return op.emitOpError("matrix element type must match");
4031 return success();
4032}
4033
4034LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
4035 return verifyCoopMatrixMulAdd(*this);
4036}
4037
4038static LogicalResult
4039verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
4040 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4041 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4042 return op->emitError(
4043 "Pointer must point to a scalar or vector type but provided ")
4044 << pointeeType;
4045 spirv::StorageClass storage =
4046 pointer.cast<spirv::PointerType>().getStorageClass();
4047 if (storage != spirv::StorageClass::Workgroup &&
4048 storage != spirv::StorageClass::CrossWorkgroup &&
4049 storage != spirv::StorageClass::UniformConstant &&
4050 storage != spirv::StorageClass::Generic)
4051 return op->emitError("Pointer storage class must be Workgroup or "
4052 "CrossWorkgroup but provided ")
4053 << stringifyStorageClass(storage);
4054 return success();
4055}
4056
4057//===----------------------------------------------------------------------===//
4058// spirv.INTEL.JointMatrixLoad
4059//===----------------------------------------------------------------------===//
4060
4061LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
4062 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4063 getResult().getType());
4064}
4065
4066//===----------------------------------------------------------------------===//
4067// spirv.INTEL.JointMatrixStore
4068//===----------------------------------------------------------------------===//
4069
4070LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
4071 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4072 getObject().getType());
4073}
4074
4075//===----------------------------------------------------------------------===//
4076// spirv.INTEL.JointMatrixMad
4077//===----------------------------------------------------------------------===//
4078
4079static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
4080 if (op.getC().getType() != op.getResult().getType())
4081 return op.emitOpError("result and third operand must have the same type");
4082 auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
4083 auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
4084 auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
4085 auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
4086 if (typeA.getRows() != typeR.getRows() ||
4087 typeA.getColumns() != typeB.getRows() ||
4088 typeB.getColumns() != typeR.getColumns())
4089 return op.emitOpError("matrix size must match");
4090 if (typeR.getScope() != typeA.getScope() ||
4091 typeR.getScope() != typeB.getScope() ||
4092 typeR.getScope() != typeC.getScope())
4093 return op.emitOpError("matrix scope must match");
4094 if (typeA.getElementType() != typeB.getElementType() ||
4095 typeR.getElementType() != typeC.getElementType())
4096 return op.emitOpError("matrix element type must match");
4097 return success();
4098}
4099
4100LogicalResult spirv::INTELJointMatrixMadOp::verify() {
4101 return verifyJointMatrixMad(*this);
4102}
4103
4104//===----------------------------------------------------------------------===//
4105// spirv.MatrixTimesScalar
4106//===----------------------------------------------------------------------===//
4107
4108LogicalResult spirv::MatrixTimesScalarOp::verify() {
4109 // We already checked that result and matrix are both of matrix type in the
4110 // auto-generated verify method.
4111
4112 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4113 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4114
4115 // Check that the scalar type is the same as the matrix element type.
4116 if (getScalar().getType() != inputMatrix.getElementType())
4117 return emitError("input matrix components' type and scaling value must "
4118 "have the same type");
4119
4120 // Note that the next three checks could be done using the AllTypesMatch
4121 // trait in the Op definition file but it generates a vague error message.
4122
4123 // Check that the input and result matrices have the same columns' count
4124 if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
4125 return emitError("input and result matrices must have the same "
4126 "number of columns");
4127
4128 // Check that the input and result matrices' have the same rows count
4129 if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
4130 return emitError("input and result matrices' columns must have "
4131 "the same size");
4132
4133 // Check that the input and result matrices' have the same component type
4134 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4135 return emitError("input and result matrices' columns must have "
4136 "the same component type");
4137
4138 return success();
4139}
4140
4141//===----------------------------------------------------------------------===//
4142// spirv.CopyMemory
4143//===----------------------------------------------------------------------===//
4144
4145void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
4146 printer << ' ';
4147
4148 StringRef targetStorageClass = stringifyStorageClass(
4149 getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4150 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
4151
4152 StringRef sourceStorageClass = stringifyStorageClass(
4153 getSource().getType().cast<spirv::PointerType>().getStorageClass());
4154 printer << " \"" << sourceStorageClass << "\" " << getSource();
4155
4156 SmallVector<StringRef, 4> elidedAttrs;
4157 printMemoryAccessAttribute(*this, printer, elidedAttrs);
4158 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4159 getSourceMemoryAccess(),
4160 getSourceAlignment());
4161
4162 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4163
4164 Type pointeeType =
4165 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4166 printer << " : " << pointeeType;
4167}
4168
4169ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4170 OperationState &result) {
4171 spirv::StorageClass targetStorageClass;
4172 OpAsmParser::UnresolvedOperand targetPtrInfo;
4173
4174 spirv::StorageClass sourceStorageClass;
4175 OpAsmParser::UnresolvedOperand sourcePtrInfo;
4176
4177 Type elementType;
4178
4179 if (parseEnumStrAttr(targetStorageClass, parser) ||
4180 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4181 parseEnumStrAttr(sourceStorageClass, parser) ||
4182 parser.parseOperand(sourcePtrInfo) ||
4183 parseMemoryAccessAttributes(parser, result)) {
4184 return failure();
4185 }
4186
4187 if (!parser.parseOptionalComma()) {
4188 // Parse 2nd memory access attributes.
4189 if (parseSourceMemoryAccessAttributes(parser, result)) {
4190 return failure();
4191 }
4192 }
4193
4194 if (parser.parseColon() || parser.parseType(elementType))
4195 return failure();
4196
4197 if (parser.parseOptionalAttrDict(result.attributes))
4198 return failure();
4199
4200 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4201 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
4202
4203 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4204 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4205 return failure();
4206 }
4207
4208 return success();
4209}
4210
4211LogicalResult spirv::CopyMemoryOp::verify() {
4212 Type targetType =
4213 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4214
4215 Type sourceType =
4216 getSource().getType().cast<spirv::PointerType>().getPointeeType();
4217
4218 if (targetType != sourceType)
4219 return emitOpError("both operands must be pointers to the same type");
4220
4221 if (failed(verifyMemoryAccessAttribute(*this)))
4222 return failure();
4223
4224 // TODO - According to the spec:
4225 //
4226 // If two masks are present, the first applies to Target and cannot include
4227 // MakePointerVisible, and the second applies to Source and cannot include
4228 // MakePointerAvailable.
4229 //
4230 // Add such verification here.
4231
4232 return verifySourceMemoryAccessAttribute(*this);
4233}
4234
4235//===----------------------------------------------------------------------===//
4236// spirv.Transpose
4237//===----------------------------------------------------------------------===//
4238
4239LogicalResult spirv::TransposeOp::verify() {
4240 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4241 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4242
4243 // Verify that the input and output matrices have correct shapes.
4244 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4245 return emitError("input matrix rows count must be equal to "
4246 "output matrix columns count");
4247
4248 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4249 return emitError("input matrix columns count must be equal to "
4250 "output matrix rows count");
4251
4252 // Verify that the input and output matrices have the same component type
4253 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4254 return emitError("input and output matrices must have the same "
4255 "component type");
4256
4257 return success();
4258}
4259
4260//===----------------------------------------------------------------------===//
4261// spirv.MatrixTimesMatrix
4262//===----------------------------------------------------------------------===//
4263
4264LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4265 auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
4266 auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
4267 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4268
4269 // left matrix columns' count and right matrix rows' count must be equal
4270 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4271 return emitError("left matrix columns' count must be equal to "
4272 "the right matrix rows' count");
4273
4274 // right and result matrices columns' count must be the same
4275 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4276 return emitError(
4277 "right and result matrices must have equal columns' count");
4278
4279 // right and result matrices component type must be the same
4280 if (rightMatrix.getElementType() != resultMatrix.getElementType())
4281 return emitError("right and result matrices' component type must"
4282 " be the same");
4283
4284 // left and result matrices component type must be the same
4285 if (leftMatrix.getElementType() != resultMatrix.getElementType())
4286 return emitError("left and result matrices' component type"
4287 " must be the same");
4288
4289 // left and result matrices rows count must be the same
4290 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4291 return emitError("left and result matrices must have equal rows' count");
4292
4293 return success();
4294}
4295
4296//===----------------------------------------------------------------------===//
4297// spirv.SpecConstantComposite
4298//===----------------------------------------------------------------------===//
4299
4300ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4301 OperationState &result) {
4302
4303 StringAttr compositeName;
4304 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4305 result.attributes))
4306 return failure();
4307
4308 if (parser.parseLParen())
4309 return failure();
4310
4311 SmallVector<Attribute, 4> constituents;
4312
4313 do {
4314 // The name of the constituent attribute isn't important
4315 const char *attrName = "spec_const";
4316 FlatSymbolRefAttr specConstRef;
4317 NamedAttrList attrs;
4318
4319 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4320 return failure();
4321
4322 constituents.push_back(specConstRef);
4323 } while (!parser.parseOptionalComma());
4324
4325 if (parser.parseRParen())
4326 return failure();
4327
4328 result.addAttribute(kCompositeSpecConstituentsName,
4329 parser.getBuilder().getArrayAttr(constituents));
4330
4331 Type type;
4332 if (parser.parseColonType(type))
4333 return failure();
4334
4335 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4336
4337 return success();
4338}
4339
4340void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4341 printer << " ";
4342 printer.printSymbolName(getSymName());
4343 printer << " (";
4344 auto constituents = this->getConstituents().getValue();
4345
4346 if (!constituents.empty())
4347 llvm::interleaveComma(constituents, printer);
4348
4349 printer << ") : " << getType();
4350}
4351
4352LogicalResult spirv::SpecConstantCompositeOp::verify() {
4353 auto cType = getType().dyn_cast<spirv::CompositeType>();
4354 auto constituents = this->getConstituents().getValue();
4355
4356 if (!cType)
4357 return emitError("result type must be a composite type, but provided ")
4358 << getType();
4359
4360 if (cType.isa<spirv::CooperativeMatrixNVType>())
4361 return emitError("unsupported composite type ") << cType;
4362 if (cType.isa<spirv::JointMatrixINTELType>())
4363 return emitError("unsupported composite type ") << cType;
4364 if (constituents.size() != cType.getNumElements())
4365 return emitError("has incorrect number of operands: expected ")
4366 << cType.getNumElements() << ", but provided "
4367 << constituents.size();
4368
4369 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4370 auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4371
4372 auto constituentSpecConstOp =
4373 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4374 (*this)->getParentOp(), constituent.getAttr()));
4375
4376 if (constituentSpecConstOp.getDefaultValue().getType() !=
4377 cType.getElementType(index))
4378 return emitError("has incorrect types of operands: expected ")
4379 << cType.getElementType(index) << ", but provided "
4380 << constituentSpecConstOp.getDefaultValue().getType();
4381 }
4382
4383 return success();
4384}
4385
4386//===----------------------------------------------------------------------===//
4387// spirv.SpecConstantOperation
4388//===----------------------------------------------------------------------===//
4389
4390ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4391 OperationState &result) {
4392 Region *body = result.addRegion();
4393
4394 if (parser.parseKeyword("wraps"))
4395 return failure();
4396
4397 body->push_back(new Block);
4398 Block &block = body->back();
4399 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4400
4401 if (!wrappedOp)
4402 return failure();
4403
4404 OpBuilder builder(parser.getContext());
4405 builder.setInsertionPointToEnd(&block);
4406 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4407 result.location = wrappedOp->getLoc();
4408
4409 result.addTypes(wrappedOp->getResult(0).getType());
4410
4411 if (parser.parseOptionalAttrDict(result.attributes))
4412 return failure();
4413
4414 return success();
4415}
4416
4417void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4418 printer << " wraps ";
4419 printer.printGenericOp(&getBody().front().front());
4420}
4421
4422LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4423 Block &block = getRegion().getBlocks().front();
4424
4425 if (block.getOperations().size() != 2)
4426 return emitOpError("expected exactly 2 nested ops");
4427
4428 Operation &enclosedOp = block.getOperations().front();
4429
4430 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4431 return emitOpError("invalid enclosed op");
4432
4433 for (auto operand : enclosedOp.getOperands())
4434 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4435 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4436 return emitOpError(
4437 "invalid operand, must be defined by a constant operation");
4438
4439 return success();
4440}
4441
4442//===----------------------------------------------------------------------===//
4443// spirv.GL.FrexpStruct
4444//===----------------------------------------------------------------------===//
4445
4446LogicalResult spirv::GLFrexpStructOp::verify() {
4447 spirv::StructType structTy =
4448 getResult().getType().dyn_cast<spirv::StructType>();
4449
4450 if (structTy.getNumElements() != 2)
4451 return emitError("result type must be a struct type with two memebers");
4452
4453 Type significandTy = structTy.getElementType(0);
4454 Type exponentTy = structTy.getElementType(1);
4455 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4456 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4457
4458 Type operandTy = getOperand().getType();
4459 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4460 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4461
4462 if (significandTy != operandTy)
4463 return emitError("member zero of the resulting struct type must be the "
4464 "same type as the operand");
4465
4466 if (exponentVecTy) {
4467 IntegerType componentIntTy =
4468 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4469 if (!componentIntTy || componentIntTy.getWidth() != 32)
4470 return emitError("member one of the resulting struct type must"
4471 "be a scalar or vector of 32 bit integer type");
4472 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4473 return emitError("member one of the resulting struct type "
4474 "must be a scalar or vector of 32 bit integer type");
4475 }
4476
4477 // Check that the two member types have the same number of components
4478 if (operandVecTy && exponentVecTy &&
4479 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4480 return success();
4481
4482 if (operandFTy && exponentIntTy)
4483 return success();
4484
4485 return emitError("member one of the resulting struct type must have the same "
4486 "number of components as the operand type");
4487}
4488
4489//===----------------------------------------------------------------------===//
4490// spirv.GL.Ldexp
4491//===----------------------------------------------------------------------===//
4492
4493LogicalResult spirv::GLLdexpOp::verify() {
4494 Type significandType = getX().getType();
4495 Type exponentType = getExp().getType();
4496
4497 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4498 return emitOpError("operands must both be scalars or vectors");
4499
4500 auto getNumElements = [](Type type) -> unsigned {
4501 if (auto vectorType = type.dyn_cast<VectorType>())
4502 return vectorType.getNumElements();
4503 return 1;
4504 };
4505
4506 if (getNumElements(significandType) != getNumElements(exponentType))
4507 return emitOpError("operands must have the same number of elements");
4508
4509 return success();
4510}
4511
4512//===----------------------------------------------------------------------===//
4513// spirv.ImageDrefGather
4514//===----------------------------------------------------------------------===//
4515
4516LogicalResult spirv::ImageDrefGatherOp::verify() {
4517 VectorType resultType = getResult().getType().cast<VectorType>();
4518 auto sampledImageType =
4519 getSampledimage().getType().cast<spirv::SampledImageType>();
4520 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4521
4522 if (resultType.getNumElements() != 4)
4523 return emitOpError("result type must be a vector of four components");
4524
4525 Type elementType = resultType.getElementType();
4526 Type sampledElementType = imageType.getElementType();
4527 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4528 return emitOpError(
4529 "the component type of result must be the same as sampled type of the "
4530 "underlying image type");
4531
4532 spirv::Dim imageDim = imageType.getDim();
4533 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4534
4535 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4536 imageDim != spirv::Dim::Rect)
4537 return emitOpError(
4538 "the Dim operand of the underlying image type must be 2D, Cube, or "
4539 "Rect");
4540
4541 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4542 return emitOpError("the MS operand of the underlying image type must be 0");
4543
4544 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4545 auto operandArguments = getOperandArguments();
4546
4547 return verifyImageOperands(*this, attr, operandArguments);
4548}
4549
4550//===----------------------------------------------------------------------===//
4551// spirv.ShiftLeftLogicalOp
4552//===----------------------------------------------------------------------===//
4553
4554LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4555 return verifyShiftOp(*this);
4556}
4557
4558//===----------------------------------------------------------------------===//
4559// spirv.ShiftRightArithmeticOp
4560//===----------------------------------------------------------------------===//
4561
4562LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4563 return verifyShiftOp(*this);
4564}
4565
4566//===----------------------------------------------------------------------===//
4567// spirv.ShiftRightLogicalOp
4568//===----------------------------------------------------------------------===//
4569
4570LogicalResult spirv::ShiftRightLogicalOp::verify() {
4571 return verifyShiftOp(*this);
4572}
4573
4574//===----------------------------------------------------------------------===//
4575// spirv.ImageQuerySize
4576//===----------------------------------------------------------------------===//
4577
4578LogicalResult spirv::ImageQuerySizeOp::verify() {
4579 spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
4580 Type resultType = getResult().getType();
4581
4582 spirv::Dim dim = imageType.getDim();
4583 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4584 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4585 switch (dim) {
4586 case spirv::Dim::Dim1D:
4587 case spirv::Dim::Dim2D:
4588 case spirv::Dim::Dim3D:
4589 case spirv::Dim::Cube:
4590 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4591 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4592 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4593 return emitError(
4594 "if Dim is 1D, 2D, 3D, or Cube, "
4595 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4596 break;
4597 case spirv::Dim::Buffer:
4598 case spirv::Dim::Rect:
4599 break;
4600 default:
4601 return emitError("the Dim operand of the image type must "
4602 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4603 }
4604
4605 unsigned componentNumber = 0;
4606 switch (dim) {
4607 case spirv::Dim::Dim1D:
4608 case spirv::Dim::Buffer:
4609 componentNumber = 1;
4610 break;
4611 case spirv::Dim::Dim2D:
4612 case spirv::Dim::Cube:
4613 case spirv::Dim::Rect:
4614 componentNumber = 2;
4615 break;
4616 case spirv::Dim::Dim3D:
4617 componentNumber = 3;
4618 break;
4619 default:
4620 break;
4621 }
4622
4623 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4624 componentNumber += 1;
4625
4626 unsigned resultComponentNumber = 1;
4627 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4628 resultComponentNumber = resultVectorType.getNumElements();
4629
4630 if (componentNumber != resultComponentNumber)
4631 return emitError("expected the result to have ")
4632 << componentNumber << " component(s), but found "
4633 << resultComponentNumber << " component(s)";
4634
4635 return success();
4636}
4637
4638static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4639 OpAsmParser &parser,
4640 OperationState &state) {
4641 OpAsmParser::UnresolvedOperand ptrInfo;
4642 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4643 Type type;
4644 auto loc = parser.getCurrentLocation();
4645 SmallVector<Type, 4> indicesTypes;
4646
4647 if (parser.parseOperand(ptrInfo) ||
4648 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4649 parser.parseColonType(type) ||
4650 parser.resolveOperand(ptrInfo, type, state.operands))
4651 return failure();
4652
4653 // Check that the provided indices list is not empty before parsing their
4654 // type list.
4655 if (indicesInfo.empty())
4656 return emitError(state.location) << opName << " expected element";
4657
4658 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4659 return failure();
4660
4661 // Check that the indices types list is not empty and that it has a one-to-one
4662 // mapping to the provided indices.
4663 if (indicesTypes.size() != indicesInfo.size())
4664 return emitError(state.location)
4665 << opName
4666 << " indices types' count must be equal to indices info count";
4667
4668 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4669 return failure();
4670
4671 auto resultType = getElementPtrType(
4672 type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4673 if (!resultType)
4674 return failure();
4675
4676 state.addTypes(resultType);
4677 return success();
4678}
4679
4680template <typename Op>
4681static auto concatElemAndIndices(Op op) {
4682 SmallVector<Value> ret(op.getIndices().size() + 1);
4683 ret[0] = op.getElement();
4684 llvm::copy(op.getIndices(), ret.begin() + 1);
4685 return ret;
4686}
4687
4688//===----------------------------------------------------------------------===//
4689// spirv.InBoundsPtrAccessChainOp
4690//===----------------------------------------------------------------------===//
4691
4692void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4693 OperationState &state,
4694 Value basePtr, Value element,
4695 ValueRange indices) {
4696 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4697 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4697, __extension__
__PRETTY_FUNCTION__))
;
4698 build(builder, state, type, basePtr, element, indices);
4699}
4700
4701ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4702 OperationState &result) {
4703 return parsePtrAccessChainOpImpl(
4704 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4705}
4706
4707void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4708 printAccessChain(*this, concatElemAndIndices(*this), printer);
4709}
4710
4711LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4712 return verifyAccessChain(*this, getIndices());
4713}
4714
4715//===----------------------------------------------------------------------===//
4716// spirv.PtrAccessChainOp
4717//===----------------------------------------------------------------------===//
4718
4719void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4720 Value basePtr, Value element,
4721 ValueRange indices) {
4722 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4723 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4723, __extension__
__PRETTY_FUNCTION__))
;
4724 build(builder, state, type, basePtr, element, indices);
4725}
4726
4727ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4728 OperationState &result) {
4729 return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4730 parser, result);
4731}
4732
4733void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4734 printAccessChain(*this, concatElemAndIndices(*this), printer);
4735}
4736
4737LogicalResult spirv::PtrAccessChainOp::verify() {
4738 return verifyAccessChain(*this, getIndices());
4739}
4740
4741//===----------------------------------------------------------------------===//
4742// spirv.VectorTimesScalarOp
4743//===----------------------------------------------------------------------===//
4744
4745LogicalResult spirv::VectorTimesScalarOp::verify() {
4746 if (getVector().getType() != getType())
4747 return emitOpError("vector operand and result type mismatch");
4748 auto scalarType = getType().cast<VectorType>().getElementType();
4749 if (getScalar().getType() != scalarType)
4750 return emitOpError("scalar operand and result element type match");
4751 return success();
4752}
4753
4754// TableGen'erated operation interfaces for querying versions, extensions, and
4755// capabilities.
4756#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4757
4758// TablenGen'erated operation definitions.
4759#define GET_OP_CLASSES
4760#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4761
4762namespace mlir {
4763namespace spirv {
4764// TableGen'erated operation availability interface implementations.
4765#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4766} // namespace spirv
4767} // namespace mlir