Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Warning:line 2660, column 18
2nd function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SPIRVOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/SPIRV/IR -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/Dialect/SPIRV/IR -I include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-09-04-125545-48738-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/FunctionImplementation.h"
25#include "mlir/IR/OpDefinition.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/TypeUtilities.h"
28#include "mlir/Interfaces/CallInterfaces.h"
29#include "llvm/ADT/APFloat.h"
30#include "llvm/ADT/APInt.h"
31#include "llvm/ADT/STLExtras.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/bit.h"
34#include <numeric>
35
36using namespace mlir;
37
38// TODO: generate these strings using ODS.
39constexpr char kMemoryAccessAttrName[] = "memory_access";
40constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
41constexpr char kAlignmentAttrName[] = "alignment";
42constexpr char kSourceAlignmentAttrName[] = "source_alignment";
43constexpr char kBranchWeightAttrName[] = "branch_weights";
44constexpr char kCallee[] = "callee";
45constexpr char kClusterSize[] = "cluster_size";
46constexpr char kControl[] = "control";
47constexpr char kDefaultValueAttrName[] = "default_value";
48constexpr char kExecutionScopeAttrName[] = "execution_scope";
49constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
50constexpr char kFnNameAttrName[] = "fn";
51constexpr char kGroupOperationAttrName[] = "group_operation";
52constexpr char kIndicesAttrName[] = "indices";
53constexpr char kInitializerAttrName[] = "initializer";
54constexpr char kInterfaceAttrName[] = "interface";
55constexpr char kMemoryScopeAttrName[] = "memory_scope";
56constexpr char kSemanticsAttrName[] = "semantics";
57constexpr char kSpecIdAttrName[] = "spec_id";
58constexpr char kTypeAttrName[] = "type";
59constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
60constexpr char kValueAttrName[] = "value";
61constexpr char kValuesAttrName[] = "values";
62constexpr char kCompositeSpecConstituentsName[] = "constituents";
63
64//===----------------------------------------------------------------------===//
65// Common utility functions
66//===----------------------------------------------------------------------===//
67
68static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
69 OperationState &result) {
70 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
71 Type type;
72 // If the operand list is in-between parentheses, then we have a generic form.
73 // (see the fallback in `printOneResultOp`).
74 SMLoc loc = parser.getCurrentLocation();
75 if (!parser.parseOptionalLParen()) {
76 if (parser.parseOperandList(ops) || parser.parseRParen() ||
77 parser.parseOptionalAttrDict(result.attributes) ||
78 parser.parseColon() || parser.parseType(type))
79 return failure();
80 auto fnType = type.dyn_cast<FunctionType>();
81 if (!fnType) {
82 parser.emitError(loc, "expected function type");
83 return failure();
84 }
85 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
86 return failure();
87 result.addTypes(fnType.getResults());
88 return success();
89 }
90 return failure(parser.parseOperandList(ops) ||
91 parser.parseOptionalAttrDict(result.attributes) ||
92 parser.parseColonType(type) ||
93 parser.resolveOperands(ops, type, result.operands) ||
94 parser.addTypeToList(type, result.types));
95}
96
97static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
98 assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 &&
"op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 98, __extension__
__PRETTY_FUNCTION__))
;
99
100 // If not all the operand and result types are the same, just use the
101 // generic assembly form to avoid omitting information in printing.
102 auto resultType = op->getResult(0).getType();
103 if (llvm::any_of(op->getOperandTypes(),
104 [&](Type type) { return type != resultType; })) {
105 p.printGenericOp(op, /*printOpName=*/false);
106 return;
107 }
108
109 p << ' ';
110 p.printOperands(op->getOperands());
111 p.printOptionalAttrDict(op->getAttrs());
112 // Now we can output only one type for all operands and the result.
113 p << " : " << resultType;
114}
115
116/// Returns true if the given op is a function-like op or nested in a
117/// function-like op without a module-like op in the middle.
118static bool isNestedInFunctionOpInterface(Operation *op) {
119 if (!op)
120 return false;
121 if (op->hasTrait<OpTrait::SymbolTable>())
122 return false;
123 if (isa<FunctionOpInterface>(op))
124 return true;
125 return isNestedInFunctionOpInterface(op->getParentOp());
126}
127
128/// Returns true if the given op is an module-like op that maintains a symbol
129/// table.
130static bool isDirectInModuleLikeOp(Operation *op) {
131 return op && op->hasTrait<OpTrait::SymbolTable>();
132}
133
134static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
135 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
136 if (!constOp) {
137 return failure();
138 }
139 auto valueAttr = constOp.value();
140 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
141 if (!integerValueAttr) {
142 return failure();
143 }
144
145 if (integerValueAttr.getType().isSignlessInteger())
146 value = integerValueAttr.getInt();
147 else
148 value = integerValueAttr.getSInt();
149
150 return success();
151}
152
153template <typename Ty>
154static ArrayAttr
155getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
156 function_ref<StringRef(Ty)> stringifyFn) {
157 if (enumValues.empty()) {
158 return nullptr;
159 }
160 SmallVector<StringRef, 1> enumValStrs;
161 enumValStrs.reserve(enumValues.size());
162 for (auto val : enumValues) {
163 enumValStrs.emplace_back(stringifyFn(val));
164 }
165 return builder.getStrArrayAttr(enumValStrs);
166}
167
168/// Parses the next string attribute in `parser` as an enumerant of the given
169/// `EnumClass`.
170template <typename EnumClass>
171static ParseResult
172parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
173 StringRef attrName = spirv::attributeName<EnumClass>()) {
174 Attribute attrVal;
175 NamedAttrList attr;
176 auto loc = parser.getCurrentLocation();
177 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
3
Taking false branch
178 attrName, attr))
179 return failure();
180 if (!attrVal.isa<StringAttr>())
4
Assuming the condition is true
5
Taking true branch
181 return parser.emitError(loc, "expected ")
6
Returning without writing to 'value'
182 << attrName << " attribute specified as string";
183 auto attrOptional =
184 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
185 if (!attrOptional)
186 return parser.emitError(loc, "invalid ")
187 << attrName << " attribute specification: " << attrVal;
188 value = *attrOptional;
189 return success();
190}
191
192/// Parses the next string attribute in `parser` as an enumerant of the given
193/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
194/// attribute with the enum class's name as attribute name.
195template <typename EnumAttrClass,
196 typename EnumClass = typename EnumAttrClass::ValueType>
197static ParseResult
198parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
199 StringRef attrName = spirv::attributeName<EnumClass>()) {
200 if (parseEnumStrAttr(value, parser))
201 return failure();
202 state.addAttribute(attrName,
203 parser.getBuilder().getAttr<EnumAttrClass>(value));
204 return success();
205}
206
207/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
208/// and inserts the enumerant into `state` as an 32-bit integer attribute with
209/// the enum class's name as attribute name.
210template <typename EnumAttrClass,
211 typename EnumClass = typename EnumAttrClass::ValueType>
212static ParseResult
213parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
214 OperationState &state,
215 StringRef attrName = spirv::attributeName<EnumClass>()) {
216 if (parseEnumKeywordAttr(value, parser))
217 return failure();
218 state.addAttribute(attrName,
219 parser.getBuilder().getAttr<EnumAttrClass>(value));
220 return success();
221}
222
223/// Parses Function, Selection and Loop control attributes. If no control is
224/// specified, "None" is used as a default.
225template <typename EnumAttrClass, typename EnumClass>
226static ParseResult
227parseControlAttribute(OpAsmParser &parser, OperationState &state,
228 StringRef attrName = spirv::attributeName<EnumClass>()) {
229 if (succeeded(parser.parseOptionalKeyword(kControl))) {
230 EnumClass control;
231 if (parser.parseLParen() ||
232 parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
233 parser.parseRParen())
234 return failure();
235 return success();
236 }
237 // Set control to "None" otherwise.
238 Builder builder = parser.getBuilder();
239 state.addAttribute(attrName,
240 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
241 return success();
242}
243
244/// Parses optional memory access attributes attached to a memory access
245/// operand/pointer. Specifically, parses the following syntax:
246/// (`[` memory-access `]`)?
247/// where:
248/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
249/// integer-literal | `"NonTemporal"`
250static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
251 OperationState &state) {
252 // Parse an optional list of attributes staring with '['
253 if (parser.parseOptionalLSquare()) {
254 // Nothing to do
255 return success();
256 }
257
258 spirv::MemoryAccess memoryAccessAttr;
259 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
260 kMemoryAccessAttrName))
261 return failure();
262
263 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
264 // Parse integer attribute for alignment.
265 Attribute alignmentAttr;
266 Type i32Type = parser.getBuilder().getIntegerType(32);
267 if (parser.parseComma() ||
268 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
269 state.attributes)) {
270 return failure();
271 }
272 }
273 return parser.parseRSquare();
274}
275
276// TODO Make sure to merge this and the previous function into one template
277// parameterized by memory access attribute name and alignment. Doing so now
278// results in VS2017 in producing an internal error (at the call site) that's
279// not detailed enough to understand what is happening.
280static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
281 OperationState &state) {
282 // Parse an optional list of attributes staring with '['
283 if (parser.parseOptionalLSquare()) {
284 // Nothing to do
285 return success();
286 }
287
288 spirv::MemoryAccess memoryAccessAttr;
289 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
290 kSourceMemoryAccessAttrName))
291 return failure();
292
293 if (spirv::bitEnumContains(memoryAccessAttr, spirv::MemoryAccess::Aligned)) {
294 // Parse integer attribute for alignment.
295 Attribute alignmentAttr;
296 Type i32Type = parser.getBuilder().getIntegerType(32);
297 if (parser.parseComma() ||
298 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
299 state.attributes)) {
300 return failure();
301 }
302 }
303 return parser.parseRSquare();
304}
305
306template <typename MemoryOpTy>
307static void printMemoryAccessAttribute(
308 MemoryOpTy memoryOp, OpAsmPrinter &printer,
309 SmallVectorImpl<StringRef> &elidedAttrs,
310 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
311 Optional<uint32_t> alignmentAttrValue = None) {
312 // Print optional memory access attribute.
313 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
314 : memoryOp.memory_access())) {
315 elidedAttrs.push_back(kMemoryAccessAttrName);
316
317 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
318
319 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
320 // Print integer alignment attribute.
321 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
322 : memoryOp.alignment())) {
323 elidedAttrs.push_back(kAlignmentAttrName);
324 printer << ", " << alignment;
325 }
326 }
327 printer << "]";
328 }
329 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
330}
331
332// TODO Make sure to merge this and the previous function into one template
333// parameterized by memory access attribute name and alignment. Doing so now
334// results in VS2017 in producing an internal error (at the call site) that's
335// not detailed enough to understand what is happening.
336template <typename MemoryOpTy>
337static void printSourceMemoryAccessAttribute(
338 MemoryOpTy memoryOp, OpAsmPrinter &printer,
339 SmallVectorImpl<StringRef> &elidedAttrs,
340 Optional<spirv::MemoryAccess> memoryAccessAtrrValue = None,
341 Optional<uint32_t> alignmentAttrValue = None) {
342
343 printer << ", ";
344
345 // Print optional memory access attribute.
346 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
347 : memoryOp.memory_access())) {
348 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
349
350 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
351
352 if (spirv::bitEnumContains(*memAccess, spirv::MemoryAccess::Aligned)) {
353 // Print integer alignment attribute.
354 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
355 : memoryOp.alignment())) {
356 elidedAttrs.push_back(kSourceAlignmentAttrName);
357 printer << ", " << alignment;
358 }
359 }
360 printer << "]";
361 }
362 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
363}
364
365static ParseResult parseImageOperands(OpAsmParser &parser,
366 spirv::ImageOperandsAttr &attr) {
367 // Expect image operands
368 if (parser.parseOptionalLSquare())
369 return success();
370
371 spirv::ImageOperands imageOperands;
372 if (parseEnumStrAttr(imageOperands, parser))
373 return failure();
374
375 attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
376
377 return parser.parseRSquare();
378}
379
380static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
381 spirv::ImageOperandsAttr attr) {
382 if (attr) {
383 auto strImageOperands = stringifyImageOperands(attr.getValue());
384 printer << "[\"" << strImageOperands << "\"]";
385 }
386}
387
388template <typename Op>
389static LogicalResult verifyImageOperands(Op imageOp,
390 spirv::ImageOperandsAttr attr,
391 Operation::operand_range operands) {
392 if (!attr) {
393 if (operands.empty())
394 return success();
395
396 return imageOp.emitError("the Image Operands should encode what operands "
397 "follow, as per Image Operands");
398 }
399
400 // TODO: Add the validation rules for the following Image Operands.
401 spirv::ImageOperands noSupportOperands =
402 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
403 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
404 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
405 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
406 spirv::ImageOperands::MakeTexelAvailable |
407 spirv::ImageOperands::MakeTexelVisible |
408 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
409
410 if (spirv::bitEnumContains(attr.getValue(), noSupportOperands))
411 llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 411)
;
412
413 return success();
414}
415
416static LogicalResult verifyCastOp(Operation *op,
417 bool requireSameBitWidth = true,
418 bool skipBitWidthCheck = false) {
419 // Some CastOps have no limit on bit widths for result and operand type.
420 if (skipBitWidthCheck)
421 return success();
422
423 Type operandType = op->getOperand(0).getType();
424 Type resultType = op->getResult(0).getType();
425
426 // ODS checks that result type and operand type have the same shape.
427 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
428 operandType = vectorType.getElementType();
429 resultType = resultType.cast<VectorType>().getElementType();
430 }
431
432 if (auto coopMatrixType =
433 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
434 operandType = coopMatrixType.getElementType();
435 resultType =
436 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
437 }
438
439 if (auto jointMatrixType =
440 operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
441 operandType = jointMatrixType.getElementType();
442 resultType =
443 resultType.cast<spirv::JointMatrixINTELType>().getElementType();
444 }
445
446 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
447 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
448 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
449
450 if (requireSameBitWidth) {
451 if (!isSameBitWidth) {
452 return op->emitOpError(
453 "expected the same bit widths for operand type and result "
454 "type, but provided ")
455 << operandType << " and " << resultType;
456 }
457 return success();
458 }
459
460 if (isSameBitWidth) {
461 return op->emitOpError(
462 "expected the different bit widths for operand type and result "
463 "type, but provided ")
464 << operandType << " and " << resultType;
465 }
466 return success();
467}
468
469template <typename MemoryOpTy>
470static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
471 // ODS checks for attributes values. Just need to verify that if the
472 // memory-access attribute is Aligned, then the alignment attribute must be
473 // present.
474 auto *op = memoryOp.getOperation();
475 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
476 if (!memAccessAttr) {
477 // Alignment attribute shouldn't be present if memory access attribute is
478 // not present.
479 if (op->getAttr(kAlignmentAttrName)) {
480 return memoryOp.emitOpError(
481 "invalid alignment specification without aligned memory access "
482 "specification");
483 }
484 return success();
485 }
486
487 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
488
489 if (!memAccess) {
490 return memoryOp.emitOpError("invalid memory access specifier: ")
491 << memAccessAttr;
492 }
493
494 if (spirv::bitEnumContains(memAccess.getValue(),
495 spirv::MemoryAccess::Aligned)) {
496 if (!op->getAttr(kAlignmentAttrName)) {
497 return memoryOp.emitOpError("missing alignment value");
498 }
499 } else {
500 if (op->getAttr(kAlignmentAttrName)) {
501 return memoryOp.emitOpError(
502 "invalid alignment specification with non-aligned memory access "
503 "specification");
504 }
505 }
506 return success();
507}
508
509// TODO Make sure to merge this and the previous function into one template
510// parameterized by memory access attribute name and alignment. Doing so now
511// results in VS2017 in producing an internal error (at the call site) that's
512// not detailed enough to understand what is happening.
513template <typename MemoryOpTy>
514static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
515 // ODS checks for attributes values. Just need to verify that if the
516 // memory-access attribute is Aligned, then the alignment attribute must be
517 // present.
518 auto *op = memoryOp.getOperation();
519 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
520 if (!memAccessAttr) {
521 // Alignment attribute shouldn't be present if memory access attribute is
522 // not present.
523 if (op->getAttr(kSourceAlignmentAttrName)) {
524 return memoryOp.emitOpError(
525 "invalid alignment specification without aligned memory access "
526 "specification");
527 }
528 return success();
529 }
530
531 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
532
533 if (!memAccess) {
534 return memoryOp.emitOpError("invalid memory access specifier: ")
535 << memAccess;
536 }
537
538 if (spirv::bitEnumContains(memAccess.getValue(),
539 spirv::MemoryAccess::Aligned)) {
540 if (!op->getAttr(kSourceAlignmentAttrName)) {
541 return memoryOp.emitOpError("missing alignment value");
542 }
543 } else {
544 if (op->getAttr(kSourceAlignmentAttrName)) {
545 return memoryOp.emitOpError(
546 "invalid alignment specification with non-aligned memory access "
547 "specification");
548 }
549 }
550 return success();
551}
552
553static LogicalResult
554verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
555 // According to the SPIR-V specification:
556 // "Despite being a mask and allowing multiple bits to be combined, it is
557 // invalid for more than one of these four bits to be set: Acquire, Release,
558 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
559 // Release semantics is done by setting the AcquireRelease bit, not by setting
560 // two bits."
561 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
562 spirv::MemorySemantics::Release |
563 spirv::MemorySemantics::AcquireRelease |
564 spirv::MemorySemantics::SequentiallyConsistent;
565
566 auto bitCount = llvm::countPopulation(
567 static_cast<uint32_t>(memorySemantics & atMostOneInSet));
568 if (bitCount > 1) {
569 return op->emitError(
570 "expected at most one of these four memory constraints "
571 "to be set: `Acquire`, `Release`,"
572 "`AcquireRelease` or `SequentiallyConsistent`");
573 }
574 return success();
575}
576
577template <typename LoadStoreOpTy>
578static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
579 Value val) {
580 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
581 // type of the pointer and the type of the value are the same
582 //
583 // TODO: Check that the value type satisfies restrictions of
584 // SPIR-V OpLoad/OpStore operations
585 if (val.getType() !=
586 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
587 return op.emitOpError("mismatch in result type and pointer type");
588 }
589 return success();
590}
591
592template <typename BlockReadWriteOpTy>
593static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
594 Value ptr, Value val) {
595 auto valType = val.getType();
596 if (auto valVecTy = valType.dyn_cast<VectorType>())
597 valType = valVecTy.getElementType();
598
599 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
600 return op.emitOpError("mismatch in result type and pointer type");
601 }
602 return success();
603}
604
605static ParseResult parseVariableDecorations(OpAsmParser &parser,
606 OperationState &state) {
607 auto builtInName = llvm::convertToSnakeFromCamelCase(
608 stringifyDecoration(spirv::Decoration::BuiltIn));
609 if (succeeded(parser.parseOptionalKeyword("bind"))) {
610 Attribute set, binding;
611 // Parse optional descriptor binding
612 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
613 stringifyDecoration(spirv::Decoration::DescriptorSet));
614 auto bindingName = llvm::convertToSnakeFromCamelCase(
615 stringifyDecoration(spirv::Decoration::Binding));
616 Type i32Type = parser.getBuilder().getIntegerType(32);
617 if (parser.parseLParen() ||
618 parser.parseAttribute(set, i32Type, descriptorSetName,
619 state.attributes) ||
620 parser.parseComma() ||
621 parser.parseAttribute(binding, i32Type, bindingName,
622 state.attributes) ||
623 parser.parseRParen()) {
624 return failure();
625 }
626 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
627 StringAttr builtIn;
628 if (parser.parseLParen() ||
629 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
630 parser.parseRParen()) {
631 return failure();
632 }
633 }
634
635 // Parse other attributes
636 if (parser.parseOptionalAttrDict(state.attributes))
637 return failure();
638
639 return success();
640}
641
642static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
643 SmallVectorImpl<StringRef> &elidedAttrs) {
644 // Print optional descriptor binding
645 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
646 stringifyDecoration(spirv::Decoration::DescriptorSet));
647 auto bindingName = llvm::convertToSnakeFromCamelCase(
648 stringifyDecoration(spirv::Decoration::Binding));
649 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
650 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
651 if (descriptorSet && binding) {
652 elidedAttrs.push_back(descriptorSetName);
653 elidedAttrs.push_back(bindingName);
654 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
655 << ")";
656 }
657
658 // Print BuiltIn attribute if present
659 auto builtInName = llvm::convertToSnakeFromCamelCase(
660 stringifyDecoration(spirv::Decoration::BuiltIn));
661 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
662 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
663 elidedAttrs.push_back(builtInName);
664 }
665
666 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
667}
668
669// Get bit width of types.
670static unsigned getBitWidth(Type type) {
671 if (type.isa<spirv::PointerType>()) {
672 // Just return 64 bits for pointer types for now.
673 // TODO: Make sure not caller relies on the actual pointer width value.
674 return 64;
675 }
676
677 if (type.isIntOrFloat())
678 return type.getIntOrFloatBitWidth();
679
680 if (auto vectorType = type.dyn_cast<VectorType>()) {
681 assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 681, __extension__
__PRETTY_FUNCTION__))
;
682 return vectorType.getNumElements() *
683 vectorType.getElementType().getIntOrFloatBitWidth();
684 }
685 llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 685)
;
686}
687
688/// Walks the given type hierarchy with the given indices, potentially down
689/// to component granularity, to select an element type. Returns null type and
690/// emits errors with the given loc on failure.
691static Type
692getElementType(Type type, ArrayRef<int32_t> indices,
693 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
694 if (indices.empty()) {
695 emitErrorFn("expected at least one index for spv.CompositeExtract");
696 return nullptr;
697 }
698
699 for (auto index : indices) {
700 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
701 if (cType.hasCompileTimeKnownNumElements() &&
702 (index < 0 ||
703 static_cast<uint64_t>(index) >= cType.getNumElements())) {
704 emitErrorFn("index ") << index << " out of bounds for " << type;
705 return nullptr;
706 }
707 type = cType.getElementType(index);
708 } else {
709 emitErrorFn("cannot extract from non-composite type ")
710 << type << " with index " << index;
711 return nullptr;
712 }
713 }
714 return type;
715}
716
717static Type
718getElementType(Type type, Attribute indices,
719 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
720 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
721 if (!indicesArrayAttr) {
722 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
723 return nullptr;
724 }
725 if (indicesArrayAttr.empty()) {
726 emitErrorFn("expected at least one index for spv.CompositeExtract");
727 return nullptr;
728 }
729
730 SmallVector<int32_t, 2> indexVals;
731 for (auto indexAttr : indicesArrayAttr) {
732 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
733 if (!indexIntAttr) {
734 emitErrorFn("expected an 32-bit integer for index, but found '")
735 << indexAttr << "'";
736 return nullptr;
737 }
738 indexVals.push_back(indexIntAttr.getInt());
739 }
740 return getElementType(type, indexVals, emitErrorFn);
741}
742
743static Type getElementType(Type type, Attribute indices, Location loc) {
744 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
745 return ::mlir::emitError(loc, err);
746 };
747 return getElementType(type, indices, errorFn);
748}
749
750static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
751 SMLoc loc) {
752 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
753 return parser.emitError(loc, err);
754 };
755 return getElementType(type, indices, errorFn);
756}
757
758/// Returns true if the given `block` only contains one `spv.mlir.merge` op.
759static inline bool isMergeBlock(Block &block) {
760 return !block.empty() && std::next(block.begin()) == block.end() &&
761 isa<spirv::MergeOp>(block.front());
762}
763
764//===----------------------------------------------------------------------===//
765// Common parsers and printers
766//===----------------------------------------------------------------------===//
767
768// Parses an atomic update op. If the update op does not take a value (like
769// AtomicIIncrement) `hasValue` must be false.
770static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
771 OperationState &state, bool hasValue) {
772 spirv::Scope scope;
773 spirv::MemorySemantics memoryScope;
774 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
775 OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
776 Type type;
777 SMLoc loc;
778 if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
779 kMemoryScopeAttrName) ||
780 parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
781 kSemanticsAttrName) ||
782 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
783 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
784 return failure();
785
786 auto ptrType = type.dyn_cast<spirv::PointerType>();
787 if (!ptrType)
788 return parser.emitError(loc, "expected pointer type");
789
790 SmallVector<Type, 2> operandTypes;
791 operandTypes.push_back(ptrType);
792 if (hasValue)
793 operandTypes.push_back(ptrType.getPointeeType());
794 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
795 state.operands))
796 return failure();
797 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
798}
799
800// Prints an atomic update op.
801static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
802 printer << " \"";
803 auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
804 printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
805 auto memorySemanticsAttr =
806 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
807 printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
808 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
809}
810
811template <typename T>
812static StringRef stringifyTypeName();
813
814template <>
815StringRef stringifyTypeName<IntegerType>() {
816 return "integer";
817}
818
819template <>
820StringRef stringifyTypeName<FloatType>() {
821 return "float";
822}
823
824// Verifies an atomic update op.
825template <typename ExpectedElementType>
826static LogicalResult verifyAtomicUpdateOp(Operation *op) {
827 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
828 auto elementType = ptrType.getPointeeType();
829 if (!elementType.isa<ExpectedElementType>())
830 return op->emitOpError() << "pointer operand must point to an "
831 << stringifyTypeName<ExpectedElementType>()
832 << " value, found " << elementType;
833
834 if (op->getNumOperands() > 1) {
835 auto valueType = op->getOperand(1).getType();
836 if (valueType != elementType)
837 return op->emitOpError("expected value to have the same type as the "
838 "pointer operand's pointee type ")
839 << elementType << ", but found " << valueType;
840 }
841 auto memorySemantics =
842 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
843 .getValue();
844 if (failed(verifyMemorySemantics(op, memorySemantics))) {
845 return failure();
846 }
847 return success();
848}
849
850static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
851 OperationState &state) {
852 spirv::Scope executionScope;
853 spirv::GroupOperation groupOperation;
854 OpAsmParser::UnresolvedOperand valueInfo;
855 if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
856 kExecutionScopeAttrName) ||
857 parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
858 kGroupOperationAttrName) ||
859 parser.parseOperand(valueInfo))
860 return failure();
861
862 Optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
863 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
864 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
865 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
866 parser.parseRParen())
867 return failure();
868 }
869
870 Type resultType;
871 if (parser.parseColonType(resultType))
872 return failure();
873
874 if (parser.resolveOperand(valueInfo, resultType, state.operands))
875 return failure();
876
877 if (clusterSizeInfo) {
878 Type i32Type = parser.getBuilder().getIntegerType(32);
879 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
880 return failure();
881 }
882
883 return parser.addTypeToList(resultType, state.types);
884}
885
886static void printGroupNonUniformArithmeticOp(Operation *groupOp,
887 OpAsmPrinter &printer) {
888 printer
889 << " \""
890 << stringifyScope(
891 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
892 .getValue())
893 << "\" \""
894 << stringifyGroupOperation(groupOp
895 ->getAttrOfType<spirv::GroupOperationAttr>(
896 kGroupOperationAttrName)
897 .getValue())
898 << "\" " << groupOp->getOperand(0);
899
900 if (groupOp->getNumOperands() > 1)
901 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
902 printer << " : " << groupOp->getResult(0).getType();
903}
904
905static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
906 spirv::Scope scope =
907 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
908 .getValue();
909 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
910 return groupOp->emitOpError(
911 "execution scope must be 'Workgroup' or 'Subgroup'");
912
913 spirv::GroupOperation operation =
914 groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
915 .getValue();
916 if (operation == spirv::GroupOperation::ClusteredReduce &&
917 groupOp->getNumOperands() == 1)
918 return groupOp->emitOpError("cluster size operand must be provided for "
919 "'ClusteredReduce' group operation");
920 if (groupOp->getNumOperands() > 1) {
921 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
922 int32_t clusterSize = 0;
923
924 // TODO: support specialization constant here.
925 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
926 return groupOp->emitOpError(
927 "cluster size operand must come from a constant op");
928
929 if (!llvm::isPowerOf2_32(clusterSize))
930 return groupOp->emitOpError(
931 "cluster size operand must be a power of two");
932 }
933 return success();
934}
935
936/// Result of a logical op must be a scalar or vector of boolean type.
937static Type getUnaryOpResultType(Type operandType) {
938 Builder builder(operandType.getContext());
939 Type resultType = builder.getIntegerType(1);
940 if (auto vecType = operandType.dyn_cast<VectorType>())
941 return VectorType::get(vecType.getNumElements(), resultType);
942 return resultType;
943}
944
945static LogicalResult verifyShiftOp(Operation *op) {
946 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
947 return op->emitError("expected the same type for the first operand and "
948 "result, but provided ")
949 << op->getOperand(0).getType() << " and "
950 << op->getResult(0).getType();
951 }
952 return success();
953}
954
955static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state,
956 Value lhs, Value rhs) {
957 assert(lhs.getType() == rhs.getType())(static_cast <bool> (lhs.getType() == rhs.getType()) ? void
(0) : __assert_fail ("lhs.getType() == rhs.getType()", "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp"
, 957, __extension__ __PRETTY_FUNCTION__))
;
958
959 Type boolType = builder.getI1Type();
960 if (auto vecType = lhs.getType().dyn_cast<VectorType>())
961 boolType = VectorType::get(vecType.getShape(), boolType);
962 state.addTypes(boolType);
963
964 state.addOperands({lhs, rhs});
965}
966
967static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state,
968 Value value) {
969 Type boolType = builder.getI1Type();
970 if (auto vecType = value.getType().dyn_cast<VectorType>())
971 boolType = VectorType::get(vecType.getShape(), boolType);
972 state.addTypes(boolType);
973
974 state.addOperands(value);
975}
976
977//===----------------------------------------------------------------------===//
978// spv.AccessChainOp
979//===----------------------------------------------------------------------===//
980
981static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
982 auto ptrType = type.dyn_cast<spirv::PointerType>();
983 if (!ptrType) {
984 emitError(baseLoc, "'spv.AccessChain' op expected a pointer "
985 "to composite type, but provided ")
986 << type;
987 return nullptr;
988 }
989
990 auto resultType = ptrType.getPointeeType();
991 auto resultStorageClass = ptrType.getStorageClass();
992 int32_t index = 0;
993
994 for (auto indexSSA : indices) {
995 auto cType = resultType.dyn_cast<spirv::CompositeType>();
996 if (!cType) {
997 emitError(baseLoc,
998 "'spv.AccessChain' op cannot extract from non-composite type ")
999 << resultType << " with index " << index;
1000 return nullptr;
1001 }
1002 index = 0;
1003 if (resultType.isa<spirv::StructType>()) {
1004 Operation *op = indexSSA.getDefiningOp();
1005 if (!op) {
1006 emitError(baseLoc, "'spv.AccessChain' op index must be an "
1007 "integer spv.Constant to access "
1008 "element of spv.struct");
1009 return nullptr;
1010 }
1011
1012 // TODO: this should be relaxed to allow
1013 // integer literals of other bitwidths.
1014 if (failed(extractValueFromConstOp(op, index))) {
1015 emitError(baseLoc,
1016 "'spv.AccessChain' index must be an integer spv.Constant to "
1017 "access element of spv.struct, but provided ")
1018 << op->getName();
1019 return nullptr;
1020 }
1021 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1022 emitError(baseLoc, "'spv.AccessChain' op index ")
1023 << index << " out of bounds for " << resultType;
1024 return nullptr;
1025 }
1026 }
1027 resultType = cType.getElementType(index);
1028 }
1029 return spirv::PointerType::get(resultType, resultStorageClass);
1030}
1031
1032void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1033 Value basePtr, ValueRange indices) {
1034 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1035 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1035, __extension__
__PRETTY_FUNCTION__))
;
1036 build(builder, state, type, basePtr, indices);
1037}
1038
1039ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1040 OperationState &result) {
1041 OpAsmParser::UnresolvedOperand ptrInfo;
1042 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1043 Type type;
1044 auto loc = parser.getCurrentLocation();
1045 SmallVector<Type, 4> indicesTypes;
1046
1047 if (parser.parseOperand(ptrInfo) ||
1048 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1049 parser.parseColonType(type) ||
1050 parser.resolveOperand(ptrInfo, type, result.operands)) {
1051 return failure();
1052 }
1053
1054 // Check that the provided indices list is not empty before parsing their
1055 // type list.
1056 if (indicesInfo.empty()) {
1057 return mlir::emitError(result.location, "'spv.AccessChain' op expected at "
1058 "least one index ");
1059 }
1060
1061 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1062 return failure();
1063
1064 // Check that the indices types list is not empty and that it has a one-to-one
1065 // mapping to the provided indices.
1066 if (indicesTypes.size() != indicesInfo.size()) {
1067 return mlir::emitError(result.location,
1068 "'spv.AccessChain' op indices types' count must be "
1069 "equal to indices info count");
1070 }
1071
1072 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1073 return failure();
1074
1075 auto resultType = getElementPtrType(
1076 type, llvm::makeArrayRef(result.operands).drop_front(), result.location);
1077 if (!resultType) {
1078 return failure();
1079 }
1080
1081 result.addTypes(resultType);
1082 return success();
1083}
1084
1085template <typename Op>
1086static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1087 printer << ' ' << op.base_ptr() << '[' << indices
1088 << "] : " << op.base_ptr().getType() << ", " << indices.getTypes();
1089}
1090
1091void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1092 printAccessChain(*this, indices(), printer);
1093}
1094
1095template <typename Op>
1096static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1097 auto resultType = getElementPtrType(accessChainOp.base_ptr().getType(),
1098 indices, accessChainOp.getLoc());
1099 if (!resultType)
1100 return failure();
1101
1102 auto providedResultType =
1103 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1104 if (!providedResultType)
1105 return accessChainOp.emitOpError(
1106 "result type must be a pointer, but provided")
1107 << providedResultType;
1108
1109 if (resultType != providedResultType)
1110 return accessChainOp.emitOpError("invalid result type: expected ")
1111 << resultType << ", but provided " << providedResultType;
1112
1113 return success();
1114}
1115
1116LogicalResult spirv::AccessChainOp::verify() {
1117 return verifyAccessChain(*this, indices());
1118}
1119
1120//===----------------------------------------------------------------------===//
1121// spv.mlir.addressof
1122//===----------------------------------------------------------------------===//
1123
1124void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1125 spirv::GlobalVariableOp var) {
1126 build(builder, state, var.type(), SymbolRefAttr::get(var));
1127}
1128
1129LogicalResult spirv::AddressOfOp::verify() {
1130 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1131 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1132 variableAttr()));
1133 if (!varOp) {
1134 return emitOpError("expected spv.GlobalVariable symbol");
1135 }
1136 if (pointer().getType() != varOp.type()) {
1137 return emitOpError(
1138 "result type mismatch with the referenced global variable's type");
1139 }
1140 return success();
1141}
1142
1143template <typename T>
1144static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1145 printer << " \"" << stringifyScope(atomOp.memory_scope()) << "\" \""
1146 << stringifyMemorySemantics(atomOp.equal_semantics()) << "\" \""
1147 << stringifyMemorySemantics(atomOp.unequal_semantics()) << "\" "
1148 << atomOp.getOperands() << " : " << atomOp.pointer().getType();
1149}
1150
1151static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1152 OperationState &state) {
1153 spirv::Scope memoryScope;
1154 spirv::MemorySemantics equalSemantics, unequalSemantics;
1155 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1156 Type type;
1157 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1158 kMemoryScopeAttrName) ||
1159 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1160 equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1161 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1162 unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1163 parser.parseOperandList(operandInfo, 3))
1164 return failure();
1165
1166 auto loc = parser.getCurrentLocation();
1167 if (parser.parseColonType(type))
1168 return failure();
1169
1170 auto ptrType = type.dyn_cast<spirv::PointerType>();
1171 if (!ptrType)
1172 return parser.emitError(loc, "expected pointer type");
1173
1174 if (parser.resolveOperands(
1175 operandInfo,
1176 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1177 parser.getNameLoc(), state.operands))
1178 return failure();
1179
1180 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1181}
1182
1183template <typename T>
1184static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1185 // According to the spec:
1186 // "The type of Value must be the same as Result Type. The type of the value
1187 // pointed to by Pointer must be the same as Result Type. This type must also
1188 // match the type of Comparator."
1189 if (atomOp.getType() != atomOp.value().getType())
1190 return atomOp.emitOpError("value operand must have the same type as the op "
1191 "result, but found ")
1192 << atomOp.value().getType() << " vs " << atomOp.getType();
1193
1194 if (atomOp.getType() != atomOp.comparator().getType())
1195 return atomOp.emitOpError(
1196 "comparator operand must have the same type as the op "
1197 "result, but found ")
1198 << atomOp.comparator().getType() << " vs " << atomOp.getType();
1199
1200 Type pointeeType = atomOp.pointer()
1201 .getType()
1202 .template cast<spirv::PointerType>()
1203 .getPointeeType();
1204 if (atomOp.getType() != pointeeType)
1205 return atomOp.emitOpError(
1206 "pointer operand's pointee type must have the same "
1207 "as the op result type, but found ")
1208 << pointeeType << " vs " << atomOp.getType();
1209
1210 // TODO: Unequal cannot be set to Release or Acquire and Release.
1211 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1212
1213 return success();
1214}
1215
1216//===----------------------------------------------------------------------===//
1217// spv.AtomicAndOp
1218//===----------------------------------------------------------------------===//
1219
1220LogicalResult spirv::AtomicAndOp::verify() {
1221 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1222}
1223
1224ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1225 OperationState &result) {
1226 return ::parseAtomicUpdateOp(parser, result, true);
1227}
1228void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1229 ::printAtomicUpdateOp(*this, p);
1230}
1231
1232//===----------------------------------------------------------------------===//
1233// spv.AtomicCompareExchangeOp
1234//===----------------------------------------------------------------------===//
1235
1236LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1237 return ::verifyAtomicCompareExchangeImpl(*this);
1238}
1239
1240ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1241 OperationState &result) {
1242 return ::parseAtomicCompareExchangeImpl(parser, result);
1243}
1244void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1245 ::printAtomicCompareExchangeImpl(*this, p);
1246}
1247
1248//===----------------------------------------------------------------------===//
1249// spv.AtomicCompareExchangeWeakOp
1250//===----------------------------------------------------------------------===//
1251
1252LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1253 return ::verifyAtomicCompareExchangeImpl(*this);
1254}
1255
1256ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1257 OperationState &result) {
1258 return ::parseAtomicCompareExchangeImpl(parser, result);
1259}
1260void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1261 ::printAtomicCompareExchangeImpl(*this, p);
1262}
1263
1264//===----------------------------------------------------------------------===//
1265// spv.AtomicExchange
1266//===----------------------------------------------------------------------===//
1267
1268void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1269 printer << " \"" << stringifyScope(memory_scope()) << "\" \""
1270 << stringifyMemorySemantics(semantics()) << "\" " << getOperands()
1271 << " : " << pointer().getType();
1272}
1273
1274ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1275 OperationState &result) {
1276 spirv::Scope memoryScope;
1277 spirv::MemorySemantics semantics;
1278 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1279 Type type;
1280 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1281 kMemoryScopeAttrName) ||
1282 parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1283 kSemanticsAttrName) ||
1284 parser.parseOperandList(operandInfo, 2))
1285 return failure();
1286
1287 auto loc = parser.getCurrentLocation();
1288 if (parser.parseColonType(type))
1289 return failure();
1290
1291 auto ptrType = type.dyn_cast<spirv::PointerType>();
1292 if (!ptrType)
1293 return parser.emitError(loc, "expected pointer type");
1294
1295 if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1296 parser.getNameLoc(), result.operands))
1297 return failure();
1298
1299 return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1300}
1301
1302LogicalResult spirv::AtomicExchangeOp::verify() {
1303 if (getType() != value().getType())
1304 return emitOpError("value operand must have the same type as the op "
1305 "result, but found ")
1306 << value().getType() << " vs " << getType();
1307
1308 Type pointeeType =
1309 pointer().getType().cast<spirv::PointerType>().getPointeeType();
1310 if (getType() != pointeeType)
1311 return emitOpError("pointer operand's pointee type must have the same "
1312 "as the op result type, but found ")
1313 << pointeeType << " vs " << getType();
1314
1315 return success();
1316}
1317
1318//===----------------------------------------------------------------------===//
1319// spv.AtomicIAddOp
1320//===----------------------------------------------------------------------===//
1321
1322LogicalResult spirv::AtomicIAddOp::verify() {
1323 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1324}
1325
1326ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1327 OperationState &result) {
1328 return ::parseAtomicUpdateOp(parser, result, true);
1329}
1330void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1331 ::printAtomicUpdateOp(*this, p);
1332}
1333
1334//===----------------------------------------------------------------------===//
1335// spv.AtomicFAddEXTOp
1336//===----------------------------------------------------------------------===//
1337
1338LogicalResult spirv::AtomicFAddEXTOp::verify() {
1339 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1340}
1341
1342ParseResult spirv::AtomicFAddEXTOp::parse(OpAsmParser &parser,
1343 OperationState &result) {
1344 return ::parseAtomicUpdateOp(parser, result, true);
1345}
1346void spirv::AtomicFAddEXTOp::print(OpAsmPrinter &p) {
1347 ::printAtomicUpdateOp(*this, p);
1348}
1349
1350//===----------------------------------------------------------------------===//
1351// spv.AtomicIDecrementOp
1352//===----------------------------------------------------------------------===//
1353
1354LogicalResult spirv::AtomicIDecrementOp::verify() {
1355 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1356}
1357
1358ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1359 OperationState &result) {
1360 return ::parseAtomicUpdateOp(parser, result, false);
1361}
1362void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1363 ::printAtomicUpdateOp(*this, p);
1364}
1365
1366//===----------------------------------------------------------------------===//
1367// spv.AtomicIIncrementOp
1368//===----------------------------------------------------------------------===//
1369
1370LogicalResult spirv::AtomicIIncrementOp::verify() {
1371 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1372}
1373
1374ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1375 OperationState &result) {
1376 return ::parseAtomicUpdateOp(parser, result, false);
1377}
1378void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1379 ::printAtomicUpdateOp(*this, p);
1380}
1381
1382//===----------------------------------------------------------------------===//
1383// spv.AtomicISubOp
1384//===----------------------------------------------------------------------===//
1385
1386LogicalResult spirv::AtomicISubOp::verify() {
1387 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1388}
1389
1390ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1391 OperationState &result) {
1392 return ::parseAtomicUpdateOp(parser, result, true);
1393}
1394void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1395 ::printAtomicUpdateOp(*this, p);
1396}
1397
1398//===----------------------------------------------------------------------===//
1399// spv.AtomicOrOp
1400//===----------------------------------------------------------------------===//
1401
1402LogicalResult spirv::AtomicOrOp::verify() {
1403 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1404}
1405
1406ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1407 OperationState &result) {
1408 return ::parseAtomicUpdateOp(parser, result, true);
1409}
1410void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1411 ::printAtomicUpdateOp(*this, p);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// spv.AtomicSMaxOp
1416//===----------------------------------------------------------------------===//
1417
1418LogicalResult spirv::AtomicSMaxOp::verify() {
1419 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1420}
1421
1422ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1423 OperationState &result) {
1424 return ::parseAtomicUpdateOp(parser, result, true);
1425}
1426void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1427 ::printAtomicUpdateOp(*this, p);
1428}
1429
1430//===----------------------------------------------------------------------===//
1431// spv.AtomicSMinOp
1432//===----------------------------------------------------------------------===//
1433
1434LogicalResult spirv::AtomicSMinOp::verify() {
1435 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1436}
1437
1438ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1439 OperationState &result) {
1440 return ::parseAtomicUpdateOp(parser, result, true);
1441}
1442void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1443 ::printAtomicUpdateOp(*this, p);
1444}
1445
1446//===----------------------------------------------------------------------===//
1447// spv.AtomicUMaxOp
1448//===----------------------------------------------------------------------===//
1449
1450LogicalResult spirv::AtomicUMaxOp::verify() {
1451 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1452}
1453
1454ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1455 OperationState &result) {
1456 return ::parseAtomicUpdateOp(parser, result, true);
1457}
1458void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1459 ::printAtomicUpdateOp(*this, p);
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// spv.AtomicUMinOp
1464//===----------------------------------------------------------------------===//
1465
1466LogicalResult spirv::AtomicUMinOp::verify() {
1467 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1468}
1469
1470ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1471 OperationState &result) {
1472 return ::parseAtomicUpdateOp(parser, result, true);
1473}
1474void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1475 ::printAtomicUpdateOp(*this, p);
1476}
1477
1478//===----------------------------------------------------------------------===//
1479// spv.AtomicXorOp
1480//===----------------------------------------------------------------------===//
1481
1482LogicalResult spirv::AtomicXorOp::verify() {
1483 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1484}
1485
1486ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1487 OperationState &result) {
1488 return ::parseAtomicUpdateOp(parser, result, true);
1489}
1490void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1491 ::printAtomicUpdateOp(*this, p);
1492}
1493
1494//===----------------------------------------------------------------------===//
1495// spv.BitcastOp
1496//===----------------------------------------------------------------------===//
1497
1498LogicalResult spirv::BitcastOp::verify() {
1499 // TODO: The SPIR-V spec validation rules are different for different
1500 // versions.
1501 auto operandType = operand().getType();
1502 auto resultType = result().getType();
1503 if (operandType == resultType) {
1504 return emitError("result type must be different from operand type");
1505 }
1506 if (operandType.isa<spirv::PointerType>() &&
1507 !resultType.isa<spirv::PointerType>()) {
1508 return emitError(
1509 "unhandled bit cast conversion from pointer type to non-pointer type");
1510 }
1511 if (!operandType.isa<spirv::PointerType>() &&
1512 resultType.isa<spirv::PointerType>()) {
1513 return emitError(
1514 "unhandled bit cast conversion from non-pointer type to pointer type");
1515 }
1516 auto operandBitWidth = getBitWidth(operandType);
1517 auto resultBitWidth = getBitWidth(resultType);
1518 if (operandBitWidth != resultBitWidth) {
1519 return emitOpError("mismatch in result type bitwidth ")
1520 << resultBitWidth << " and operand type bitwidth "
1521 << operandBitWidth;
1522 }
1523 return success();
1524}
1525
1526//===----------------------------------------------------------------------===//
1527// spv.BranchOp
1528//===----------------------------------------------------------------------===//
1529
1530SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1531 assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index"
) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1531, __extension__
__PRETTY_FUNCTION__))
;
1532 return SuccessorOperands(0, targetOperandsMutable());
1533}
1534
1535//===----------------------------------------------------------------------===//
1536// spv.BranchConditionalOp
1537//===----------------------------------------------------------------------===//
1538
1539SuccessorOperands
1540spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1541 assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index"
) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1541, __extension__
__PRETTY_FUNCTION__))
;
1542 return SuccessorOperands(index == kTrueIndex ? trueTargetOperandsMutable()
1543 : falseTargetOperandsMutable());
1544}
1545
1546ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1547 OperationState &result) {
1548 auto &builder = parser.getBuilder();
1549 OpAsmParser::UnresolvedOperand condInfo;
1550 Block *dest;
1551
1552 // Parse the condition.
1553 Type boolTy = builder.getI1Type();
1554 if (parser.parseOperand(condInfo) ||
1555 parser.resolveOperand(condInfo, boolTy, result.operands))
1556 return failure();
1557
1558 // Parse the optional branch weights.
1559 if (succeeded(parser.parseOptionalLSquare())) {
1560 IntegerAttr trueWeight, falseWeight;
1561 NamedAttrList weights;
1562
1563 auto i32Type = builder.getIntegerType(32);
1564 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1565 parser.parseComma() ||
1566 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1567 parser.parseRSquare())
1568 return failure();
1569
1570 result.addAttribute(kBranchWeightAttrName,
1571 builder.getArrayAttr({trueWeight, falseWeight}));
1572 }
1573
1574 // Parse the true branch.
1575 SmallVector<Value, 4> trueOperands;
1576 if (parser.parseComma() ||
1577 parser.parseSuccessorAndUseList(dest, trueOperands))
1578 return failure();
1579 result.addSuccessors(dest);
1580 result.addOperands(trueOperands);
1581
1582 // Parse the false branch.
1583 SmallVector<Value, 4> falseOperands;
1584 if (parser.parseComma() ||
1585 parser.parseSuccessorAndUseList(dest, falseOperands))
1586 return failure();
1587 result.addSuccessors(dest);
1588 result.addOperands(falseOperands);
1589 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1590 builder.getDenseI32ArrayAttr(
1591 {1, static_cast<int32_t>(trueOperands.size()),
1592 static_cast<int32_t>(falseOperands.size())}));
1593
1594 return success();
1595}
1596
1597void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1598 printer << ' ' << condition();
1599
1600 if (auto weights = branch_weights()) {
1601 printer << " [";
1602 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1603 printer << a.cast<IntegerAttr>().getInt();
1604 });
1605 printer << "]";
1606 }
1607
1608 printer << ", ";
1609 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1610 printer << ", ";
1611 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1612}
1613
1614LogicalResult spirv::BranchConditionalOp::verify() {
1615 if (auto weights = branch_weights()) {
1616 if (weights->getValue().size() != 2) {
1617 return emitOpError("must have exactly two branch weights");
1618 }
1619 if (llvm::all_of(*weights, [](Attribute attr) {
1620 return attr.cast<IntegerAttr>().getValue().isNullValue();
1621 }))
1622 return emitOpError("branch weights cannot both be zero");
1623 }
1624
1625 return success();
1626}
1627
1628//===----------------------------------------------------------------------===//
1629// spv.CompositeConstruct
1630//===----------------------------------------------------------------------===//
1631
1632LogicalResult spirv::CompositeConstructOp::verify() {
1633 auto cType = getType().cast<spirv::CompositeType>();
1634 operand_range constituents = this->constituents();
1635
1636 if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1637 if (constituents.size() != 1)
1638 return emitOpError("has incorrect number of operands: expected ")
1639 << "1, but provided " << constituents.size();
1640 if (coopType.getElementType() != constituents.front().getType())
1641 return emitOpError("operand type mismatch: expected operand type ")
1642 << coopType.getElementType() << ", but provided "
1643 << constituents.front().getType();
1644 return success();
1645 }
1646
1647 if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1648 if (constituents.size() != 1)
1649 return emitOpError("has incorrect number of operands: expected ")
1650 << "1, but provided " << constituents.size();
1651 if (jointType.getElementType() != constituents.front().getType())
1652 return emitOpError("operand type mismatch: expected operand type ")
1653 << jointType.getElementType() << ", but provided "
1654 << constituents.front().getType();
1655 return success();
1656 }
1657
1658 if (constituents.size() == cType.getNumElements()) {
1659 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1660 if (constituents[index].getType() != cType.getElementType(index)) {
1661 return emitOpError("operand type mismatch: expected operand type ")
1662 << cType.getElementType(index) << ", but provided "
1663 << constituents[index].getType();
1664 }
1665 }
1666 return success();
1667 }
1668
1669 // If not constructing a cooperative matrix type, then we must be constructing
1670 // a vector type.
1671 auto resultType = cType.dyn_cast<VectorType>();
1672 if (!resultType)
1673 return emitOpError(
1674 "expected to return a vector or cooperative matrix when the number of "
1675 "constituents is less than what the result needs");
1676
1677 SmallVector<unsigned> sizes;
1678 for (Value component : constituents) {
1679 if (!component.getType().isa<VectorType>() &&
1680 !component.getType().isIntOrFloat())
1681 return emitOpError("operand type mismatch: expected operand to have "
1682 "a scalar or vector type, but provided ")
1683 << component.getType();
1684
1685 Type elementType = component.getType();
1686 if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1687 sizes.push_back(vectorType.getNumElements());
1688 elementType = vectorType.getElementType();
1689 } else {
1690 sizes.push_back(1);
1691 }
1692
1693 if (elementType != resultType.getElementType())
1694 return emitOpError("operand element type mismatch: expected to be ")
1695 << resultType.getElementType() << ", but provided " << elementType;
1696 }
1697 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1698 if (totalCount != cType.getNumElements())
1699 return emitOpError("has incorrect number of operands: expected ")
1700 << cType.getNumElements() << ", but provided " << totalCount;
1701 return success();
1702}
1703
1704//===----------------------------------------------------------------------===//
1705// spv.CompositeExtractOp
1706//===----------------------------------------------------------------------===//
1707
1708void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1709 Value composite,
1710 ArrayRef<int32_t> indices) {
1711 auto indexAttr = builder.getI32ArrayAttr(indices);
1712 auto elementType =
1713 getElementType(composite.getType(), indexAttr, state.location);
1714 if (!elementType) {
1715 return;
1716 }
1717 build(builder, state, elementType, composite, indexAttr);
1718}
1719
1720ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1721 OperationState &result) {
1722 OpAsmParser::UnresolvedOperand compositeInfo;
1723 Attribute indicesAttr;
1724 Type compositeType;
1725 SMLoc attrLocation;
1726
1727 if (parser.parseOperand(compositeInfo) ||
1728 parser.getCurrentLocation(&attrLocation) ||
1729 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1730 parser.parseColonType(compositeType) ||
1731 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1732 return failure();
1733 }
1734
1735 Type resultType =
1736 getElementType(compositeType, indicesAttr, parser, attrLocation);
1737 if (!resultType) {
1738 return failure();
1739 }
1740 result.addTypes(resultType);
1741 return success();
1742}
1743
1744void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1745 printer << ' ' << composite() << indices() << " : " << composite().getType();
1746}
1747
1748LogicalResult spirv::CompositeExtractOp::verify() {
1749 auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1750 auto resultType =
1751 getElementType(composite().getType(), indicesArrayAttr, getLoc());
1752 if (!resultType)
1753 return failure();
1754
1755 if (resultType != getType()) {
1756 return emitOpError("invalid result type: expected ")
1757 << resultType << " but provided " << getType();
1758 }
1759
1760 return success();
1761}
1762
1763//===----------------------------------------------------------------------===//
1764// spv.CompositeInsert
1765//===----------------------------------------------------------------------===//
1766
1767void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1768 Value object, Value composite,
1769 ArrayRef<int32_t> indices) {
1770 auto indexAttr = builder.getI32ArrayAttr(indices);
1771 build(builder, state, composite.getType(), object, composite, indexAttr);
1772}
1773
1774ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1775 OperationState &result) {
1776 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1777 Type objectType, compositeType;
1778 Attribute indicesAttr;
1779 auto loc = parser.getCurrentLocation();
1780
1781 return failure(
1782 parser.parseOperandList(operands, 2) ||
1783 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1784 parser.parseColonType(objectType) ||
1785 parser.parseKeywordType("into", compositeType) ||
1786 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1787 result.operands) ||
1788 parser.addTypesToList(compositeType, result.types));
1789}
1790
1791LogicalResult spirv::CompositeInsertOp::verify() {
1792 auto indicesArrayAttr = indices().dyn_cast<ArrayAttr>();
1793 auto objectType =
1794 getElementType(composite().getType(), indicesArrayAttr, getLoc());
1795 if (!objectType)
1796 return failure();
1797
1798 if (objectType != object().getType()) {
1799 return emitOpError("object operand type should be ")
1800 << objectType << ", but found " << object().getType();
1801 }
1802
1803 if (composite().getType() != getType()) {
1804 return emitOpError("result type should be the same as "
1805 "the composite type, but found ")
1806 << composite().getType() << " vs " << getType();
1807 }
1808
1809 return success();
1810}
1811
1812void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1813 printer << " " << object() << ", " << composite() << indices() << " : "
1814 << object().getType() << " into " << composite().getType();
1815}
1816
1817//===----------------------------------------------------------------------===//
1818// spv.Constant
1819//===----------------------------------------------------------------------===//
1820
1821ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1822 OperationState &result) {
1823 Attribute value;
1824 if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1825 return failure();
1826
1827 Type type = NoneType::get(parser.getContext());
1828 if (auto typedAttr = value.dyn_cast<TypedAttr>())
1829 type = typedAttr.getType();
1830 if (type.isa<NoneType, TensorType>()) {
1831 if (parser.parseColonType(type))
1832 return failure();
1833 }
1834
1835 return parser.addTypeToList(type, result.types);
1836}
1837
1838void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1839 printer << ' ' << value();
1840 if (getType().isa<spirv::ArrayType>())
1841 printer << " : " << getType();
1842}
1843
1844static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1845 Type opType) {
1846 if (value.isa<IntegerAttr, FloatAttr>()) {
1847 auto valueType = value.cast<TypedAttr>().getType();
1848 if (valueType != opType)
1849 return op.emitOpError("result type (")
1850 << opType << ") does not match value type (" << valueType << ")";
1851 return success();
1852 }
1853 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1854 auto valueType = value.cast<TypedAttr>().getType();
1855 if (valueType == opType)
1856 return success();
1857 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1858 auto shapedType = valueType.dyn_cast<ShapedType>();
1859 if (!arrayType)
1860 return op.emitOpError("result or element type (")
1861 << opType << ") does not match value type (" << valueType
1862 << "), must be the same or spv.array";
1863
1864 int numElements = arrayType.getNumElements();
1865 auto opElemType = arrayType.getElementType();
1866 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1867 numElements *= t.getNumElements();
1868 opElemType = t.getElementType();
1869 }
1870 if (!opElemType.isIntOrFloat())
1871 return op.emitOpError("only support nested array result type");
1872
1873 auto valueElemType = shapedType.getElementType();
1874 if (valueElemType != opElemType) {
1875 return op.emitOpError("result element type (")
1876 << opElemType << ") does not match value element type ("
1877 << valueElemType << ")";
1878 }
1879
1880 if (numElements != shapedType.getNumElements()) {
1881 return op.emitOpError("result number of elements (")
1882 << numElements << ") does not match value number of elements ("
1883 << shapedType.getNumElements() << ")";
1884 }
1885 return success();
1886 }
1887 if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
1888 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1889 if (!arrayType)
1890 return op.emitOpError("must have spv.array result type for array value");
1891 Type elemType = arrayType.getElementType();
1892 for (Attribute element : arrayAttr.getValue()) {
1893 // Verify array elements recursively.
1894 if (failed(verifyConstantType(op, element, elemType)))
1895 return failure();
1896 }
1897 return success();
1898 }
1899 return op.emitOpError("cannot have attribute: ") << value;
1900}
1901
1902LogicalResult spirv::ConstantOp::verify() {
1903 // ODS already generates checks to make sure the result type is valid. We just
1904 // need to additionally check that the value's attribute type is consistent
1905 // with the result type.
1906 return verifyConstantType(*this, valueAttr(), getType());
1907}
1908
1909bool spirv::ConstantOp::isBuildableWith(Type type) {
1910 // Must be valid SPIR-V type first.
1911 if (!type.isa<spirv::SPIRVType>())
1912 return false;
1913
1914 if (isa<SPIRVDialect>(type.getDialect())) {
1915 // TODO: support constant struct
1916 return type.isa<spirv::ArrayType>();
1917 }
1918
1919 return true;
1920}
1921
1922spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
1923 OpBuilder &builder) {
1924 if (auto intType = type.dyn_cast<IntegerType>()) {
1925 unsigned width = intType.getWidth();
1926 if (width == 1)
1927 return builder.create<spirv::ConstantOp>(loc, type,
1928 builder.getBoolAttr(false));
1929 return builder.create<spirv::ConstantOp>(
1930 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
1931 }
1932 if (auto floatType = type.dyn_cast<FloatType>()) {
1933 return builder.create<spirv::ConstantOp>(
1934 loc, type, builder.getFloatAttr(floatType, 0.0));
1935 }
1936 if (auto vectorType = type.dyn_cast<VectorType>()) {
1937 Type elemType = vectorType.getElementType();
1938 if (elemType.isa<IntegerType>()) {
1939 return builder.create<spirv::ConstantOp>(
1940 loc, type,
1941 DenseElementsAttr::get(vectorType,
1942 IntegerAttr::get(elemType, 0.0).getValue()));
1943 }
1944 if (elemType.isa<FloatType>()) {
1945 return builder.create<spirv::ConstantOp>(
1946 loc, type,
1947 DenseFPElementsAttr::get(vectorType,
1948 FloatAttr::get(elemType, 0.0).getValue()));
1949 }
1950 }
1951
1952 llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1952)
;
1953}
1954
1955spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
1956 OpBuilder &builder) {
1957 if (auto intType = type.dyn_cast<IntegerType>()) {
1958 unsigned width = intType.getWidth();
1959 if (width == 1)
1960 return builder.create<spirv::ConstantOp>(loc, type,
1961 builder.getBoolAttr(true));
1962 return builder.create<spirv::ConstantOp>(
1963 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
1964 }
1965 if (auto floatType = type.dyn_cast<FloatType>()) {
1966 return builder.create<spirv::ConstantOp>(
1967 loc, type, builder.getFloatAttr(floatType, 1.0));
1968 }
1969 if (auto vectorType = type.dyn_cast<VectorType>()) {
1970 Type elemType = vectorType.getElementType();
1971 if (elemType.isa<IntegerType>()) {
1972 return builder.create<spirv::ConstantOp>(
1973 loc, type,
1974 DenseElementsAttr::get(vectorType,
1975 IntegerAttr::get(elemType, 1.0).getValue()));
1976 }
1977 if (elemType.isa<FloatType>()) {
1978 return builder.create<spirv::ConstantOp>(
1979 loc, type,
1980 DenseFPElementsAttr::get(vectorType,
1981 FloatAttr::get(elemType, 1.0).getValue()));
1982 }
1983 }
1984
1985 llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1985)
;
1986}
1987
1988void mlir::spirv::ConstantOp::getAsmResultNames(
1989 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
1990 Type type = getType();
1991
1992 SmallString<32> specialNameBuffer;
1993 llvm::raw_svector_ostream specialName(specialNameBuffer);
1994 specialName << "cst";
1995
1996 IntegerType intTy = type.dyn_cast<IntegerType>();
1997
1998 if (IntegerAttr intCst = value().dyn_cast<IntegerAttr>()) {
1999 if (intTy && intTy.getWidth() == 1) {
2000 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2001 }
2002
2003 if (intTy.isSignless()) {
2004 specialName << intCst.getInt();
2005 } else {
2006 specialName << intCst.getSInt();
2007 }
2008 }
2009
2010 if (intTy || type.isa<FloatType>()) {
2011 specialName << '_' << type;
2012 }
2013
2014 if (auto vecType = type.dyn_cast<VectorType>()) {
2015 specialName << "_vec_";
2016 specialName << vecType.getDimSize(0);
2017
2018 Type elementType = vecType.getElementType();
2019
2020 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2021 specialName << "x" << elementType;
2022 }
2023 }
2024
2025 setNameFn(getResult(), specialName.str());
2026}
2027
2028void mlir::spirv::AddressOfOp::getAsmResultNames(
2029 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2030 SmallString<32> specialNameBuffer;
2031 llvm::raw_svector_ostream specialName(specialNameBuffer);
2032 specialName << variable() << "_addr";
2033 setNameFn(getResult(), specialName.str());
2034}
2035
2036//===----------------------------------------------------------------------===//
2037// spv.ControlBarrierOp
2038//===----------------------------------------------------------------------===//
2039
2040LogicalResult spirv::ControlBarrierOp::verify() {
2041 return verifyMemorySemantics(getOperation(), memory_semantics());
2042}
2043
2044//===----------------------------------------------------------------------===//
2045// spv.ConvertFToSOp
2046//===----------------------------------------------------------------------===//
2047
2048LogicalResult spirv::ConvertFToSOp::verify() {
2049 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2050 /*skipBitWidthCheck=*/true);
2051}
2052
2053//===----------------------------------------------------------------------===//
2054// spv.ConvertFToUOp
2055//===----------------------------------------------------------------------===//
2056
2057LogicalResult spirv::ConvertFToUOp::verify() {
2058 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2059 /*skipBitWidthCheck=*/true);
2060}
2061
2062//===----------------------------------------------------------------------===//
2063// spv.ConvertSToFOp
2064//===----------------------------------------------------------------------===//
2065
2066LogicalResult spirv::ConvertSToFOp::verify() {
2067 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2068 /*skipBitWidthCheck=*/true);
2069}
2070
2071//===----------------------------------------------------------------------===//
2072// spv.ConvertUToFOp
2073//===----------------------------------------------------------------------===//
2074
2075LogicalResult spirv::ConvertUToFOp::verify() {
2076 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2077 /*skipBitWidthCheck=*/true);
2078}
2079
2080//===----------------------------------------------------------------------===//
2081// spv.EntryPoint
2082//===----------------------------------------------------------------------===//
2083
2084void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2085 spirv::ExecutionModel executionModel,
2086 spirv::FuncOp function,
2087 ArrayRef<Attribute> interfaceVars) {
2088 build(builder, state,
2089 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2090 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2091}
2092
2093ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2094 OperationState &result) {
2095 spirv::ExecutionModel execModel;
2096 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2097 SmallVector<Type, 0> idTypes;
2098 SmallVector<Attribute, 4> interfaceVars;
2099
2100 FlatSymbolRefAttr fn;
2101 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2102 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2103 return failure();
2104 }
2105
2106 if (!parser.parseOptionalComma()) {
2107 // Parse the interface variables
2108 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2109 // The name of the interface variable attribute isnt important
2110 FlatSymbolRefAttr var;
2111 NamedAttrList attrs;
2112 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2113 return failure();
2114 interfaceVars.push_back(var);
2115 return success();
2116 }))
2117 return failure();
2118 }
2119 result.addAttribute(kInterfaceAttrName,
2120 parser.getBuilder().getArrayAttr(interfaceVars));
2121 return success();
2122}
2123
2124void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2125 printer << " \"" << stringifyExecutionModel(execution_model()) << "\" ";
2126 printer.printSymbolName(fn());
2127 auto interfaceVars = interface().getValue();
2128 if (!interfaceVars.empty()) {
2129 printer << ", ";
2130 llvm::interleaveComma(interfaceVars, printer);
2131 }
2132}
2133
2134LogicalResult spirv::EntryPointOp::verify() {
2135 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2136 // verification.
2137 return success();
2138}
2139
2140//===----------------------------------------------------------------------===//
2141// spv.ExecutionMode
2142//===----------------------------------------------------------------------===//
2143
2144void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2145 spirv::FuncOp function,
2146 spirv::ExecutionMode executionMode,
2147 ArrayRef<int32_t> params) {
2148 build(builder, state, SymbolRefAttr::get(function),
2149 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2150 builder.getI32ArrayAttr(params));
2151}
2152
2153ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2154 OperationState &result) {
2155 spirv::ExecutionMode execMode;
2156 Attribute fn;
2157 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2158 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2159 return failure();
2160 }
2161
2162 SmallVector<int32_t, 4> values;
2163 Type i32Type = parser.getBuilder().getIntegerType(32);
2164 while (!parser.parseOptionalComma()) {
2165 NamedAttrList attr;
2166 Attribute value;
2167 if (parser.parseAttribute(value, i32Type, "value", attr)) {
2168 return failure();
2169 }
2170 values.push_back(value.cast<IntegerAttr>().getInt());
2171 }
2172 result.addAttribute(kValuesAttrName,
2173 parser.getBuilder().getI32ArrayAttr(values));
2174 return success();
2175}
2176
2177void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2178 printer << " ";
2179 printer.printSymbolName(fn());
2180 printer << " \"" << stringifyExecutionMode(execution_mode()) << "\"";
2181 auto values = this->values();
2182 if (values.empty())
2183 return;
2184 printer << ", ";
2185 llvm::interleaveComma(values, printer, [&](Attribute a) {
2186 printer << a.cast<IntegerAttr>().getInt();
2187 });
2188}
2189
2190//===----------------------------------------------------------------------===//
2191// spv.FConvertOp
2192//===----------------------------------------------------------------------===//
2193
2194LogicalResult spirv::FConvertOp::verify() {
2195 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2196}
2197
2198//===----------------------------------------------------------------------===//
2199// spv.SConvertOp
2200//===----------------------------------------------------------------------===//
2201
2202LogicalResult spirv::SConvertOp::verify() {
2203 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2204}
2205
2206//===----------------------------------------------------------------------===//
2207// spv.UConvertOp
2208//===----------------------------------------------------------------------===//
2209
2210LogicalResult spirv::UConvertOp::verify() {
2211 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2212}
2213
2214//===----------------------------------------------------------------------===//
2215// spv.func
2216//===----------------------------------------------------------------------===//
2217
2218ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2219 SmallVector<OpAsmParser::Argument> entryArgs;
2220 SmallVector<DictionaryAttr> resultAttrs;
2221 SmallVector<Type> resultTypes;
2222 auto &builder = parser.getBuilder();
2223
2224 // Parse the name as a symbol.
2225 StringAttr nameAttr;
2226 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2227 result.attributes))
2228 return failure();
2229
2230 // Parse the function signature.
2231 bool isVariadic = false;
2232 if (function_interface_impl::parseFunctionSignature(
2233 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2234 resultAttrs))
2235 return failure();
2236
2237 SmallVector<Type> argTypes;
2238 for (auto &arg : entryArgs)
2239 argTypes.push_back(arg.type);
2240 auto fnType = builder.getFunctionType(argTypes, resultTypes);
2241 result.addAttribute(FunctionOpInterface::getTypeAttrName(),
2242 TypeAttr::get(fnType));
2243
2244 // Parse the optional function control keyword.
2245 spirv::FunctionControl fnControl;
2246 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2247 return failure();
2248
2249 // If additional attributes are present, parse them.
2250 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2251 return failure();
2252
2253 // Add the attributes to the function arguments.
2254 assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes.
size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2254, __extension__
__PRETTY_FUNCTION__))
;
2255 function_interface_impl::addArgAndResultAttrs(builder, result, entryArgs,
2256 resultAttrs);
2257
2258 // Parse the optional function body.
2259 auto *body = result.addRegion();
2260 OptionalParseResult parseResult =
2261 parser.parseOptionalRegion(*body, entryArgs);
2262 return failure(parseResult.has_value() && failed(*parseResult));
2263}
2264
2265void spirv::FuncOp::print(OpAsmPrinter &printer) {
2266 // Print function name, signature, and control.
2267 printer << " ";
2268 printer.printSymbolName(sym_name());
2269 auto fnType = getFunctionType();
2270 function_interface_impl::printFunctionSignature(
2271 printer, *this, fnType.getInputs(),
2272 /*isVariadic=*/false, fnType.getResults());
2273 printer << " \"" << spirv::stringifyFunctionControl(function_control())
2274 << "\"";
2275 function_interface_impl::printFunctionAttributes(
2276 printer, *this, fnType.getNumInputs(), fnType.getNumResults(),
2277 {spirv::attributeName<spirv::FunctionControl>()});
2278
2279 // Print the body if this is not an external function.
2280 Region &body = this->body();
2281 if (!body.empty()) {
2282 printer << ' ';
2283 printer.printRegion(body, /*printEntryBlockArgs=*/false,
2284 /*printBlockTerminators=*/true);
2285 }
2286}
2287
2288LogicalResult spirv::FuncOp::verifyType() {
2289 auto type = getFunctionTypeAttr().getValue();
2290 if (!type.isa<FunctionType>())
2291 return emitOpError("requires '" + getTypeAttrName() +
2292 "' attribute of function type");
2293 if (getFunctionType().getNumResults() > 1)
2294 return emitOpError("cannot have more than one result");
2295 return success();
2296}
2297
2298LogicalResult spirv::FuncOp::verifyBody() {
2299 FunctionType fnType = getFunctionType();
2300
2301 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2302 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2303 if (fnType.getNumResults() != 0)
2304 return retOp.emitOpError("cannot be used in functions returning value");
2305 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2306 if (fnType.getNumResults() != 1)
2307 return retOp.emitOpError(
2308 "returns 1 value but enclosing function requires ")
2309 << fnType.getNumResults() << " results";
2310
2311 auto retOperandType = retOp.value().getType();
2312 auto fnResultType = fnType.getResult(0);
2313 if (retOperandType != fnResultType)
2314 return retOp.emitOpError(" return value's type (")
2315 << retOperandType << ") mismatch with function's result type ("
2316 << fnResultType << ")";
2317 }
2318 return WalkResult::advance();
2319 });
2320
2321 // TODO: verify other bits like linkage type.
2322
2323 return failure(walkResult.wasInterrupted());
2324}
2325
2326void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2327 StringRef name, FunctionType type,
2328 spirv::FunctionControl control,
2329 ArrayRef<NamedAttribute> attrs) {
2330 state.addAttribute(SymbolTable::getSymbolAttrName(),
2331 builder.getStringAttr(name));
2332 state.addAttribute(getTypeAttrName(), TypeAttr::get(type));
2333 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2334 builder.getAttr<spirv::FunctionControlAttr>(control));
2335 state.attributes.append(attrs.begin(), attrs.end());
2336 state.addRegion();
2337}
2338
2339// CallableOpInterface
2340Region *spirv::FuncOp::getCallableRegion() {
2341 return isExternal() ? nullptr : &body();
2342}
2343
2344// CallableOpInterface
2345ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2346 return getFunctionType().getResults();
2347}
2348
2349//===----------------------------------------------------------------------===//
2350// spv.FunctionCall
2351//===----------------------------------------------------------------------===//
2352
2353LogicalResult spirv::FunctionCallOp::verify() {
2354 auto fnName = calleeAttr();
2355
2356 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2357 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2358 if (!funcOp) {
2359 return emitOpError("callee function '")
2360 << fnName.getValue() << "' not found in nearest symbol table";
2361 }
2362
2363 auto functionType = funcOp.getFunctionType();
2364
2365 if (getNumResults() > 1) {
2366 return emitOpError(
2367 "expected callee function to have 0 or 1 result, but provided ")
2368 << getNumResults();
2369 }
2370
2371 if (functionType.getNumInputs() != getNumOperands()) {
2372 return emitOpError("has incorrect number of operands for callee: expected ")
2373 << functionType.getNumInputs() << ", but provided "
2374 << getNumOperands();
2375 }
2376
2377 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2378 if (getOperand(i).getType() != functionType.getInput(i)) {
2379 return emitOpError("operand type mismatch: expected operand type ")
2380 << functionType.getInput(i) << ", but provided "
2381 << getOperand(i).getType() << " for operand number " << i;
2382 }
2383 }
2384
2385 if (functionType.getNumResults() != getNumResults()) {
2386 return emitOpError(
2387 "has incorrect number of results has for callee: expected ")
2388 << functionType.getNumResults() << ", but provided "
2389 << getNumResults();
2390 }
2391
2392 if (getNumResults() &&
2393 (getResult(0).getType() != functionType.getResult(0))) {
2394 return emitOpError("result type mismatch: expected ")
2395 << functionType.getResult(0) << ", but provided "
2396 << getResult(0).getType();
2397 }
2398
2399 return success();
2400}
2401
2402CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2403 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2404}
2405
2406Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2407 return arguments();
2408}
2409
2410//===----------------------------------------------------------------------===//
2411// spv.GLFClampOp
2412//===----------------------------------------------------------------------===//
2413
2414ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2415 OperationState &result) {
2416 return parseOneResultSameOperandTypeOp(parser, result);
2417}
2418void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2419
2420//===----------------------------------------------------------------------===//
2421// spv.GLUClampOp
2422//===----------------------------------------------------------------------===//
2423
2424ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2425 OperationState &result) {
2426 return parseOneResultSameOperandTypeOp(parser, result);
2427}
2428void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2429
2430//===----------------------------------------------------------------------===//
2431// spv.GLSClampOp
2432//===----------------------------------------------------------------------===//
2433
2434ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2435 OperationState &result) {
2436 return parseOneResultSameOperandTypeOp(parser, result);
2437}
2438void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2439
2440//===----------------------------------------------------------------------===//
2441// spv.GLFmaOp
2442//===----------------------------------------------------------------------===//
2443
2444ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2445 return parseOneResultSameOperandTypeOp(parser, result);
2446}
2447void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2448
2449//===----------------------------------------------------------------------===//
2450// spv.GlobalVariable
2451//===----------------------------------------------------------------------===//
2452
2453void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2454 Type type, StringRef name,
2455 unsigned descriptorSet, unsigned binding) {
2456 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2457 state.addAttribute(
2458 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2459 builder.getI32IntegerAttr(descriptorSet));
2460 state.addAttribute(
2461 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2462 builder.getI32IntegerAttr(binding));
2463}
2464
2465void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2466 Type type, StringRef name,
2467 spirv::BuiltIn builtin) {
2468 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2469 state.addAttribute(
2470 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2471 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2472}
2473
2474ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2475 OperationState &result) {
2476 // Parse variable name.
2477 StringAttr nameAttr;
2478 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2479 result.attributes)) {
2480 return failure();
2481 }
2482
2483 // Parse optional initializer
2484 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2485 FlatSymbolRefAttr initSymbol;
2486 if (parser.parseLParen() ||
2487 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2488 result.attributes) ||
2489 parser.parseRParen())
2490 return failure();
2491 }
2492
2493 if (parseVariableDecorations(parser, result)) {
2494 return failure();
2495 }
2496
2497 Type type;
2498 auto loc = parser.getCurrentLocation();
2499 if (parser.parseColonType(type)) {
2500 return failure();
2501 }
2502 if (!type.isa<spirv::PointerType>()) {
2503 return parser.emitError(loc, "expected spv.ptr type");
2504 }
2505 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2506
2507 return success();
2508}
2509
2510void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2511 SmallVector<StringRef, 4> elidedAttrs{
2512 spirv::attributeName<spirv::StorageClass>()};
2513
2514 // Print variable name.
2515 printer << ' ';
2516 printer.printSymbolName(sym_name());
2517 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2518
2519 // Print optional initializer
2520 if (auto initializer = this->initializer()) {
2521 printer << " " << kInitializerAttrName << '(';
2522 printer.printSymbolName(*initializer);
2523 printer << ')';
2524 elidedAttrs.push_back(kInitializerAttrName);
2525 }
2526
2527 elidedAttrs.push_back(kTypeAttrName);
2528 printVariableDecorations(*this, printer, elidedAttrs);
2529 printer << " : " << type();
2530}
2531
2532LogicalResult spirv::GlobalVariableOp::verify() {
2533 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2534 // object. It cannot be Generic. It must be the same as the Storage Class
2535 // operand of the Result Type."
2536 // Also, Function storage class is reserved by spv.Variable.
2537 auto storageClass = this->storageClass();
2538 if (storageClass == spirv::StorageClass::Generic ||
2539 storageClass == spirv::StorageClass::Function) {
2540 return emitOpError("storage class cannot be '")
2541 << stringifyStorageClass(storageClass) << "'";
2542 }
2543
2544 if (auto init =
2545 (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2546 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2547 (*this)->getParentOp(), init.getAttr());
2548 // TODO: Currently only variable initialization with specialization
2549 // constants and other variables is supported. They could be normal
2550 // constants in the module scope as well.
2551 if (!initOp ||
2552 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2553 return emitOpError("initializer must be result of a "
2554 "spv.SpecConstant or spv.GlobalVariable op");
2555 }
2556 }
2557
2558 return success();
2559}
2560
2561//===----------------------------------------------------------------------===//
2562// spv.GroupBroadcast
2563//===----------------------------------------------------------------------===//
2564
2565LogicalResult spirv::GroupBroadcastOp::verify() {
2566 spirv::Scope scope = execution_scope();
2567 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2568 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2569
2570 if (auto localIdTy = localid().getType().dyn_cast<VectorType>())
2571 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2572 return emitOpError("localid is a vector and can be with only "
2573 " 2 or 3 components, actual number is ")
2574 << localIdTy.getNumElements();
2575
2576 return success();
2577}
2578
2579//===----------------------------------------------------------------------===//
2580// spv.GroupNonUniformBallotOp
2581//===----------------------------------------------------------------------===//
2582
2583LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2584 spirv::Scope scope = execution_scope();
2585 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2586 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2587
2588 return success();
2589}
2590
2591//===----------------------------------------------------------------------===//
2592// spv.GroupNonUniformBroadcast
2593//===----------------------------------------------------------------------===//
2594
2595LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2596 spirv::Scope scope = execution_scope();
2597 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2598 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2599
2600 // SPIR-V spec: "Before version 1.5, Id must come from a
2601 // constant instruction.
2602 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2603 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2604 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2605
2606 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2607 auto *idOp = id().getDefiningOp();
2608 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2609 spirv::ReferenceOfOp>(idOp)) // for spec constant
2610 return emitOpError("id must be the result of a constant op");
2611 }
2612
2613 return success();
2614}
2615
2616//===----------------------------------------------------------------------===//
2617// spv.GroupNonUniformShuffle*
2618//===----------------------------------------------------------------------===//
2619
2620template <typename OpTy>
2621static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
2622 spirv::Scope scope = op.execution_scope();
2623 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2624 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2625
2626 if (op.getOperands().back().getType().isSignedInteger())
2627 return op.emitOpError("second operand must be a singless/unsigned integer");
2628
2629 return success();
2630}
2631
2632LogicalResult spirv::GroupNonUniformShuffleOp::verify() {
2633 return verifyGroupNonUniformShuffleOp(*this);
2634}
2635LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() {
2636 return verifyGroupNonUniformShuffleOp(*this);
2637}
2638LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() {
2639 return verifyGroupNonUniformShuffleOp(*this);
2640}
2641LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
2642 return verifyGroupNonUniformShuffleOp(*this);
2643}
2644
2645//===----------------------------------------------------------------------===//
2646// spv.SubgroupBlockReadINTEL
2647//===----------------------------------------------------------------------===//
2648
2649ParseResult spirv::SubgroupBlockReadINTELOp::parse(OpAsmParser &parser,
2650 OperationState &result) {
2651 // Parse the storage class specification
2652 spirv::StorageClass storageClass;
1
'storageClass' declared without an initial value
2653 OpAsmParser::UnresolvedOperand ptrInfo;
2654 Type elementType;
2655 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2
Calling 'parseEnumStrAttr<mlir::spirv::StorageClass>'
7
Returning from 'parseEnumStrAttr<mlir::spirv::StorageClass>'
8
Taking false branch
2656 parser.parseColon() || parser.parseType(elementType)) {
2657 return failure();
2658 }
2659
2660 auto ptrType = spirv::PointerType::get(elementType, storageClass);
9
2nd function call argument is an uninitialized value
2661 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2662 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2663
2664 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2665 return failure();
2666 }
2667
2668 result.addTypes(elementType);
2669 return success();
2670}
2671
2672void spirv::SubgroupBlockReadINTELOp::print(OpAsmPrinter &printer) {
2673 printer << " " << ptr() << " : " << getType();
2674}
2675
2676LogicalResult spirv::SubgroupBlockReadINTELOp::verify() {
2677 if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2678 return failure();
2679
2680 return success();
2681}
2682
2683//===----------------------------------------------------------------------===//
2684// spv.SubgroupBlockWriteINTEL
2685//===----------------------------------------------------------------------===//
2686
2687ParseResult spirv::SubgroupBlockWriteINTELOp::parse(OpAsmParser &parser,
2688 OperationState &result) {
2689 // Parse the storage class specification
2690 spirv::StorageClass storageClass;
2691 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2692 auto loc = parser.getCurrentLocation();
2693 Type elementType;
2694 if (parseEnumStrAttr(storageClass, parser) ||
2695 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2696 parser.parseType(elementType)) {
2697 return failure();
2698 }
2699
2700 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2701 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2702 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2703
2704 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2705 result.operands)) {
2706 return failure();
2707 }
2708 return success();
2709}
2710
2711void spirv::SubgroupBlockWriteINTELOp::print(OpAsmPrinter &printer) {
2712 printer << " " << ptr() << ", " << value() << " : " << value().getType();
2713}
2714
2715LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() {
2716 if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value())))
2717 return failure();
2718
2719 return success();
2720}
2721
2722//===----------------------------------------------------------------------===//
2723// spv.GroupNonUniformElectOp
2724//===----------------------------------------------------------------------===//
2725
2726LogicalResult spirv::GroupNonUniformElectOp::verify() {
2727 spirv::Scope scope = execution_scope();
2728 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2729 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2730
2731 return success();
2732}
2733
2734//===----------------------------------------------------------------------===//
2735// spv.GroupNonUniformFAddOp
2736//===----------------------------------------------------------------------===//
2737
2738LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2739 return verifyGroupNonUniformArithmeticOp(*this);
2740}
2741
2742ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2743 OperationState &result) {
2744 return parseGroupNonUniformArithmeticOp(parser, result);
2745}
2746void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2747 printGroupNonUniformArithmeticOp(*this, p);
2748}
2749
2750//===----------------------------------------------------------------------===//
2751// spv.GroupNonUniformFMaxOp
2752//===----------------------------------------------------------------------===//
2753
2754LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2755 return verifyGroupNonUniformArithmeticOp(*this);
2756}
2757
2758ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2759 OperationState &result) {
2760 return parseGroupNonUniformArithmeticOp(parser, result);
2761}
2762void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2763 printGroupNonUniformArithmeticOp(*this, p);
2764}
2765
2766//===----------------------------------------------------------------------===//
2767// spv.GroupNonUniformFMinOp
2768//===----------------------------------------------------------------------===//
2769
2770LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2771 return verifyGroupNonUniformArithmeticOp(*this);
2772}
2773
2774ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2775 OperationState &result) {
2776 return parseGroupNonUniformArithmeticOp(parser, result);
2777}
2778void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2779 printGroupNonUniformArithmeticOp(*this, p);
2780}
2781
2782//===----------------------------------------------------------------------===//
2783// spv.GroupNonUniformFMulOp
2784//===----------------------------------------------------------------------===//
2785
2786LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2787 return verifyGroupNonUniformArithmeticOp(*this);
2788}
2789
2790ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2791 OperationState &result) {
2792 return parseGroupNonUniformArithmeticOp(parser, result);
2793}
2794void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2795 printGroupNonUniformArithmeticOp(*this, p);
2796}
2797
2798//===----------------------------------------------------------------------===//
2799// spv.GroupNonUniformIAddOp
2800//===----------------------------------------------------------------------===//
2801
2802LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2803 return verifyGroupNonUniformArithmeticOp(*this);
2804}
2805
2806ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2807 OperationState &result) {
2808 return parseGroupNonUniformArithmeticOp(parser, result);
2809}
2810void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2811 printGroupNonUniformArithmeticOp(*this, p);
2812}
2813
2814//===----------------------------------------------------------------------===//
2815// spv.GroupNonUniformIMulOp
2816//===----------------------------------------------------------------------===//
2817
2818LogicalResult spirv::GroupNonUniformIMulOp::verify() {
2819 return verifyGroupNonUniformArithmeticOp(*this);
2820}
2821
2822ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
2823 OperationState &result) {
2824 return parseGroupNonUniformArithmeticOp(parser, result);
2825}
2826void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
2827 printGroupNonUniformArithmeticOp(*this, p);
2828}
2829
2830//===----------------------------------------------------------------------===//
2831// spv.GroupNonUniformSMaxOp
2832//===----------------------------------------------------------------------===//
2833
2834LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
2835 return verifyGroupNonUniformArithmeticOp(*this);
2836}
2837
2838ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
2839 OperationState &result) {
2840 return parseGroupNonUniformArithmeticOp(parser, result);
2841}
2842void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
2843 printGroupNonUniformArithmeticOp(*this, p);
2844}
2845
2846//===----------------------------------------------------------------------===//
2847// spv.GroupNonUniformSMinOp
2848//===----------------------------------------------------------------------===//
2849
2850LogicalResult spirv::GroupNonUniformSMinOp::verify() {
2851 return verifyGroupNonUniformArithmeticOp(*this);
2852}
2853
2854ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
2855 OperationState &result) {
2856 return parseGroupNonUniformArithmeticOp(parser, result);
2857}
2858void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
2859 printGroupNonUniformArithmeticOp(*this, p);
2860}
2861
2862//===----------------------------------------------------------------------===//
2863// spv.GroupNonUniformUMaxOp
2864//===----------------------------------------------------------------------===//
2865
2866LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
2867 return verifyGroupNonUniformArithmeticOp(*this);
2868}
2869
2870ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
2871 OperationState &result) {
2872 return parseGroupNonUniformArithmeticOp(parser, result);
2873}
2874void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
2875 printGroupNonUniformArithmeticOp(*this, p);
2876}
2877
2878//===----------------------------------------------------------------------===//
2879// spv.GroupNonUniformUMinOp
2880//===----------------------------------------------------------------------===//
2881
2882LogicalResult spirv::GroupNonUniformUMinOp::verify() {
2883 return verifyGroupNonUniformArithmeticOp(*this);
2884}
2885
2886ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
2887 OperationState &result) {
2888 return parseGroupNonUniformArithmeticOp(parser, result);
2889}
2890void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
2891 printGroupNonUniformArithmeticOp(*this, p);
2892}
2893
2894//===----------------------------------------------------------------------===//
2895// spv.IAddCarryOp
2896//===----------------------------------------------------------------------===//
2897
2898LogicalResult spirv::IAddCarryOp::verify() {
2899 auto resultType = getType().cast<spirv::StructType>();
2900 if (resultType.getNumElements() != 2)
2901 return emitOpError("expected result struct type containing two members");
2902
2903 if (!llvm::all_equal({operand1().getType(), operand2().getType(),
2904 resultType.getElementType(0),
2905 resultType.getElementType(1)}))
2906 return emitOpError(
2907 "expected all operand types and struct member types are the same");
2908
2909 return success();
2910}
2911
2912ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
2913 OperationState &result) {
2914 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2915 if (parser.parseOptionalAttrDict(result.attributes) ||
2916 parser.parseOperandList(operands) || parser.parseColon())
2917 return failure();
2918
2919 Type resultType;
2920 SMLoc loc = parser.getCurrentLocation();
2921 if (parser.parseType(resultType))
2922 return failure();
2923
2924 auto structType = resultType.dyn_cast<spirv::StructType>();
2925 if (!structType || structType.getNumElements() != 2)
2926 return parser.emitError(loc, "expected spv.struct type with two members");
2927
2928 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2929 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
2930 return failure();
2931
2932 result.addTypes(resultType);
2933 return success();
2934}
2935
2936void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
2937 printer << ' ';
2938 printer.printOptionalAttrDict((*this)->getAttrs());
2939 printer.printOperands((*this)->getOperands());
2940 printer << " : " << getType();
2941}
2942
2943//===----------------------------------------------------------------------===//
2944// spv.ISubBorrowOp
2945//===----------------------------------------------------------------------===//
2946
2947LogicalResult spirv::ISubBorrowOp::verify() {
2948 auto resultType = getType().cast<spirv::StructType>();
2949 if (resultType.getNumElements() != 2)
2950 return emitOpError("expected result struct type containing two members");
2951
2952 if (!llvm::all_equal({operand1().getType(), operand2().getType(),
2953 resultType.getElementType(0),
2954 resultType.getElementType(1)}))
2955 return emitOpError(
2956 "expected all operand types and struct member types are the same");
2957
2958 return success();
2959}
2960
2961ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
2962 OperationState &result) {
2963 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
2964 if (parser.parseOptionalAttrDict(result.attributes) ||
2965 parser.parseOperandList(operands) || parser.parseColon())
2966 return failure();
2967
2968 Type resultType;
2969 auto loc = parser.getCurrentLocation();
2970 if (parser.parseType(resultType))
2971 return failure();
2972
2973 auto structType = resultType.dyn_cast<spirv::StructType>();
2974 if (!structType || structType.getNumElements() != 2)
2975 return parser.emitError(loc, "expected spv.struct type with two members");
2976
2977 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
2978 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
2979 return failure();
2980
2981 result.addTypes(resultType);
2982 return success();
2983}
2984
2985void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
2986 printer << ' ';
2987 printer.printOptionalAttrDict((*this)->getAttrs());
2988 printer.printOperands((*this)->getOperands());
2989 printer << " : " << getType();
2990}
2991
2992//===----------------------------------------------------------------------===//
2993// spv.LoadOp
2994//===----------------------------------------------------------------------===//
2995
2996void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
2997 Value basePtr, MemoryAccessAttr memoryAccess,
2998 IntegerAttr alignment) {
2999 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3000 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3001 alignment);
3002}
3003
3004ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3005 // Parse the storage class specification
3006 spirv::StorageClass storageClass;
3007 OpAsmParser::UnresolvedOperand ptrInfo;
3008 Type elementType;
3009 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3010 parseMemoryAccessAttributes(parser, result) ||
3011 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3012 parser.parseType(elementType)) {
3013 return failure();
3014 }
3015
3016 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3017 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3018 return failure();
3019 }
3020
3021 result.addTypes(elementType);
3022 return success();
3023}
3024
3025void spirv::LoadOp::print(OpAsmPrinter &printer) {
3026 SmallVector<StringRef, 4> elidedAttrs;
3027 StringRef sc = stringifyStorageClass(
3028 ptr().getType().cast<spirv::PointerType>().getStorageClass());
3029 printer << " \"" << sc << "\" " << ptr();
3030
3031 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3032
3033 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3034 printer << " : " << getType();
3035}
3036
3037LogicalResult spirv::LoadOp::verify() {
3038 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3039 // type with fixed size; i.e., it cannot be, nor include, any
3040 // OpTypeRuntimeArray types."
3041 if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) {
3042 return failure();
3043 }
3044 return verifyMemoryAccessAttribute(*this);
3045}
3046
3047//===----------------------------------------------------------------------===//
3048// spv.mlir.loop
3049//===----------------------------------------------------------------------===//
3050
3051void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3052 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3053 spirv::LoopControl::None));
3054 state.addRegion();
3055}
3056
3057ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3058 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3059 result))
3060 return failure();
3061 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3062}
3063
3064void spirv::LoopOp::print(OpAsmPrinter &printer) {
3065 auto control = loop_control();
3066 if (control != spirv::LoopControl::None)
3067 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3068 printer << ' ';
3069 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3070 /*printBlockTerminators=*/true);
3071}
3072
3073/// Returns true if the given `srcBlock` contains only one `spv.Branch` to the
3074/// given `dstBlock`.
3075static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3076 // Check that there is only one op in the `srcBlock`.
3077 if (!llvm::hasSingleElement(srcBlock))
3078 return false;
3079
3080 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3081 return branchOp && branchOp.getSuccessor() == &dstBlock;
3082}
3083
3084LogicalResult spirv::LoopOp::verifyRegions() {
3085 auto *op = getOperation();
3086
3087 // We need to verify that the blocks follow the following layout:
3088 //
3089 // +-------------+
3090 // | entry block |
3091 // +-------------+
3092 // |
3093 // v
3094 // +-------------+
3095 // | loop header | <-----+
3096 // +-------------+ |
3097 // |
3098 // ... |
3099 // \ | / |
3100 // v |
3101 // +---------------+ |
3102 // | loop continue | -----+
3103 // +---------------+
3104 //
3105 // ...
3106 // \ | /
3107 // v
3108 // +-------------+
3109 // | merge block |
3110 // +-------------+
3111
3112 auto &region = op->getRegion(0);
3113 // Allow empty region as a degenerated case, which can come from
3114 // optimizations.
3115 if (region.empty())
3116 return success();
3117
3118 // The last block is the merge block.
3119 Block &merge = region.back();
3120 if (!isMergeBlock(merge))
3121 return emitOpError(
3122 "last block must be the merge block with only one 'spv.mlir.merge' op");
3123
3124 if (std::next(region.begin()) == region.end())
3125 return emitOpError(
3126 "must have an entry block branching to the loop header block");
3127 // The first block is the entry block.
3128 Block &entry = region.front();
3129
3130 if (std::next(region.begin(), 2) == region.end())
3131 return emitOpError(
3132 "must have a loop header block branched from the entry block");
3133 // The second block is the loop header block.
3134 Block &header = *std::next(region.begin(), 1);
3135
3136 if (!hasOneBranchOpTo(entry, header))
3137 return emitOpError(
3138 "entry block must only have one 'spv.Branch' op to the second block");
3139
3140 if (std::next(region.begin(), 3) == region.end())
3141 return emitOpError(
3142 "requires a loop continue block branching to the loop header block");
3143 // The second to last block is the loop continue block.
3144 Block &cont = *std::prev(region.end(), 2);
3145
3146 // Make sure that we have a branch from the loop continue block to the loop
3147 // header block.
3148 if (llvm::none_of(
3149 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3150 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3151 return emitOpError("second to last block must be the loop continue "
3152 "block that branches to the loop header block");
3153
3154 // Make sure that no other blocks (except the entry and loop continue block)
3155 // branches to the loop header block.
3156 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3157 std::prev(region.end(), 2))) {
3158 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3159 if (block.getSuccessor(i) == &header) {
3160 return emitOpError("can only have the entry and loop continue "
3161 "block branching to the loop header block");
3162 }
3163 }
3164 }
3165
3166 return success();
3167}
3168
3169Block *spirv::LoopOp::getEntryBlock() {
3170 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3170, __extension__
__PRETTY_FUNCTION__))
;
3171 return &body().front();
3172}
3173
3174Block *spirv::LoopOp::getHeaderBlock() {
3175 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3175, __extension__
__PRETTY_FUNCTION__))
;
3176 // The second block is the loop header block.
3177 return &*std::next(body().begin());
3178}
3179
3180Block *spirv::LoopOp::getContinueBlock() {
3181 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3181, __extension__
__PRETTY_FUNCTION__))
;
3182 // The second to last block is the loop continue block.
3183 return &*std::prev(body().end(), 2);
3184}
3185
3186Block *spirv::LoopOp::getMergeBlock() {
3187 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3187, __extension__
__PRETTY_FUNCTION__))
;
3188 // The last block is the loop merge block.
3189 return &body().back();
3190}
3191
3192void spirv::LoopOp::addEntryAndMergeBlock() {
3193 assert(body().empty() && "entry and merge block already exist")(static_cast <bool> (body().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("body().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3193, __extension__
__PRETTY_FUNCTION__))
;
3194 body().push_back(new Block());
3195 auto *mergeBlock = new Block();
3196 body().push_back(mergeBlock);
3197 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3198
3199 // Add a spv.mlir.merge op into the merge block.
3200 builder.create<spirv::MergeOp>(getLoc());
3201}
3202
3203//===----------------------------------------------------------------------===//
3204// spv.MemoryBarrierOp
3205//===----------------------------------------------------------------------===//
3206
3207LogicalResult spirv::MemoryBarrierOp::verify() {
3208 return verifyMemorySemantics(getOperation(), memory_semantics());
3209}
3210
3211//===----------------------------------------------------------------------===//
3212// spv.mlir.merge
3213//===----------------------------------------------------------------------===//
3214
3215LogicalResult spirv::MergeOp::verify() {
3216 auto *parentOp = (*this)->getParentOp();
3217 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3218 return emitOpError(
3219 "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'");
3220
3221 // TODO: This check should be done in `verifyRegions` of parent op.
3222 Block &parentLastBlock = (*this)->getParentRegion()->back();
3223 if (getOperation() != parentLastBlock.getTerminator())
3224 return emitOpError("can only be used in the last block of "
3225 "'spv.mlir.selection' or 'spv.mlir.loop'");
3226 return success();
3227}
3228
3229//===----------------------------------------------------------------------===//
3230// spv.module
3231//===----------------------------------------------------------------------===//
3232
3233void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3234 Optional<StringRef> name) {
3235 OpBuilder::InsertionGuard guard(builder);
3236 builder.createBlock(state.addRegion());
3237 if (name) {
3238 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3239 builder.getStringAttr(*name));
3240 }
3241}
3242
3243void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3244 spirv::AddressingModel addressingModel,
3245 spirv::MemoryModel memoryModel,
3246 Optional<VerCapExtAttr> vceTriple,
3247 Optional<StringRef> name) {
3248 state.addAttribute(
3249 "addressing_model",
3250 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3251 state.addAttribute("memory_model",
3252 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3253 OpBuilder::InsertionGuard guard(builder);
3254 builder.createBlock(state.addRegion());
3255 if (vceTriple)
3256 state.addAttribute(getVCETripleAttrName(), *vceTriple);
3257 if (name)
3258 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3259 builder.getStringAttr(*name));
3260}
3261
3262ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3263 OperationState &result) {
3264 Region *body = result.addRegion();
3265
3266 // If the name is present, parse it.
3267 StringAttr nameAttr;
3268 (void)parser.parseOptionalSymbolName(
3269 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3270
3271 // Parse attributes
3272 spirv::AddressingModel addrModel;
3273 spirv::MemoryModel memoryModel;
3274 if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3275 result) ||
3276 ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3277 result))
3278 return failure();
3279
3280 if (succeeded(parser.parseOptionalKeyword("requires"))) {
3281 spirv::VerCapExtAttr vceTriple;
3282 if (parser.parseAttribute(vceTriple,
3283 spirv::ModuleOp::getVCETripleAttrName(),
3284 result.attributes))
3285 return failure();
3286 }
3287
3288 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3289 parser.parseRegion(*body, /*arguments=*/{}))
3290 return failure();
3291
3292 // Make sure we have at least one block.
3293 if (body->empty())
3294 body->push_back(new Block());
3295
3296 return success();
3297}
3298
3299void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3300 if (Optional<StringRef> name = getName()) {
3301 printer << ' ';
3302 printer.printSymbolName(*name);
3303 }
3304
3305 SmallVector<StringRef, 2> elidedAttrs;
3306
3307 printer << " " << spirv::stringifyAddressingModel(addressing_model()) << " "
3308 << spirv::stringifyMemoryModel(memory_model());
3309 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3310 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3311 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3312 mlir::SymbolTable::getSymbolAttrName()});
3313
3314 if (Optional<spirv::VerCapExtAttr> triple = vce_triple()) {
3315 printer << " requires " << *triple;
3316 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3317 }
3318
3319 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3320 printer << ' ';
3321 printer.printRegion(getRegion());
3322}
3323
3324LogicalResult spirv::ModuleOp::verifyRegions() {
3325 Dialect *dialect = (*this)->getDialect();
3326 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3327 entryPoints;
3328 mlir::SymbolTable table(*this);
3329
3330 for (auto &op : *getBody()) {
3331 if (op.getDialect() != dialect)
3332 return op.emitError("'spv.module' can only contain spv.* ops");
3333
3334 // For EntryPoint op, check that the function and execution model is not
3335 // duplicated in EntryPointOps. Also verify that the interface specified
3336 // comes from globalVariables here to make this check cheaper.
3337 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3338 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.fn());
3339 if (!funcOp) {
3340 return entryPointOp.emitError("function '")
3341 << entryPointOp.fn() << "' not found in 'spv.module'";
3342 }
3343 if (auto interface = entryPointOp.interface()) {
3344 for (Attribute varRef : interface) {
3345 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3346 if (!varSymRef) {
3347 return entryPointOp.emitError(
3348 "expected symbol reference for interface "
3349 "specification instead of '")
3350 << varRef;
3351 }
3352 auto variableOp =
3353 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3354 if (!variableOp) {
3355 return entryPointOp.emitError("expected spv.GlobalVariable "
3356 "symbol reference instead of'")
3357 << varSymRef << "'";
3358 }
3359 }
3360 }
3361
3362 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3363 funcOp, entryPointOp.execution_model());
3364 auto entryPtIt = entryPoints.find(key);
3365 if (entryPtIt != entryPoints.end()) {
3366 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3367 }
3368 entryPoints[key] = entryPointOp;
3369 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3370 if (funcOp.isExternal())
3371 return op.emitError("'spv.module' cannot contain external functions");
3372
3373 // TODO: move this check to spv.func.
3374 for (auto &block : funcOp)
3375 for (auto &op : block) {
3376 if (op.getDialect() != dialect)
3377 return op.emitError(
3378 "functions in 'spv.module' can only contain spv.* ops");
3379 }
3380 }
3381 }
3382
3383 return success();
3384}
3385
3386//===----------------------------------------------------------------------===//
3387// spv.mlir.referenceof
3388//===----------------------------------------------------------------------===//
3389
3390LogicalResult spirv::ReferenceOfOp::verify() {
3391 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3392 (*this)->getParentOp(), spec_constAttr());
3393 Type constType;
3394
3395 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3396 if (specConstOp)
3397 constType = specConstOp.default_value().getType();
3398
3399 auto specConstCompositeOp =
3400 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3401 if (specConstCompositeOp)
3402 constType = specConstCompositeOp.type();
3403
3404 if (!specConstOp && !specConstCompositeOp)
3405 return emitOpError(
3406 "expected spv.SpecConstant or spv.SpecConstantComposite symbol");
3407
3408 if (reference().getType() != constType)
3409 return emitOpError("result type mismatch with the referenced "
3410 "specialization constant's type");
3411
3412 return success();
3413}
3414
3415//===----------------------------------------------------------------------===//
3416// spv.Return
3417//===----------------------------------------------------------------------===//
3418
3419LogicalResult spirv::ReturnOp::verify() {
3420 // Verification is performed in spv.func op.
3421 return success();
3422}
3423
3424//===----------------------------------------------------------------------===//
3425// spv.ReturnValue
3426//===----------------------------------------------------------------------===//
3427
3428LogicalResult spirv::ReturnValueOp::verify() {
3429 // Verification is performed in spv.func op.
3430 return success();
3431}
3432
3433//===----------------------------------------------------------------------===//
3434// spv.Select
3435//===----------------------------------------------------------------------===//
3436
3437LogicalResult spirv::SelectOp::verify() {
3438 if (auto conditionTy = condition().getType().dyn_cast<VectorType>()) {
3439 auto resultVectorTy = result().getType().dyn_cast<VectorType>();
3440 if (!resultVectorTy) {
3441 return emitOpError("result expected to be of vector type when "
3442 "condition is of vector type");
3443 }
3444 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3445 return emitOpError("result should have the same number of elements as "
3446 "the condition when condition is of vector type");
3447 }
3448 }
3449 return success();
3450}
3451
3452//===----------------------------------------------------------------------===//
3453// spv.mlir.selection
3454//===----------------------------------------------------------------------===//
3455
3456ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3457 OperationState &result) {
3458 if (parseControlAttribute<spirv::SelectionControlAttr,
3459 spirv::SelectionControl>(parser, result))
3460 return failure();
3461 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3462}
3463
3464void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3465 auto control = selection_control();
3466 if (control != spirv::SelectionControl::None)
3467 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3468 printer << ' ';
3469 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3470 /*printBlockTerminators=*/true);
3471}
3472
3473LogicalResult spirv::SelectionOp::verifyRegions() {
3474 auto *op = getOperation();
3475
3476 // We need to verify that the blocks follow the following layout:
3477 //
3478 // +--------------+
3479 // | header block |
3480 // +--------------+
3481 // / | \
3482 // ...
3483 //
3484 //
3485 // +---------+ +---------+ +---------+
3486 // | case #0 | | case #1 | | case #2 | ...
3487 // +---------+ +---------+ +---------+
3488 //
3489 //
3490 // ...
3491 // \ | /
3492 // v
3493 // +-------------+
3494 // | merge block |
3495 // +-------------+
3496
3497 auto &region = op->getRegion(0);
3498 // Allow empty region as a degenerated case, which can come from
3499 // optimizations.
3500 if (region.empty())
3501 return success();
3502
3503 // The last block is the merge block.
3504 if (!isMergeBlock(region.back()))
3505 return emitOpError(
3506 "last block must be the merge block with only one 'spv.mlir.merge' op");
3507
3508 if (std::next(region.begin()) == region.end())
3509 return emitOpError("must have a selection header block");
3510
3511 return success();
3512}
3513
3514Block *spirv::SelectionOp::getHeaderBlock() {
3515 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3515, __extension__
__PRETTY_FUNCTION__))
;
3516 // The first block is the loop header block.
3517 return &body().front();
3518}
3519
3520Block *spirv::SelectionOp::getMergeBlock() {
3521 assert(!body().empty() && "op region should not be empty!")(static_cast <bool> (!body().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!body().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3521, __extension__
__PRETTY_FUNCTION__))
;
3522 // The last block is the loop merge block.
3523 return &body().back();
3524}
3525
3526void spirv::SelectionOp::addMergeBlock() {
3527 assert(body().empty() && "entry and merge block already exist")(static_cast <bool> (body().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("body().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3527, __extension__
__PRETTY_FUNCTION__))
;
3528 auto *mergeBlock = new Block();
3529 body().push_back(mergeBlock);
3530 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3531
3532 // Add a spv.mlir.merge op into the merge block.
3533 builder.create<spirv::MergeOp>(getLoc());
3534}
3535
3536spirv::SelectionOp spirv::SelectionOp::createIfThen(
3537 Location loc, Value condition,
3538 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3539 auto selectionOp =
3540 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3541
3542 selectionOp.addMergeBlock();
3543 Block *mergeBlock = selectionOp.getMergeBlock();
3544 Block *thenBlock = nullptr;
3545
3546 // Build the "then" block.
3547 {
3548 OpBuilder::InsertionGuard guard(builder);
3549 thenBlock = builder.createBlock(mergeBlock);
3550 thenBody(builder);
3551 builder.create<spirv::BranchOp>(loc, mergeBlock);
3552 }
3553
3554 // Build the header block.
3555 {
3556 OpBuilder::InsertionGuard guard(builder);
3557 builder.createBlock(thenBlock);
3558 builder.create<spirv::BranchConditionalOp>(
3559 loc, condition, thenBlock,
3560 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3561 /*falseArguments=*/ArrayRef<Value>());
3562 }
3563
3564 return selectionOp;
3565}
3566
3567//===----------------------------------------------------------------------===//
3568// spv.SpecConstant
3569//===----------------------------------------------------------------------===//
3570
3571ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3572 OperationState &result) {
3573 StringAttr nameAttr;
3574 Attribute valueAttr;
3575
3576 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3577 result.attributes))
3578 return failure();
3579
3580 // Parse optional spec_id.
3581 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3582 IntegerAttr specIdAttr;
3583 if (parser.parseLParen() ||
3584 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3585 parser.parseRParen())
3586 return failure();
3587 }
3588
3589 if (parser.parseEqual() ||
3590 parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3591 result.attributes))
3592 return failure();
3593
3594 return success();
3595}
3596
3597void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3598 printer << ' ';
3599 printer.printSymbolName(sym_name());
3600 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3601 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3602 printer << " = " << default_value();
3603}
3604
3605LogicalResult spirv::SpecConstantOp::verify() {
3606 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3607 if (specID.getValue().isNegative())
3608 return emitOpError("SpecId cannot be negative");
3609
3610 auto value = default_value();
3611 if (value.isa<IntegerAttr, FloatAttr>()) {
3612 // Make sure bitwidth is allowed.
3613 if (!value.getType().isa<spirv::SPIRVType>())
3614 return emitOpError("default value bitwidth disallowed");
3615 return success();
3616 }
3617 return emitOpError(
3618 "default value can only be a bool, integer, or float scalar");
3619}
3620
3621//===----------------------------------------------------------------------===//
3622// spv.StoreOp
3623//===----------------------------------------------------------------------===//
3624
3625ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3626 // Parse the storage class specification
3627 spirv::StorageClass storageClass;
3628 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3629 auto loc = parser.getCurrentLocation();
3630 Type elementType;
3631 if (parseEnumStrAttr(storageClass, parser) ||
3632 parser.parseOperandList(operandInfo, 2) ||
3633 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3634 parser.parseType(elementType)) {
3635 return failure();
3636 }
3637
3638 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3639 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3640 result.operands)) {
3641 return failure();
3642 }
3643 return success();
3644}
3645
3646void spirv::StoreOp::print(OpAsmPrinter &printer) {
3647 SmallVector<StringRef, 4> elidedAttrs;
3648 StringRef sc = stringifyStorageClass(
3649 ptr().getType().cast<spirv::PointerType>().getStorageClass());
3650 printer << " \"" << sc << "\" " << ptr() << ", " << value();
3651
3652 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3653
3654 printer << " : " << value().getType();
3655 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3656}
3657
3658LogicalResult spirv::StoreOp::verify() {
3659 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3660 // OpTypePointer whose Type operand is the same as the type of Object."
3661 if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value())))
3662 return failure();
3663 return verifyMemoryAccessAttribute(*this);
3664}
3665
3666//===----------------------------------------------------------------------===//
3667// spv.Unreachable
3668//===----------------------------------------------------------------------===//
3669
3670LogicalResult spirv::UnreachableOp::verify() {
3671 auto *block = (*this)->getBlock();
3672 // Fast track: if this is in entry block, its invalid. Otherwise, if no
3673 // predecessors, it's valid.
3674 if (block->isEntryBlock())
3675 return emitOpError("cannot be used in reachable block");
3676 if (block->hasNoPredecessors())
3677 return success();
3678
3679 // TODO: further verification needs to analyze reachability from
3680 // the entry block.
3681
3682 return success();
3683}
3684
3685//===----------------------------------------------------------------------===//
3686// spv.Variable
3687//===----------------------------------------------------------------------===//
3688
3689ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3690 OperationState &result) {
3691 // Parse optional initializer
3692 Optional<OpAsmParser::UnresolvedOperand> initInfo;
3693 if (succeeded(parser.parseOptionalKeyword("init"))) {
3694 initInfo = OpAsmParser::UnresolvedOperand();
3695 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3696 parser.parseRParen())
3697 return failure();
3698 }
3699
3700 if (parseVariableDecorations(parser, result)) {
3701 return failure();
3702 }
3703
3704 // Parse result pointer type
3705 Type type;
3706 if (parser.parseColon())
3707 return failure();
3708 auto loc = parser.getCurrentLocation();
3709 if (parser.parseType(type))
3710 return failure();
3711
3712 auto ptrType = type.dyn_cast<spirv::PointerType>();
3713 if (!ptrType)
3714 return parser.emitError(loc, "expected spv.ptr type");
3715 result.addTypes(ptrType);
3716
3717 // Resolve the initializer operand
3718 if (initInfo) {
3719 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3720 result.operands))
3721 return failure();
3722 }
3723
3724 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3725 ptrType.getStorageClass());
3726 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3727
3728 return success();
3729}
3730
3731void spirv::VariableOp::print(OpAsmPrinter &printer) {
3732 SmallVector<StringRef, 4> elidedAttrs{
3733 spirv::attributeName<spirv::StorageClass>()};
3734 // Print optional initializer
3735 if (getNumOperands() != 0)
3736 printer << " init(" << initializer() << ")";
3737
3738 printVariableDecorations(*this, printer, elidedAttrs);
3739 printer << " : " << getType();
3740}
3741
3742LogicalResult spirv::VariableOp::verify() {
3743 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3744 // object. It cannot be Generic. It must be the same as the Storage Class
3745 // operand of the Result Type."
3746 if (storage_class() != spirv::StorageClass::Function) {
3747 return emitOpError(
3748 "can only be used to model function-level variables. Use "
3749 "spv.GlobalVariable for module-level variables.");
3750 }
3751
3752 auto pointerType = pointer().getType().cast<spirv::PointerType>();
3753 if (storage_class() != pointerType.getStorageClass())
3754 return emitOpError(
3755 "storage class must match result pointer's storage class");
3756
3757 if (getNumOperands() != 0) {
3758 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3759 // a global (module scope) OpVariable instruction".
3760 auto *initOp = getOperand(0).getDefiningOp();
3761 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3762 spirv::ReferenceOfOp, // for spec constant
3763 spirv::AddressOfOp>(initOp))
3764 return emitOpError("initializer must be the result of a "
3765 "constant or spv.GlobalVariable op");
3766 }
3767
3768 // TODO: generate these strings using ODS.
3769 auto *op = getOperation();
3770 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3771 stringifyDecoration(spirv::Decoration::DescriptorSet));
3772 auto bindingName = llvm::convertToSnakeFromCamelCase(
3773 stringifyDecoration(spirv::Decoration::Binding));
3774 auto builtInName = llvm::convertToSnakeFromCamelCase(
3775 stringifyDecoration(spirv::Decoration::BuiltIn));
3776
3777 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3778 if (op->getAttr(attr))
3779 return emitOpError("cannot have '")
3780 << attr << "' attribute (only allowed in spv.GlobalVariable)";
3781 }
3782
3783 return success();
3784}
3785
3786//===----------------------------------------------------------------------===//
3787// spv.VectorShuffle
3788//===----------------------------------------------------------------------===//
3789
3790LogicalResult spirv::VectorShuffleOp::verify() {
3791 VectorType resultType = getType().cast<VectorType>();
3792
3793 size_t numResultElements = resultType.getNumElements();
3794 if (numResultElements != components().size())
3795 return emitOpError("result type element count (")
3796 << numResultElements
3797 << ") mismatch with the number of component selectors ("
3798 << components().size() << ")";
3799
3800 size_t totalSrcElements =
3801 vector1().getType().cast<VectorType>().getNumElements() +
3802 vector2().getType().cast<VectorType>().getNumElements();
3803
3804 for (const auto &selector : components().getAsValueRange<IntegerAttr>()) {
3805 uint32_t index = selector.getZExtValue();
3806 if (index >= totalSrcElements &&
3807 index != std::numeric_limits<uint32_t>().max())
3808 return emitOpError("component selector ")
3809 << index << " out of range: expected to be in [0, "
3810 << totalSrcElements << ") or 0xffffffff";
3811 }
3812 return success();
3813}
3814
3815//===----------------------------------------------------------------------===//
3816// spv.CooperativeMatrixLoadNV
3817//===----------------------------------------------------------------------===//
3818
3819ParseResult spirv::CooperativeMatrixLoadNVOp::parse(OpAsmParser &parser,
3820 OperationState &result) {
3821 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3822 Type strideType = parser.getBuilder().getIntegerType(32);
3823 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3824 Type ptrType;
3825 Type elementType;
3826 if (parser.parseOperandList(operandInfo, 3) ||
3827 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3828 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3829 return failure();
3830 }
3831 if (parser.resolveOperands(operandInfo,
3832 {ptrType, strideType, columnMajorType},
3833 parser.getNameLoc(), result.operands)) {
3834 return failure();
3835 }
3836
3837 result.addTypes(elementType);
3838 return success();
3839}
3840
3841void spirv::CooperativeMatrixLoadNVOp::print(OpAsmPrinter &printer) {
3842 printer << " " << pointer() << ", " << stride() << ", " << columnmajor();
3843 // Print optional memory access attribute.
3844 if (auto memAccess = memory_access())
3845 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3846 printer << " : " << pointer().getType() << " as " << getType();
3847}
3848
3849static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
3850 Type coopMatrix) {
3851 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3852 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3853 return op->emitError(
3854 "Pointer must point to a scalar or vector type but provided ")
3855 << pointeeType;
3856 spirv::StorageClass storage =
3857 pointer.cast<spirv::PointerType>().getStorageClass();
3858 if (storage != spirv::StorageClass::Workgroup &&
3859 storage != spirv::StorageClass::StorageBuffer &&
3860 storage != spirv::StorageClass::PhysicalStorageBuffer)
3861 return op->emitError(
3862 "Pointer storage class must be Workgroup, StorageBuffer or "
3863 "PhysicalStorageBufferEXT but provided ")
3864 << stringifyStorageClass(storage);
3865 return success();
3866}
3867
3868LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() {
3869 return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3870 result().getType());
3871}
3872
3873//===----------------------------------------------------------------------===//
3874// spv.CooperativeMatrixStoreNV
3875//===----------------------------------------------------------------------===//
3876
3877ParseResult spirv::CooperativeMatrixStoreNVOp::parse(OpAsmParser &parser,
3878 OperationState &result) {
3879 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
3880 Type strideType = parser.getBuilder().getIntegerType(32);
3881 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3882 Type ptrType;
3883 Type elementType;
3884 if (parser.parseOperandList(operandInfo, 4) ||
3885 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3886 parser.parseType(ptrType) || parser.parseComma() ||
3887 parser.parseType(elementType)) {
3888 return failure();
3889 }
3890 if (parser.resolveOperands(
3891 operandInfo, {ptrType, elementType, strideType, columnMajorType},
3892 parser.getNameLoc(), result.operands)) {
3893 return failure();
3894 }
3895
3896 return success();
3897}
3898
3899void spirv::CooperativeMatrixStoreNVOp::print(OpAsmPrinter &printer) {
3900 printer << " " << pointer() << ", " << object() << ", " << stride() << ", "
3901 << columnmajor();
3902 // Print optional memory access attribute.
3903 if (auto memAccess = memory_access())
3904 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3905 printer << " : " << pointer().getType() << ", " << getOperand(1).getType();
3906}
3907
3908LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() {
3909 return verifyPointerAndCoopMatrixType(*this, pointer().getType(),
3910 object().getType());
3911}
3912
3913//===----------------------------------------------------------------------===//
3914// spv.CooperativeMatrixMulAddNV
3915//===----------------------------------------------------------------------===//
3916
3917static LogicalResult
3918verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) {
3919 if (op.c().getType() != op.result().getType())
3920 return op.emitOpError("result and third operand must have the same type");
3921 auto typeA = op.a().getType().cast<spirv::CooperativeMatrixNVType>();
3922 auto typeB = op.b().getType().cast<spirv::CooperativeMatrixNVType>();
3923 auto typeC = op.c().getType().cast<spirv::CooperativeMatrixNVType>();
3924 auto typeR = op.result().getType().cast<spirv::CooperativeMatrixNVType>();
3925 if (typeA.getRows() != typeR.getRows() ||
3926 typeA.getColumns() != typeB.getRows() ||
3927 typeB.getColumns() != typeR.getColumns())
3928 return op.emitOpError("matrix size must match");
3929 if (typeR.getScope() != typeA.getScope() ||
3930 typeR.getScope() != typeB.getScope() ||
3931 typeR.getScope() != typeC.getScope())
3932 return op.emitOpError("matrix scope must match");
3933 if (typeA.getElementType() != typeB.getElementType() ||
3934 typeR.getElementType() != typeC.getElementType())
3935 return op.emitOpError("matrix element type must match");
3936 return success();
3937}
3938
3939LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() {
3940 return verifyCoopMatrixMulAdd(*this);
3941}
3942
3943static LogicalResult
3944verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
3945 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
3946 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
3947 return op->emitError(
3948 "Pointer must point to a scalar or vector type but provided ")
3949 << pointeeType;
3950 spirv::StorageClass storage =
3951 pointer.cast<spirv::PointerType>().getStorageClass();
3952 if (storage != spirv::StorageClass::Workgroup &&
3953 storage != spirv::StorageClass::CrossWorkgroup &&
3954 storage != spirv::StorageClass::UniformConstant &&
3955 storage != spirv::StorageClass::Generic)
3956 return op->emitError("Pointer storage class must be Workgroup or "
3957 "CrossWorkgroup but provided ")
3958 << stringifyStorageClass(storage);
3959 return success();
3960}
3961
3962//===----------------------------------------------------------------------===//
3963// spv.JointMatrixLoadINTEL
3964//===----------------------------------------------------------------------===//
3965
3966LogicalResult spirv::JointMatrixLoadINTELOp::verify() {
3967 return verifyPointerAndJointMatrixType(*this, pointer().getType(),
3968 result().getType());
3969}
3970
3971//===----------------------------------------------------------------------===//
3972// spv.JointMatrixStoreINTEL
3973//===----------------------------------------------------------------------===//
3974
3975LogicalResult spirv::JointMatrixStoreINTELOp::verify() {
3976 return verifyPointerAndJointMatrixType(*this, pointer().getType(),
3977 object().getType());
3978}
3979
3980//===----------------------------------------------------------------------===//
3981// spv.JointMatrixMadINTEL
3982//===----------------------------------------------------------------------===//
3983
3984static LogicalResult verifyJointMatrixMad(spirv::JointMatrixMadINTELOp op) {
3985 if (op.c().getType() != op.result().getType())
3986 return op.emitOpError("result and third operand must have the same type");
3987 auto typeA = op.a().getType().cast<spirv::JointMatrixINTELType>();
3988 auto typeB = op.b().getType().cast<spirv::JointMatrixINTELType>();
3989 auto typeC = op.c().getType().cast<spirv::JointMatrixINTELType>();
3990 auto typeR = op.result().getType().cast<spirv::JointMatrixINTELType>();
3991 if (typeA.getRows() != typeR.getRows() ||
3992 typeA.getColumns() != typeB.getRows() ||
3993 typeB.getColumns() != typeR.getColumns())
3994 return op.emitOpError("matrix size must match");
3995 if (typeR.getScope() != typeA.getScope() ||
3996 typeR.getScope() != typeB.getScope() ||
3997 typeR.getScope() != typeC.getScope())
3998 return op.emitOpError("matrix scope must match");
3999 if (typeA.getElementType() != typeB.getElementType() ||
4000 typeR.getElementType() != typeC.getElementType())
4001 return op.emitOpError("matrix element type must match");
4002 return success();
4003}
4004
4005LogicalResult spirv::JointMatrixMadINTELOp::verify() {
4006 return verifyJointMatrixMad(*this);
4007}
4008
4009//===----------------------------------------------------------------------===//
4010// spv.MatrixTimesScalar
4011//===----------------------------------------------------------------------===//
4012
4013LogicalResult spirv::MatrixTimesScalarOp::verify() {
4014 // We already checked that result and matrix are both of matrix type in the
4015 // auto-generated verify method.
4016
4017 auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
4018 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4019
4020 // Check that the scalar type is the same as the matrix element type.
4021 if (scalar().getType() != inputMatrix.getElementType())
4022 return emitError("input matrix components' type and scaling value must "
4023 "have the same type");
4024
4025 // Note that the next three checks could be done using the AllTypesMatch
4026 // trait in the Op definition file but it generates a vague error message.
4027
4028 // Check that the input and result matrices have the same columns' count
4029 if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns())
4030 return emitError("input and result matrices must have the same "
4031 "number of columns");
4032
4033 // Check that the input and result matrices' have the same rows count
4034 if (inputMatrix.getNumRows() != resultMatrix.getNumRows())
4035 return emitError("input and result matrices' columns must have "
4036 "the same size");
4037
4038 // Check that the input and result matrices' have the same component type
4039 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4040 return emitError("input and result matrices' columns must have "
4041 "the same component type");
4042
4043 return success();
4044}
4045
4046//===----------------------------------------------------------------------===//
4047// spv.CopyMemory
4048//===----------------------------------------------------------------------===//
4049
4050void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
4051 printer << ' ';
4052
4053 StringRef targetStorageClass = stringifyStorageClass(
4054 target().getType().cast<spirv::PointerType>().getStorageClass());
4055 printer << " \"" << targetStorageClass << "\" " << target() << ", ";
4056
4057 StringRef sourceStorageClass = stringifyStorageClass(
4058 source().getType().cast<spirv::PointerType>().getStorageClass());
4059 printer << " \"" << sourceStorageClass << "\" " << source();
4060
4061 SmallVector<StringRef, 4> elidedAttrs;
4062 printMemoryAccessAttribute(*this, printer, elidedAttrs);
4063 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4064 source_memory_access(), source_alignment());
4065
4066 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4067
4068 Type pointeeType =
4069 target().getType().cast<spirv::PointerType>().getPointeeType();
4070 printer << " : " << pointeeType;
4071}
4072
4073ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4074 OperationState &result) {
4075 spirv::StorageClass targetStorageClass;
4076 OpAsmParser::UnresolvedOperand targetPtrInfo;
4077
4078 spirv::StorageClass sourceStorageClass;
4079 OpAsmParser::UnresolvedOperand sourcePtrInfo;
4080
4081 Type elementType;
4082
4083 if (parseEnumStrAttr(targetStorageClass, parser) ||
4084 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4085 parseEnumStrAttr(sourceStorageClass, parser) ||
4086 parser.parseOperand(sourcePtrInfo) ||
4087 parseMemoryAccessAttributes(parser, result)) {
4088 return failure();
4089 }
4090
4091 if (!parser.parseOptionalComma()) {
4092 // Parse 2nd memory access attributes.
4093 if (parseSourceMemoryAccessAttributes(parser, result)) {
4094 return failure();
4095 }
4096 }
4097
4098 if (parser.parseColon() || parser.parseType(elementType))
4099 return failure();
4100
4101 if (parser.parseOptionalAttrDict(result.attributes))
4102 return failure();
4103
4104 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4105 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
4106
4107 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4108 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4109 return failure();
4110 }
4111
4112 return success();
4113}
4114
4115LogicalResult spirv::CopyMemoryOp::verify() {
4116 Type targetType =
4117 target().getType().cast<spirv::PointerType>().getPointeeType();
4118
4119 Type sourceType =
4120 source().getType().cast<spirv::PointerType>().getPointeeType();
4121
4122 if (targetType != sourceType)
4123 return emitOpError("both operands must be pointers to the same type");
4124
4125 if (failed(verifyMemoryAccessAttribute(*this)))
4126 return failure();
4127
4128 // TODO - According to the spec:
4129 //
4130 // If two masks are present, the first applies to Target and cannot include
4131 // MakePointerVisible, and the second applies to Source and cannot include
4132 // MakePointerAvailable.
4133 //
4134 // Add such verification here.
4135
4136 return verifySourceMemoryAccessAttribute(*this);
4137}
4138
4139//===----------------------------------------------------------------------===//
4140// spv.Transpose
4141//===----------------------------------------------------------------------===//
4142
4143LogicalResult spirv::TransposeOp::verify() {
4144 auto inputMatrix = matrix().getType().cast<spirv::MatrixType>();
4145 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4146
4147 // Verify that the input and output matrices have correct shapes.
4148 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4149 return emitError("input matrix rows count must be equal to "
4150 "output matrix columns count");
4151
4152 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4153 return emitError("input matrix columns count must be equal to "
4154 "output matrix rows count");
4155
4156 // Verify that the input and output matrices have the same component type
4157 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4158 return emitError("input and output matrices must have the same "
4159 "component type");
4160
4161 return success();
4162}
4163
4164//===----------------------------------------------------------------------===//
4165// spv.MatrixTimesMatrix
4166//===----------------------------------------------------------------------===//
4167
4168LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4169 auto leftMatrix = leftmatrix().getType().cast<spirv::MatrixType>();
4170 auto rightMatrix = rightmatrix().getType().cast<spirv::MatrixType>();
4171 auto resultMatrix = result().getType().cast<spirv::MatrixType>();
4172
4173 // left matrix columns' count and right matrix rows' count must be equal
4174 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4175 return emitError("left matrix columns' count must be equal to "
4176 "the right matrix rows' count");
4177
4178 // right and result matrices columns' count must be the same
4179 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4180 return emitError(
4181 "right and result matrices must have equal columns' count");
4182
4183 // right and result matrices component type must be the same
4184 if (rightMatrix.getElementType() != resultMatrix.getElementType())
4185 return emitError("right and result matrices' component type must"
4186 " be the same");
4187
4188 // left and result matrices component type must be the same
4189 if (leftMatrix.getElementType() != resultMatrix.getElementType())
4190 return emitError("left and result matrices' component type"
4191 " must be the same");
4192
4193 // left and result matrices rows count must be the same
4194 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4195 return emitError("left and result matrices must have equal rows' count");
4196
4197 return success();
4198}
4199
4200//===----------------------------------------------------------------------===//
4201// spv.SpecConstantComposite
4202//===----------------------------------------------------------------------===//
4203
4204ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4205 OperationState &result) {
4206
4207 StringAttr compositeName;
4208 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4209 result.attributes))
4210 return failure();
4211
4212 if (parser.parseLParen())
4213 return failure();
4214
4215 SmallVector<Attribute, 4> constituents;
4216
4217 do {
4218 // The name of the constituent attribute isn't important
4219 const char *attrName = "spec_const";
4220 FlatSymbolRefAttr specConstRef;
4221 NamedAttrList attrs;
4222
4223 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4224 return failure();
4225
4226 constituents.push_back(specConstRef);
4227 } while (!parser.parseOptionalComma());
4228
4229 if (parser.parseRParen())
4230 return failure();
4231
4232 result.addAttribute(kCompositeSpecConstituentsName,
4233 parser.getBuilder().getArrayAttr(constituents));
4234
4235 Type type;
4236 if (parser.parseColonType(type))
4237 return failure();
4238
4239 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4240
4241 return success();
4242}
4243
4244void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4245 printer << " ";
4246 printer.printSymbolName(sym_name());
4247 printer << " (";
4248 auto constituents = this->constituents().getValue();
4249
4250 if (!constituents.empty())
4251 llvm::interleaveComma(constituents, printer);
4252
4253 printer << ") : " << type();
4254}
4255
4256LogicalResult spirv::SpecConstantCompositeOp::verify() {
4257 auto cType = type().dyn_cast<spirv::CompositeType>();
4258 auto constituents = this->constituents().getValue();
4259
4260 if (!cType)
4261 return emitError("result type must be a composite type, but provided ")
4262 << type();
4263
4264 if (cType.isa<spirv::CooperativeMatrixNVType>())
4265 return emitError("unsupported composite type ") << cType;
4266 if (cType.isa<spirv::JointMatrixINTELType>())
4267 return emitError("unsupported composite type ") << cType;
4268 if (constituents.size() != cType.getNumElements())
4269 return emitError("has incorrect number of operands: expected ")
4270 << cType.getNumElements() << ", but provided "
4271 << constituents.size();
4272
4273 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4274 auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4275
4276 auto constituentSpecConstOp =
4277 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4278 (*this)->getParentOp(), constituent.getAttr()));
4279
4280 if (constituentSpecConstOp.default_value().getType() !=
4281 cType.getElementType(index))
4282 return emitError("has incorrect types of operands: expected ")
4283 << cType.getElementType(index) << ", but provided "
4284 << constituentSpecConstOp.default_value().getType();
4285 }
4286
4287 return success();
4288}
4289
4290//===----------------------------------------------------------------------===//
4291// spv.SpecConstantOperation
4292//===----------------------------------------------------------------------===//
4293
4294ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4295 OperationState &result) {
4296 Region *body = result.addRegion();
4297
4298 if (parser.parseKeyword("wraps"))
4299 return failure();
4300
4301 body->push_back(new Block);
4302 Block &block = body->back();
4303 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4304
4305 if (!wrappedOp)
4306 return failure();
4307
4308 OpBuilder builder(parser.getContext());
4309 builder.setInsertionPointToEnd(&block);
4310 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4311 result.location = wrappedOp->getLoc();
4312
4313 result.addTypes(wrappedOp->getResult(0).getType());
4314
4315 if (parser.parseOptionalAttrDict(result.attributes))
4316 return failure();
4317
4318 return success();
4319}
4320
4321void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4322 printer << " wraps ";
4323 printer.printGenericOp(&body().front().front());
4324}
4325
4326LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4327 Block &block = getRegion().getBlocks().front();
4328
4329 if (block.getOperations().size() != 2)
4330 return emitOpError("expected exactly 2 nested ops");
4331
4332 Operation &enclosedOp = block.getOperations().front();
4333
4334 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4335 return emitOpError("invalid enclosed op");
4336
4337 for (auto operand : enclosedOp.getOperands())
4338 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4339 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4340 return emitOpError(
4341 "invalid operand, must be defined by a constant operation");
4342
4343 return success();
4344}
4345
4346//===----------------------------------------------------------------------===//
4347// spv.GL.FrexpStruct
4348//===----------------------------------------------------------------------===//
4349
4350LogicalResult spirv::GLFrexpStructOp::verify() {
4351 spirv::StructType structTy = result().getType().dyn_cast<spirv::StructType>();
4352
4353 if (structTy.getNumElements() != 2)
4354 return emitError("result type must be a struct type with two memebers");
4355
4356 Type significandTy = structTy.getElementType(0);
4357 Type exponentTy = structTy.getElementType(1);
4358 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4359 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4360
4361 Type operandTy = operand().getType();
4362 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4363 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4364
4365 if (significandTy != operandTy)
4366 return emitError("member zero of the resulting struct type must be the "
4367 "same type as the operand");
4368
4369 if (exponentVecTy) {
4370 IntegerType componentIntTy =
4371 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4372 if (!componentIntTy || componentIntTy.getWidth() != 32)
4373 return emitError("member one of the resulting struct type must"
4374 "be a scalar or vector of 32 bit integer type");
4375 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4376 return emitError("member one of the resulting struct type "
4377 "must be a scalar or vector of 32 bit integer type");
4378 }
4379
4380 // Check that the two member types have the same number of components
4381 if (operandVecTy && exponentVecTy &&
4382 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4383 return success();
4384
4385 if (operandFTy && exponentIntTy)
4386 return success();
4387
4388 return emitError("member one of the resulting struct type must have the same "
4389 "number of components as the operand type");
4390}
4391
4392//===----------------------------------------------------------------------===//
4393// spv.GL.Ldexp
4394//===----------------------------------------------------------------------===//
4395
4396LogicalResult spirv::GLLdexpOp::verify() {
4397 Type significandType = x().getType();
4398 Type exponentType = exp().getType();
4399
4400 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4401 return emitOpError("operands must both be scalars or vectors");
4402
4403 auto getNumElements = [](Type type) -> unsigned {
4404 if (auto vectorType = type.dyn_cast<VectorType>())
4405 return vectorType.getNumElements();
4406 return 1;
4407 };
4408
4409 if (getNumElements(significandType) != getNumElements(exponentType))
4410 return emitOpError("operands must have the same number of elements");
4411
4412 return success();
4413}
4414
4415//===----------------------------------------------------------------------===//
4416// spv.ImageDrefGather
4417//===----------------------------------------------------------------------===//
4418
4419LogicalResult spirv::ImageDrefGatherOp::verify() {
4420 VectorType resultType = result().getType().cast<VectorType>();
4421 auto sampledImageType =
4422 sampledimage().getType().cast<spirv::SampledImageType>();
4423 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4424
4425 if (resultType.getNumElements() != 4)
4426 return emitOpError("result type must be a vector of four components");
4427
4428 Type elementType = resultType.getElementType();
4429 Type sampledElementType = imageType.getElementType();
4430 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4431 return emitOpError(
4432 "the component type of result must be the same as sampled type of the "
4433 "underlying image type");
4434
4435 spirv::Dim imageDim = imageType.getDim();
4436 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4437
4438 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4439 imageDim != spirv::Dim::Rect)
4440 return emitOpError(
4441 "the Dim operand of the underlying image type must be 2D, Cube, or "
4442 "Rect");
4443
4444 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4445 return emitOpError("the MS operand of the underlying image type must be 0");
4446
4447 spirv::ImageOperandsAttr attr = imageoperandsAttr();
4448 auto operandArguments = operand_arguments();
4449
4450 return verifyImageOperands(*this, attr, operandArguments);
4451}
4452
4453//===----------------------------------------------------------------------===//
4454// spv.ShiftLeftLogicalOp
4455//===----------------------------------------------------------------------===//
4456
4457LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4458 return verifyShiftOp(*this);
4459}
4460
4461//===----------------------------------------------------------------------===//
4462// spv.ShiftRightArithmeticOp
4463//===----------------------------------------------------------------------===//
4464
4465LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4466 return verifyShiftOp(*this);
4467}
4468
4469//===----------------------------------------------------------------------===//
4470// spv.ShiftRightLogicalOp
4471//===----------------------------------------------------------------------===//
4472
4473LogicalResult spirv::ShiftRightLogicalOp::verify() {
4474 return verifyShiftOp(*this);
4475}
4476
4477//===----------------------------------------------------------------------===//
4478// spv.ImageQuerySize
4479//===----------------------------------------------------------------------===//
4480
4481LogicalResult spirv::ImageQuerySizeOp::verify() {
4482 spirv::ImageType imageType = image().getType().cast<spirv::ImageType>();
4483 Type resultType = result().getType();
4484
4485 spirv::Dim dim = imageType.getDim();
4486 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4487 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4488 switch (dim) {
4489 case spirv::Dim::Dim1D:
4490 case spirv::Dim::Dim2D:
4491 case spirv::Dim::Dim3D:
4492 case spirv::Dim::Cube:
4493 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4494 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4495 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4496 return emitError(
4497 "if Dim is 1D, 2D, 3D, or Cube, "
4498 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4499 break;
4500 case spirv::Dim::Buffer:
4501 case spirv::Dim::Rect:
4502 break;
4503 default:
4504 return emitError("the Dim operand of the image type must "
4505 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4506 }
4507
4508 unsigned componentNumber = 0;
4509 switch (dim) {
4510 case spirv::Dim::Dim1D:
4511 case spirv::Dim::Buffer:
4512 componentNumber = 1;
4513 break;
4514 case spirv::Dim::Dim2D:
4515 case spirv::Dim::Cube:
4516 case spirv::Dim::Rect:
4517 componentNumber = 2;
4518 break;
4519 case spirv::Dim::Dim3D:
4520 componentNumber = 3;
4521 break;
4522 default:
4523 break;
4524 }
4525
4526 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4527 componentNumber += 1;
4528
4529 unsigned resultComponentNumber = 1;
4530 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4531 resultComponentNumber = resultVectorType.getNumElements();
4532
4533 if (componentNumber != resultComponentNumber)
4534 return emitError("expected the result to have ")
4535 << componentNumber << " component(s), but found "
4536 << resultComponentNumber << " component(s)";
4537
4538 return success();
4539}
4540
4541static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4542 OpAsmParser &parser,
4543 OperationState &state) {
4544 OpAsmParser::UnresolvedOperand ptrInfo;
4545 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4546 Type type;
4547 auto loc = parser.getCurrentLocation();
4548 SmallVector<Type, 4> indicesTypes;
4549
4550 if (parser.parseOperand(ptrInfo) ||
4551 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4552 parser.parseColonType(type) ||
4553 parser.resolveOperand(ptrInfo, type, state.operands))
4554 return failure();
4555
4556 // Check that the provided indices list is not empty before parsing their
4557 // type list.
4558 if (indicesInfo.empty())
4559 return emitError(state.location) << opName << " expected element";
4560
4561 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4562 return failure();
4563
4564 // Check that the indices types list is not empty and that it has a one-to-one
4565 // mapping to the provided indices.
4566 if (indicesTypes.size() != indicesInfo.size())
4567 return emitError(state.location)
4568 << opName
4569 << " indices types' count must be equal to indices info count";
4570
4571 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4572 return failure();
4573
4574 auto resultType = getElementPtrType(
4575 type, llvm::makeArrayRef(state.operands).drop_front(2), state.location);
4576 if (!resultType)
4577 return failure();
4578
4579 state.addTypes(resultType);
4580 return success();
4581}
4582
4583template <typename Op>
4584static auto concatElemAndIndices(Op op) {
4585 SmallVector<Value> ret(op.indices().size() + 1);
4586 ret[0] = op.element();
4587 llvm::copy(op.indices(), ret.begin() + 1);
4588 return ret;
4589}
4590
4591//===----------------------------------------------------------------------===//
4592// spv.InBoundsPtrAccessChainOp
4593//===----------------------------------------------------------------------===//
4594
4595void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4596 OperationState &state,
4597 Value basePtr, Value element,
4598 ValueRange indices) {
4599 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4600 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4600, __extension__
__PRETTY_FUNCTION__))
;
4601 build(builder, state, type, basePtr, element, indices);
4602}
4603
4604ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4605 OperationState &result) {
4606 return parsePtrAccessChainOpImpl(
4607 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4608}
4609
4610void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4611 printAccessChain(*this, concatElemAndIndices(*this), printer);
4612}
4613
4614LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4615 return verifyAccessChain(*this, indices());
4616}
4617
4618//===----------------------------------------------------------------------===//
4619// spv.PtrAccessChainOp
4620//===----------------------------------------------------------------------===//
4621
4622void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4623 Value basePtr, Value element,
4624 ValueRange indices) {
4625 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4626 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4626, __extension__
__PRETTY_FUNCTION__))
;
4627 build(builder, state, type, basePtr, element, indices);
4628}
4629
4630ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4631 OperationState &result) {
4632 return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4633 parser, result);
4634}
4635
4636void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4637 printAccessChain(*this, concatElemAndIndices(*this), printer);
4638}
4639
4640LogicalResult spirv::PtrAccessChainOp::verify() {
4641 return verifyAccessChain(*this, indices());
4642}
4643
4644//===----------------------------------------------------------------------===//
4645// spv.VectorTimesScalarOp
4646//===----------------------------------------------------------------------===//
4647
4648LogicalResult spirv::VectorTimesScalarOp::verify() {
4649 if (vector().getType() != getType())
4650 return emitOpError("vector operand and result type mismatch");
4651 auto scalarType = getType().cast<VectorType>().getElementType();
4652 if (scalar().getType() != scalarType)
4653 return emitOpError("scalar operand and result element type match");
4654 return success();
4655}
4656
4657// TableGen'erated operation interfaces for querying versions, extensions, and
4658// capabilities.
4659#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4660
4661// TablenGen'erated operation definitions.
4662#define GET_OP_CLASSES
4663#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4664
4665namespace mlir {
4666namespace spirv {
4667// TableGen'erated operation availability interface implementations.
4668#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4669} // namespace spirv
4670} // namespace mlir