Bug Summary

File:build/source/mlir/include/mlir/IR/Builders.h
Warning:line 100, column 12
2nd function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SPIRVOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-17/lib/clang/17 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GLIBCXX_ASSERTIONS -D _GNU_SOURCE -D _LIBCPP_ENABLE_ASSERTIONS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/SPIRV/IR -I /build/source/mlir/lib/Dialect/SPIRV/IR -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-17/lib/clang/17/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1683717183 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2023-05-10-133810-16478-1 -x c++ /build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

/build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

1//===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the operations in the SPIR-V dialect.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14
15#include "mlir/Dialect/SPIRV/IR/ParserUtils.h"
16#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h"
17#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
18#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
19#include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h"
20#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
21#include "mlir/Dialect/SPIRV/IR/TargetAndABI.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinTypes.h"
24#include "mlir/IR/FunctionImplementation.h"
25#include "mlir/IR/OpDefinition.h"
26#include "mlir/IR/OpImplementation.h"
27#include "mlir/IR/Operation.h"
28#include "mlir/IR/TypeUtilities.h"
29#include "mlir/Interfaces/CallInterfaces.h"
30#include "llvm/ADT/APFloat.h"
31#include "llvm/ADT/APInt.h"
32#include "llvm/ADT/ArrayRef.h"
33#include "llvm/ADT/STLExtras.h"
34#include "llvm/ADT/StringExtras.h"
35#include "llvm/Support/FormatVariadic.h"
36#include <cassert>
37#include <numeric>
38
39using namespace mlir;
40
41// TODO: generate these strings using ODS.
42constexpr char kAlignmentAttrName[] = "alignment";
43constexpr char kBranchWeightAttrName[] = "branch_weights";
44constexpr char kCallee[] = "callee";
45constexpr char kClusterSize[] = "cluster_size";
46constexpr char kControl[] = "control";
47constexpr char kDefaultValueAttrName[] = "default_value";
48constexpr char kEqualSemanticsAttrName[] = "equal_semantics";
49constexpr char kExecutionScopeAttrName[] = "execution_scope";
50constexpr char kFnNameAttrName[] = "fn";
51constexpr char kGroupOperationAttrName[] = "group_operation";
52constexpr char kIndicesAttrName[] = "indices";
53constexpr char kInitializerAttrName[] = "initializer";
54constexpr char kInterfaceAttrName[] = "interface";
55constexpr char kMemoryAccessAttrName[] = "memory_access";
56constexpr char kMemoryScopeAttrName[] = "memory_scope";
57constexpr char kPackedVectorFormatAttrName[] = "format";
58constexpr char kSemanticsAttrName[] = "semantics";
59constexpr char kSourceAlignmentAttrName[] = "source_alignment";
60constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access";
61constexpr char kSpecIdAttrName[] = "spec_id";
62constexpr char kTypeAttrName[] = "type";
63constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics";
64constexpr char kValueAttrName[] = "value";
65constexpr char kValuesAttrName[] = "values";
66constexpr char kCompositeSpecConstituentsName[] = "constituents";
67
68//===----------------------------------------------------------------------===//
69// Common utility functions
70//===----------------------------------------------------------------------===//
71
72static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
73 OperationState &result) {
74 SmallVector<OpAsmParser::UnresolvedOperand, 2> ops;
75 Type type;
76 // If the operand list is in-between parentheses, then we have a generic form.
77 // (see the fallback in `printOneResultOp`).
78 SMLoc loc = parser.getCurrentLocation();
79 if (!parser.parseOptionalLParen()) {
80 if (parser.parseOperandList(ops) || parser.parseRParen() ||
81 parser.parseOptionalAttrDict(result.attributes) ||
82 parser.parseColon() || parser.parseType(type))
83 return failure();
84 auto fnType = type.dyn_cast<FunctionType>();
85 if (!fnType) {
86 parser.emitError(loc, "expected function type");
87 return failure();
88 }
89 if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands))
90 return failure();
91 result.addTypes(fnType.getResults());
92 return success();
93 }
94 return failure(parser.parseOperandList(ops) ||
95 parser.parseOptionalAttrDict(result.attributes) ||
96 parser.parseColonType(type) ||
97 parser.resolveOperands(ops, type, result.operands) ||
98 parser.addTypeToList(type, result.types));
99}
100
101static void printOneResultOp(Operation *op, OpAsmPrinter &p) {
102 assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 &&
"op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 102, __extension__
__PRETTY_FUNCTION__))
;
103
104 // If not all the operand and result types are the same, just use the
105 // generic assembly form to avoid omitting information in printing.
106 auto resultType = op->getResult(0).getType();
107 if (llvm::any_of(op->getOperandTypes(),
108 [&](Type type) { return type != resultType; })) {
109 p.printGenericOp(op, /*printOpName=*/false);
110 return;
111 }
112
113 p << ' ';
114 p.printOperands(op->getOperands());
115 p.printOptionalAttrDict(op->getAttrs());
116 // Now we can output only one type for all operands and the result.
117 p << " : " << resultType;
118}
119
120/// Returns true if the given op is a function-like op or nested in a
121/// function-like op without a module-like op in the middle.
122static bool isNestedInFunctionOpInterface(Operation *op) {
123 if (!op)
124 return false;
125 if (op->hasTrait<OpTrait::SymbolTable>())
126 return false;
127 if (isa<FunctionOpInterface>(op))
128 return true;
129 return isNestedInFunctionOpInterface(op->getParentOp());
130}
131
132/// Returns true if the given op is an module-like op that maintains a symbol
133/// table.
134static bool isDirectInModuleLikeOp(Operation *op) {
135 return op && op->hasTrait<OpTrait::SymbolTable>();
136}
137
138static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) {
139 auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op);
140 if (!constOp) {
141 return failure();
142 }
143 auto valueAttr = constOp.getValue();
144 auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>();
145 if (!integerValueAttr) {
146 return failure();
147 }
148
149 if (integerValueAttr.getType().isSignlessInteger())
150 value = integerValueAttr.getInt();
151 else
152 value = integerValueAttr.getSInt();
153
154 return success();
155}
156
157template <typename Ty>
158static ArrayAttr
159getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
160 function_ref<StringRef(Ty)> stringifyFn) {
161 if (enumValues.empty()) {
162 return nullptr;
163 }
164 SmallVector<StringRef, 1> enumValStrs;
165 enumValStrs.reserve(enumValues.size());
166 for (auto val : enumValues) {
167 enumValStrs.emplace_back(stringifyFn(val));
168 }
169 return builder.getStrArrayAttr(enumValStrs);
170}
171
172/// Parses the next string attribute in `parser` as an enumerant of the given
173/// `EnumClass`.
174template <typename EnumClass>
175static ParseResult
176parseEnumStrAttr(EnumClass &value, OpAsmParser &parser,
177 StringRef attrName = spirv::attributeName<EnumClass>()) {
178 Attribute attrVal;
179 NamedAttrList attr;
180 auto loc = parser.getCurrentLocation();
181 if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(),
4
Taking false branch
182 attrName, attr))
183 return failure();
184 if (!attrVal.isa<StringAttr>())
5
Taking true branch
185 return parser.emitError(loc, "expected ")
186 << attrName << " attribute specified as string";
187 auto attrOptional =
188 spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue());
189 if (!attrOptional)
190 return parser.emitError(loc, "invalid ")
191 << attrName << " attribute specification: " << attrVal;
192 value = *attrOptional;
193 return success();
194}
195
196/// Parses the next string attribute in `parser` as an enumerant of the given
197/// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer
198/// attribute with the enum class's name as attribute name.
199template <typename EnumAttrClass,
200 typename EnumClass = typename EnumAttrClass::ValueType>
201static ParseResult
202parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state,
203 StringRef attrName = spirv::attributeName<EnumClass>()) {
204 if (parseEnumStrAttr(value, parser))
3
Calling 'parseEnumStrAttr<mlir::spirv::ExecutionModel>'
6
Returning from 'parseEnumStrAttr<mlir::spirv::ExecutionModel>'
7
Taking false branch
205 return failure();
206 state.addAttribute(attrName,
207 parser.getBuilder().getAttr<EnumAttrClass>(value));
8
Calling 'Builder::getAttr'
208 return success();
209}
210
211/// Parses the next keyword in `parser` as an enumerant of the given `EnumClass`
212/// and inserts the enumerant into `state` as an 32-bit integer attribute with
213/// the enum class's name as attribute name.
214template <typename EnumAttrClass,
215 typename EnumClass = typename EnumAttrClass::ValueType>
216static ParseResult
217parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser,
218 OperationState &state,
219 StringRef attrName = spirv::attributeName<EnumClass>()) {
220 if (parseEnumKeywordAttr(value, parser))
221 return failure();
222 state.addAttribute(attrName,
223 parser.getBuilder().getAttr<EnumAttrClass>(value));
224 return success();
225}
226
227/// Parses Function, Selection and Loop control attributes. If no control is
228/// specified, "None" is used as a default.
229template <typename EnumAttrClass, typename EnumClass>
230static ParseResult
231parseControlAttribute(OpAsmParser &parser, OperationState &state,
232 StringRef attrName = spirv::attributeName<EnumClass>()) {
233 if (succeeded(parser.parseOptionalKeyword(kControl))) {
234 EnumClass control;
235 if (parser.parseLParen() ||
236 parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) ||
237 parser.parseRParen())
238 return failure();
239 return success();
240 }
241 // Set control to "None" otherwise.
242 Builder builder = parser.getBuilder();
243 state.addAttribute(attrName,
244 builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0)));
245 return success();
246}
247
248/// Parses optional memory access attributes attached to a memory access
249/// operand/pointer. Specifically, parses the following syntax:
250/// (`[` memory-access `]`)?
251/// where:
252/// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", `
253/// integer-literal | `"NonTemporal"`
254static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser,
255 OperationState &state) {
256 // Parse an optional list of attributes staring with '['
257 if (parser.parseOptionalLSquare()) {
258 // Nothing to do
259 return success();
260 }
261
262 spirv::MemoryAccess memoryAccessAttr;
263 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
264 kMemoryAccessAttrName))
265 return failure();
266
267 if (spirv::bitEnumContainsAll(memoryAccessAttr,
268 spirv::MemoryAccess::Aligned)) {
269 // Parse integer attribute for alignment.
270 Attribute alignmentAttr;
271 Type i32Type = parser.getBuilder().getIntegerType(32);
272 if (parser.parseComma() ||
273 parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName,
274 state.attributes)) {
275 return failure();
276 }
277 }
278 return parser.parseRSquare();
279}
280
281// TODO Make sure to merge this and the previous function into one template
282// parameterized by memory access attribute name and alignment. Doing so now
283// results in VS2017 in producing an internal error (at the call site) that's
284// not detailed enough to understand what is happening.
285static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser,
286 OperationState &state) {
287 // Parse an optional list of attributes staring with '['
288 if (parser.parseOptionalLSquare()) {
289 // Nothing to do
290 return success();
291 }
292
293 spirv::MemoryAccess memoryAccessAttr;
294 if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state,
295 kSourceMemoryAccessAttrName))
296 return failure();
297
298 if (spirv::bitEnumContainsAll(memoryAccessAttr,
299 spirv::MemoryAccess::Aligned)) {
300 // Parse integer attribute for alignment.
301 Attribute alignmentAttr;
302 Type i32Type = parser.getBuilder().getIntegerType(32);
303 if (parser.parseComma() ||
304 parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName,
305 state.attributes)) {
306 return failure();
307 }
308 }
309 return parser.parseRSquare();
310}
311
312template <typename MemoryOpTy>
313static void printMemoryAccessAttribute(
314 MemoryOpTy memoryOp, OpAsmPrinter &printer,
315 SmallVectorImpl<StringRef> &elidedAttrs,
316 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
317 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
318 // Print optional memory access attribute.
319 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
320 : memoryOp.getMemoryAccess())) {
321 elidedAttrs.push_back(kMemoryAccessAttrName);
322
323 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
324
325 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
326 // Print integer alignment attribute.
327 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
328 : memoryOp.getAlignment())) {
329 elidedAttrs.push_back(kAlignmentAttrName);
330 printer << ", " << *alignment;
331 }
332 }
333 printer << "]";
334 }
335 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
336}
337
338// TODO Make sure to merge this and the previous function into one template
339// parameterized by memory access attribute name and alignment. Doing so now
340// results in VS2017 in producing an internal error (at the call site) that's
341// not detailed enough to understand what is happening.
342template <typename MemoryOpTy>
343static void printSourceMemoryAccessAttribute(
344 MemoryOpTy memoryOp, OpAsmPrinter &printer,
345 SmallVectorImpl<StringRef> &elidedAttrs,
346 std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt,
347 std::optional<uint32_t> alignmentAttrValue = std::nullopt) {
348
349 printer << ", ";
350
351 // Print optional memory access attribute.
352 if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue
353 : memoryOp.getMemoryAccess())) {
354 elidedAttrs.push_back(kSourceMemoryAccessAttrName);
355
356 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"";
357
358 if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) {
359 // Print integer alignment attribute.
360 if (auto alignment = (alignmentAttrValue ? alignmentAttrValue
361 : memoryOp.getAlignment())) {
362 elidedAttrs.push_back(kSourceAlignmentAttrName);
363 printer << ", " << *alignment;
364 }
365 }
366 printer << "]";
367 }
368 elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>());
369}
370
371static ParseResult parseImageOperands(OpAsmParser &parser,
372 spirv::ImageOperandsAttr &attr) {
373 // Expect image operands
374 if (parser.parseOptionalLSquare())
375 return success();
376
377 spirv::ImageOperands imageOperands;
378 if (parseEnumStrAttr(imageOperands, parser))
379 return failure();
380
381 attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands);
382
383 return parser.parseRSquare();
384}
385
386static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp,
387 spirv::ImageOperandsAttr attr) {
388 if (attr) {
389 auto strImageOperands = stringifyImageOperands(attr.getValue());
390 printer << "[\"" << strImageOperands << "\"]";
391 }
392}
393
394template <typename Op>
395static LogicalResult verifyImageOperands(Op imageOp,
396 spirv::ImageOperandsAttr attr,
397 Operation::operand_range operands) {
398 if (!attr) {
399 if (operands.empty())
400 return success();
401
402 return imageOp.emitError("the Image Operands should encode what operands "
403 "follow, as per Image Operands");
404 }
405
406 // TODO: Add the validation rules for the following Image Operands.
407 spirv::ImageOperands noSupportOperands =
408 spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
409 spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
410 spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
411 spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
412 spirv::ImageOperands::MakeTexelAvailable |
413 spirv::ImageOperands::MakeTexelVisible |
414 spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
415
416 if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands))
417 llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 417)
;
418
419 return success();
420}
421
422static LogicalResult verifyCastOp(Operation *op,
423 bool requireSameBitWidth = true,
424 bool skipBitWidthCheck = false) {
425 // Some CastOps have no limit on bit widths for result and operand type.
426 if (skipBitWidthCheck)
427 return success();
428
429 Type operandType = op->getOperand(0).getType();
430 Type resultType = op->getResult(0).getType();
431
432 // ODS checks that result type and operand type have the same shape.
433 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
434 operandType = vectorType.getElementType();
435 resultType = resultType.cast<VectorType>().getElementType();
436 }
437
438 if (auto coopMatrixType =
439 operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
440 operandType = coopMatrixType.getElementType();
441 resultType =
442 resultType.cast<spirv::CooperativeMatrixNVType>().getElementType();
443 }
444
445 if (auto jointMatrixType =
446 operandType.dyn_cast<spirv::JointMatrixINTELType>()) {
447 operandType = jointMatrixType.getElementType();
448 resultType =
449 resultType.cast<spirv::JointMatrixINTELType>().getElementType();
450 }
451
452 auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth();
453 auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth();
454 auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth;
455
456 if (requireSameBitWidth) {
457 if (!isSameBitWidth) {
458 return op->emitOpError(
459 "expected the same bit widths for operand type and result "
460 "type, but provided ")
461 << operandType << " and " << resultType;
462 }
463 return success();
464 }
465
466 if (isSameBitWidth) {
467 return op->emitOpError(
468 "expected the different bit widths for operand type and result "
469 "type, but provided ")
470 << operandType << " and " << resultType;
471 }
472 return success();
473}
474
475template <typename MemoryOpTy>
476static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) {
477 // ODS checks for attributes values. Just need to verify that if the
478 // memory-access attribute is Aligned, then the alignment attribute must be
479 // present.
480 auto *op = memoryOp.getOperation();
481 auto memAccessAttr = op->getAttr(kMemoryAccessAttrName);
482 if (!memAccessAttr) {
483 // Alignment attribute shouldn't be present if memory access attribute is
484 // not present.
485 if (op->getAttr(kAlignmentAttrName)) {
486 return memoryOp.emitOpError(
487 "invalid alignment specification without aligned memory access "
488 "specification");
489 }
490 return success();
491 }
492
493 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
494
495 if (!memAccess) {
496 return memoryOp.emitOpError("invalid memory access specifier: ")
497 << memAccessAttr;
498 }
499
500 if (spirv::bitEnumContainsAll(memAccess.getValue(),
501 spirv::MemoryAccess::Aligned)) {
502 if (!op->getAttr(kAlignmentAttrName)) {
503 return memoryOp.emitOpError("missing alignment value");
504 }
505 } else {
506 if (op->getAttr(kAlignmentAttrName)) {
507 return memoryOp.emitOpError(
508 "invalid alignment specification with non-aligned memory access "
509 "specification");
510 }
511 }
512 return success();
513}
514
515// TODO Make sure to merge this and the previous function into one template
516// parameterized by memory access attribute name and alignment. Doing so now
517// results in VS2017 in producing an internal error (at the call site) that's
518// not detailed enough to understand what is happening.
519template <typename MemoryOpTy>
520static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) {
521 // ODS checks for attributes values. Just need to verify that if the
522 // memory-access attribute is Aligned, then the alignment attribute must be
523 // present.
524 auto *op = memoryOp.getOperation();
525 auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName);
526 if (!memAccessAttr) {
527 // Alignment attribute shouldn't be present if memory access attribute is
528 // not present.
529 if (op->getAttr(kSourceAlignmentAttrName)) {
530 return memoryOp.emitOpError(
531 "invalid alignment specification without aligned memory access "
532 "specification");
533 }
534 return success();
535 }
536
537 auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>();
538
539 if (!memAccess) {
540 return memoryOp.emitOpError("invalid memory access specifier: ")
541 << memAccess;
542 }
543
544 if (spirv::bitEnumContainsAll(memAccess.getValue(),
545 spirv::MemoryAccess::Aligned)) {
546 if (!op->getAttr(kSourceAlignmentAttrName)) {
547 return memoryOp.emitOpError("missing alignment value");
548 }
549 } else {
550 if (op->getAttr(kSourceAlignmentAttrName)) {
551 return memoryOp.emitOpError(
552 "invalid alignment specification with non-aligned memory access "
553 "specification");
554 }
555 }
556 return success();
557}
558
559static LogicalResult
560verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) {
561 // According to the SPIR-V specification:
562 // "Despite being a mask and allowing multiple bits to be combined, it is
563 // invalid for more than one of these four bits to be set: Acquire, Release,
564 // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and
565 // Release semantics is done by setting the AcquireRelease bit, not by setting
566 // two bits."
567 auto atMostOneInSet = spirv::MemorySemantics::Acquire |
568 spirv::MemorySemantics::Release |
569 spirv::MemorySemantics::AcquireRelease |
570 spirv::MemorySemantics::SequentiallyConsistent;
571
572 auto bitCount =
573 llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet));
574 if (bitCount > 1) {
575 return op->emitError(
576 "expected at most one of these four memory constraints "
577 "to be set: `Acquire`, `Release`,"
578 "`AcquireRelease` or `SequentiallyConsistent`");
579 }
580 return success();
581}
582
583template <typename LoadStoreOpTy>
584static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr,
585 Value val) {
586 // ODS already checks ptr is spirv::PointerType. Just check that the pointee
587 // type of the pointer and the type of the value are the same
588 //
589 // TODO: Check that the value type satisfies restrictions of
590 // SPIR-V OpLoad/OpStore operations
591 if (val.getType() !=
592 ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
593 return op.emitOpError("mismatch in result type and pointer type");
594 }
595 return success();
596}
597
598template <typename BlockReadWriteOpTy>
599static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op,
600 Value ptr, Value val) {
601 auto valType = val.getType();
602 if (auto valVecTy = valType.dyn_cast<VectorType>())
603 valType = valVecTy.getElementType();
604
605 if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) {
606 return op.emitOpError("mismatch in result type and pointer type");
607 }
608 return success();
609}
610
611static ParseResult parseVariableDecorations(OpAsmParser &parser,
612 OperationState &state) {
613 auto builtInName = llvm::convertToSnakeFromCamelCase(
614 stringifyDecoration(spirv::Decoration::BuiltIn));
615 if (succeeded(parser.parseOptionalKeyword("bind"))) {
616 Attribute set, binding;
617 // Parse optional descriptor binding
618 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
619 stringifyDecoration(spirv::Decoration::DescriptorSet));
620 auto bindingName = llvm::convertToSnakeFromCamelCase(
621 stringifyDecoration(spirv::Decoration::Binding));
622 Type i32Type = parser.getBuilder().getIntegerType(32);
623 if (parser.parseLParen() ||
624 parser.parseAttribute(set, i32Type, descriptorSetName,
625 state.attributes) ||
626 parser.parseComma() ||
627 parser.parseAttribute(binding, i32Type, bindingName,
628 state.attributes) ||
629 parser.parseRParen()) {
630 return failure();
631 }
632 } else if (succeeded(parser.parseOptionalKeyword(builtInName))) {
633 StringAttr builtIn;
634 if (parser.parseLParen() ||
635 parser.parseAttribute(builtIn, builtInName, state.attributes) ||
636 parser.parseRParen()) {
637 return failure();
638 }
639 }
640
641 // Parse other attributes
642 if (parser.parseOptionalAttrDict(state.attributes))
643 return failure();
644
645 return success();
646}
647
648static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
649 SmallVectorImpl<StringRef> &elidedAttrs) {
650 // Print optional descriptor binding
651 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
652 stringifyDecoration(spirv::Decoration::DescriptorSet));
653 auto bindingName = llvm::convertToSnakeFromCamelCase(
654 stringifyDecoration(spirv::Decoration::Binding));
655 auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName);
656 auto binding = op->getAttrOfType<IntegerAttr>(bindingName);
657 if (descriptorSet && binding) {
658 elidedAttrs.push_back(descriptorSetName);
659 elidedAttrs.push_back(bindingName);
660 printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt()
661 << ")";
662 }
663
664 // Print BuiltIn attribute if present
665 auto builtInName = llvm::convertToSnakeFromCamelCase(
666 stringifyDecoration(spirv::Decoration::BuiltIn));
667 if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) {
668 printer << " " << builtInName << "(\"" << builtin.getValue() << "\")";
669 elidedAttrs.push_back(builtInName);
670 }
671
672 printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
673}
674
675// Get bit width of types.
676static unsigned getBitWidth(Type type) {
677 if (type.isa<spirv::PointerType>()) {
678 // Just return 64 bits for pointer types for now.
679 // TODO: Make sure not caller relies on the actual pointer width value.
680 return 64;
681 }
682
683 if (type.isIntOrFloat())
684 return type.getIntOrFloatBitWidth();
685
686 if (auto vectorType = type.dyn_cast<VectorType>()) {
687 assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 687, __extension__
__PRETTY_FUNCTION__))
;
688 return vectorType.getNumElements() *
689 vectorType.getElementType().getIntOrFloatBitWidth();
690 }
691 llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 691)
;
692}
693
694/// Walks the given type hierarchy with the given indices, potentially down
695/// to component granularity, to select an element type. Returns null type and
696/// emits errors with the given loc on failure.
697static Type
698getElementType(Type type, ArrayRef<int32_t> indices,
699 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
700 if (indices.empty()) {
701 emitErrorFn("expected at least one index for spirv.CompositeExtract");
702 return nullptr;
703 }
704
705 for (auto index : indices) {
706 if (auto cType = type.dyn_cast<spirv::CompositeType>()) {
707 if (cType.hasCompileTimeKnownNumElements() &&
708 (index < 0 ||
709 static_cast<uint64_t>(index) >= cType.getNumElements())) {
710 emitErrorFn("index ") << index << " out of bounds for " << type;
711 return nullptr;
712 }
713 type = cType.getElementType(index);
714 } else {
715 emitErrorFn("cannot extract from non-composite type ")
716 << type << " with index " << index;
717 return nullptr;
718 }
719 }
720 return type;
721}
722
723static Type
724getElementType(Type type, Attribute indices,
725 function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) {
726 auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>();
727 if (!indicesArrayAttr) {
728 emitErrorFn("expected a 32-bit integer array attribute for 'indices'");
729 return nullptr;
730 }
731 if (indicesArrayAttr.empty()) {
732 emitErrorFn("expected at least one index for spirv.CompositeExtract");
733 return nullptr;
734 }
735
736 SmallVector<int32_t, 2> indexVals;
737 for (auto indexAttr : indicesArrayAttr) {
738 auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>();
739 if (!indexIntAttr) {
740 emitErrorFn("expected an 32-bit integer for index, but found '")
741 << indexAttr << "'";
742 return nullptr;
743 }
744 indexVals.push_back(indexIntAttr.getInt());
745 }
746 return getElementType(type, indexVals, emitErrorFn);
747}
748
749static Type getElementType(Type type, Attribute indices, Location loc) {
750 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
751 return ::mlir::emitError(loc, err);
752 };
753 return getElementType(type, indices, errorFn);
754}
755
756static Type getElementType(Type type, Attribute indices, OpAsmParser &parser,
757 SMLoc loc) {
758 auto errorFn = [&](StringRef err) -> InFlightDiagnostic {
759 return parser.emitError(loc, err);
760 };
761 return getElementType(type, indices, errorFn);
762}
763
764/// Returns true if the given `block` only contains one `spirv.mlir.merge` op.
765static inline bool isMergeBlock(Block &block) {
766 return !block.empty() && std::next(block.begin()) == block.end() &&
767 isa<spirv::MergeOp>(block.front());
768}
769
770template <typename ExtendedBinaryOp>
771static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) {
772 auto resultType = op.getType().template cast<spirv::StructType>();
773 if (resultType.getNumElements() != 2)
774 return op.emitOpError("expected result struct type containing two members");
775
776 if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(),
777 resultType.getElementType(0),
778 resultType.getElementType(1)}))
779 return op.emitOpError(
780 "expected all operand types and struct member types are the same");
781
782 return success();
783}
784
785static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser,
786 OperationState &result) {
787 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
788 if (parser.parseOptionalAttrDict(result.attributes) ||
789 parser.parseOperandList(operands) || parser.parseColon())
790 return failure();
791
792 Type resultType;
793 SMLoc loc = parser.getCurrentLocation();
794 if (parser.parseType(resultType))
795 return failure();
796
797 auto structType = resultType.dyn_cast<spirv::StructType>();
798 if (!structType || structType.getNumElements() != 2)
799 return parser.emitError(loc, "expected spirv.struct type with two members");
800
801 SmallVector<Type, 2> operandTypes(2, structType.getElementType(0));
802 if (parser.resolveOperands(operands, operandTypes, loc, result.operands))
803 return failure();
804
805 result.addTypes(resultType);
806 return success();
807}
808
809static void printArithmeticExtendedBinaryOp(Operation *op,
810 OpAsmPrinter &printer) {
811 printer << ' ';
812 printer.printOptionalAttrDict(op->getAttrs());
813 printer.printOperands(op->getOperands());
814 printer << " : " << op->getResultTypes().front();
815}
816
817//===----------------------------------------------------------------------===//
818// Common parsers and printers
819//===----------------------------------------------------------------------===//
820
821// Parses an atomic update op. If the update op does not take a value (like
822// AtomicIIncrement) `hasValue` must be false.
823static ParseResult parseAtomicUpdateOp(OpAsmParser &parser,
824 OperationState &state, bool hasValue) {
825 spirv::Scope scope;
826 spirv::MemorySemantics memoryScope;
827 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
828 OpAsmParser::UnresolvedOperand ptrInfo, valueInfo;
829 Type type;
830 SMLoc loc;
831 if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state,
832 kMemoryScopeAttrName) ||
833 parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state,
834 kSemanticsAttrName) ||
835 parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) ||
836 parser.getCurrentLocation(&loc) || parser.parseColonType(type))
837 return failure();
838
839 auto ptrType = type.dyn_cast<spirv::PointerType>();
840 if (!ptrType)
841 return parser.emitError(loc, "expected pointer type");
842
843 SmallVector<Type, 2> operandTypes;
844 operandTypes.push_back(ptrType);
845 if (hasValue)
846 operandTypes.push_back(ptrType.getPointeeType());
847 if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(),
848 state.operands))
849 return failure();
850 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
851}
852
853// Prints an atomic update op.
854static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) {
855 printer << " \"";
856 auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName);
857 printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \"";
858 auto memorySemanticsAttr =
859 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName);
860 printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue())
861 << "\" " << op->getOperands() << " : " << op->getOperand(0).getType();
862}
863
864template <typename T>
865static StringRef stringifyTypeName();
866
867template <>
868StringRef stringifyTypeName<IntegerType>() {
869 return "integer";
870}
871
872template <>
873StringRef stringifyTypeName<FloatType>() {
874 return "float";
875}
876
877// Verifies an atomic update op.
878template <typename ExpectedElementType>
879static LogicalResult verifyAtomicUpdateOp(Operation *op) {
880 auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>();
881 auto elementType = ptrType.getPointeeType();
882 if (!elementType.isa<ExpectedElementType>())
883 return op->emitOpError() << "pointer operand must point to an "
884 << stringifyTypeName<ExpectedElementType>()
885 << " value, found " << elementType;
886
887 if (op->getNumOperands() > 1) {
888 auto valueType = op->getOperand(1).getType();
889 if (valueType != elementType)
890 return op->emitOpError("expected value to have the same type as the "
891 "pointer operand's pointee type ")
892 << elementType << ", but found " << valueType;
893 }
894 auto memorySemantics =
895 op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName)
896 .getValue();
897 if (failed(verifyMemorySemantics(op, memorySemantics))) {
898 return failure();
899 }
900 return success();
901}
902
903static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser,
904 OperationState &state) {
905 spirv::Scope executionScope;
906 spirv::GroupOperation groupOperation;
907 OpAsmParser::UnresolvedOperand valueInfo;
908 if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state,
909 kExecutionScopeAttrName) ||
910 parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state,
911 kGroupOperationAttrName) ||
912 parser.parseOperand(valueInfo))
913 return failure();
914
915 std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo;
916 if (succeeded(parser.parseOptionalKeyword(kClusterSize))) {
917 clusterSizeInfo = OpAsmParser::UnresolvedOperand();
918 if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) ||
919 parser.parseRParen())
920 return failure();
921 }
922
923 Type resultType;
924 if (parser.parseColonType(resultType))
925 return failure();
926
927 if (parser.resolveOperand(valueInfo, resultType, state.operands))
928 return failure();
929
930 if (clusterSizeInfo) {
931 Type i32Type = parser.getBuilder().getIntegerType(32);
932 if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands))
933 return failure();
934 }
935
936 return parser.addTypeToList(resultType, state.types);
937}
938
939static void printGroupNonUniformArithmeticOp(Operation *groupOp,
940 OpAsmPrinter &printer) {
941 printer
942 << " \""
943 << stringifyScope(
944 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
945 .getValue())
946 << "\" \""
947 << stringifyGroupOperation(groupOp
948 ->getAttrOfType<spirv::GroupOperationAttr>(
949 kGroupOperationAttrName)
950 .getValue())
951 << "\" " << groupOp->getOperand(0);
952
953 if (groupOp->getNumOperands() > 1)
954 printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')';
955 printer << " : " << groupOp->getResult(0).getType();
956}
957
958static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) {
959 spirv::Scope scope =
960 groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName)
961 .getValue();
962 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
963 return groupOp->emitOpError(
964 "execution scope must be 'Workgroup' or 'Subgroup'");
965
966 spirv::GroupOperation operation =
967 groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName)
968 .getValue();
969 if (operation == spirv::GroupOperation::ClusteredReduce &&
970 groupOp->getNumOperands() == 1)
971 return groupOp->emitOpError("cluster size operand must be provided for "
972 "'ClusteredReduce' group operation");
973 if (groupOp->getNumOperands() > 1) {
974 Operation *sizeOp = groupOp->getOperand(1).getDefiningOp();
975 int32_t clusterSize = 0;
976
977 // TODO: support specialization constant here.
978 if (failed(extractValueFromConstOp(sizeOp, clusterSize)))
979 return groupOp->emitOpError(
980 "cluster size operand must come from a constant op");
981
982 if (!llvm::isPowerOf2_32(clusterSize))
983 return groupOp->emitOpError(
984 "cluster size operand must be a power of two");
985 }
986 return success();
987}
988
989/// Result of a logical op must be a scalar or vector of boolean type.
990static Type getUnaryOpResultType(Type operandType) {
991 Builder builder(operandType.getContext());
992 Type resultType = builder.getIntegerType(1);
993 if (auto vecType = operandType.dyn_cast<VectorType>())
994 return VectorType::get(vecType.getNumElements(), resultType);
995 return resultType;
996}
997
998static LogicalResult verifyShiftOp(Operation *op) {
999 if (op->getOperand(0).getType() != op->getResult(0).getType()) {
1000 return op->emitError("expected the same type for the first operand and "
1001 "result, but provided ")
1002 << op->getOperand(0).getType() << " and "
1003 << op->getResult(0).getType();
1004 }
1005 return success();
1006}
1007
1008//===----------------------------------------------------------------------===//
1009// spirv.AccessChainOp
1010//===----------------------------------------------------------------------===//
1011
1012static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) {
1013 auto ptrType = type.dyn_cast<spirv::PointerType>();
1014 if (!ptrType) {
1015 emitError(baseLoc, "'spirv.AccessChain' op expected a pointer "
1016 "to composite type, but provided ")
1017 << type;
1018 return nullptr;
1019 }
1020
1021 auto resultType = ptrType.getPointeeType();
1022 auto resultStorageClass = ptrType.getStorageClass();
1023 int32_t index = 0;
1024
1025 for (auto indexSSA : indices) {
1026 auto cType = resultType.dyn_cast<spirv::CompositeType>();
1027 if (!cType) {
1028 emitError(
1029 baseLoc,
1030 "'spirv.AccessChain' op cannot extract from non-composite type ")
1031 << resultType << " with index " << index;
1032 return nullptr;
1033 }
1034 index = 0;
1035 if (resultType.isa<spirv::StructType>()) {
1036 Operation *op = indexSSA.getDefiningOp();
1037 if (!op) {
1038 emitError(baseLoc, "'spirv.AccessChain' op index must be an "
1039 "integer spirv.Constant to access "
1040 "element of spirv.struct");
1041 return nullptr;
1042 }
1043
1044 // TODO: this should be relaxed to allow
1045 // integer literals of other bitwidths.
1046 if (failed(extractValueFromConstOp(op, index))) {
1047 emitError(
1048 baseLoc,
1049 "'spirv.AccessChain' index must be an integer spirv.Constant to "
1050 "access element of spirv.struct, but provided ")
1051 << op->getName();
1052 return nullptr;
1053 }
1054 if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) {
1055 emitError(baseLoc, "'spirv.AccessChain' op index ")
1056 << index << " out of bounds for " << resultType;
1057 return nullptr;
1058 }
1059 }
1060 resultType = cType.getElementType(index);
1061 }
1062 return spirv::PointerType::get(resultType, resultStorageClass);
1063}
1064
1065void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state,
1066 Value basePtr, ValueRange indices) {
1067 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
1068 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1068, __extension__
__PRETTY_FUNCTION__))
;
1069 build(builder, state, type, basePtr, indices);
1070}
1071
1072ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser,
1073 OperationState &result) {
1074 OpAsmParser::UnresolvedOperand ptrInfo;
1075 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
1076 Type type;
1077 auto loc = parser.getCurrentLocation();
1078 SmallVector<Type, 4> indicesTypes;
1079
1080 if (parser.parseOperand(ptrInfo) ||
1081 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
1082 parser.parseColonType(type) ||
1083 parser.resolveOperand(ptrInfo, type, result.operands)) {
1084 return failure();
1085 }
1086
1087 // Check that the provided indices list is not empty before parsing their
1088 // type list.
1089 if (indicesInfo.empty()) {
1090 return mlir::emitError(result.location,
1091 "'spirv.AccessChain' op expected at "
1092 "least one index ");
1093 }
1094
1095 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
1096 return failure();
1097
1098 // Check that the indices types list is not empty and that it has a one-to-one
1099 // mapping to the provided indices.
1100 if (indicesTypes.size() != indicesInfo.size()) {
1101 return mlir::emitError(
1102 result.location, "'spirv.AccessChain' op indices types' count must be "
1103 "equal to indices info count");
1104 }
1105
1106 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands))
1107 return failure();
1108
1109 auto resultType = getElementPtrType(
1110 type, llvm::ArrayRef(result.operands).drop_front(), result.location);
1111 if (!resultType) {
1112 return failure();
1113 }
1114
1115 result.addTypes(resultType);
1116 return success();
1117}
1118
1119template <typename Op>
1120static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) {
1121 printer << ' ' << op.getBasePtr() << '[' << indices
1122 << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes();
1123}
1124
1125void spirv::AccessChainOp::print(OpAsmPrinter &printer) {
1126 printAccessChain(*this, getIndices(), printer);
1127}
1128
1129template <typename Op>
1130static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) {
1131 auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(),
1132 indices, accessChainOp.getLoc());
1133 if (!resultType)
1134 return failure();
1135
1136 auto providedResultType =
1137 accessChainOp.getType().template dyn_cast<spirv::PointerType>();
1138 if (!providedResultType)
1139 return accessChainOp.emitOpError(
1140 "result type must be a pointer, but provided")
1141 << providedResultType;
1142
1143 if (resultType != providedResultType)
1144 return accessChainOp.emitOpError("invalid result type: expected ")
1145 << resultType << ", but provided " << providedResultType;
1146
1147 return success();
1148}
1149
1150LogicalResult spirv::AccessChainOp::verify() {
1151 return verifyAccessChain(*this, getIndices());
1152}
1153
1154//===----------------------------------------------------------------------===//
1155// spirv.mlir.addressof
1156//===----------------------------------------------------------------------===//
1157
1158void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state,
1159 spirv::GlobalVariableOp var) {
1160 build(builder, state, var.getType(), SymbolRefAttr::get(var));
1161}
1162
1163LogicalResult spirv::AddressOfOp::verify() {
1164 auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>(
1165 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(),
1166 getVariableAttr()));
1167 if (!varOp) {
1168 return emitOpError("expected spirv.GlobalVariable symbol");
1169 }
1170 if (getPointer().getType() != varOp.getType()) {
1171 return emitOpError(
1172 "result type mismatch with the referenced global variable's type");
1173 }
1174 return success();
1175}
1176
1177template <typename T>
1178static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) {
1179 printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \""
1180 << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \""
1181 << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" "
1182 << atomOp.getOperands() << " : " << atomOp.getPointer().getType();
1183}
1184
1185static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser,
1186 OperationState &state) {
1187 spirv::Scope memoryScope;
1188 spirv::MemorySemantics equalSemantics, unequalSemantics;
1189 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
1190 Type type;
1191 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state,
1192 kMemoryScopeAttrName) ||
1193 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1194 equalSemantics, parser, state, kEqualSemanticsAttrName) ||
1195 parseEnumStrAttr<spirv::MemorySemanticsAttr>(
1196 unequalSemantics, parser, state, kUnequalSemanticsAttrName) ||
1197 parser.parseOperandList(operandInfo, 3))
1198 return failure();
1199
1200 auto loc = parser.getCurrentLocation();
1201 if (parser.parseColonType(type))
1202 return failure();
1203
1204 auto ptrType = type.dyn_cast<spirv::PointerType>();
1205 if (!ptrType)
1206 return parser.emitError(loc, "expected pointer type");
1207
1208 if (parser.resolveOperands(
1209 operandInfo,
1210 {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()},
1211 parser.getNameLoc(), state.operands))
1212 return failure();
1213
1214 return parser.addTypeToList(ptrType.getPointeeType(), state.types);
1215}
1216
1217template <typename T>
1218static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) {
1219 // According to the spec:
1220 // "The type of Value must be the same as Result Type. The type of the value
1221 // pointed to by Pointer must be the same as Result Type. This type must also
1222 // match the type of Comparator."
1223 if (atomOp.getType() != atomOp.getValue().getType())
1224 return atomOp.emitOpError("value operand must have the same type as the op "
1225 "result, but found ")
1226 << atomOp.getValue().getType() << " vs " << atomOp.getType();
1227
1228 if (atomOp.getType() != atomOp.getComparator().getType())
1229 return atomOp.emitOpError(
1230 "comparator operand must have the same type as the op "
1231 "result, but found ")
1232 << atomOp.getComparator().getType() << " vs " << atomOp.getType();
1233
1234 Type pointeeType = atomOp.getPointer()
1235 .getType()
1236 .template cast<spirv::PointerType>()
1237 .getPointeeType();
1238 if (atomOp.getType() != pointeeType)
1239 return atomOp.emitOpError(
1240 "pointer operand's pointee type must have the same "
1241 "as the op result type, but found ")
1242 << pointeeType << " vs " << atomOp.getType();
1243
1244 // TODO: Unequal cannot be set to Release or Acquire and Release.
1245 // In addition, Unequal cannot be set to a stronger memory-order then Equal.
1246
1247 return success();
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// spirv.AtomicAndOp
1252//===----------------------------------------------------------------------===//
1253
1254LogicalResult spirv::AtomicAndOp::verify() {
1255 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1256}
1257
1258ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser,
1259 OperationState &result) {
1260 return ::parseAtomicUpdateOp(parser, result, true);
1261}
1262void spirv::AtomicAndOp::print(OpAsmPrinter &p) {
1263 ::printAtomicUpdateOp(*this, p);
1264}
1265
1266//===----------------------------------------------------------------------===//
1267// spirv.AtomicCompareExchangeOp
1268//===----------------------------------------------------------------------===//
1269
1270LogicalResult spirv::AtomicCompareExchangeOp::verify() {
1271 return ::verifyAtomicCompareExchangeImpl(*this);
1272}
1273
1274ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser,
1275 OperationState &result) {
1276 return ::parseAtomicCompareExchangeImpl(parser, result);
1277}
1278void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) {
1279 ::printAtomicCompareExchangeImpl(*this, p);
1280}
1281
1282//===----------------------------------------------------------------------===//
1283// spirv.AtomicCompareExchangeWeakOp
1284//===----------------------------------------------------------------------===//
1285
1286LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() {
1287 return ::verifyAtomicCompareExchangeImpl(*this);
1288}
1289
1290ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser,
1291 OperationState &result) {
1292 return ::parseAtomicCompareExchangeImpl(parser, result);
1293}
1294void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) {
1295 ::printAtomicCompareExchangeImpl(*this, p);
1296}
1297
1298//===----------------------------------------------------------------------===//
1299// spirv.AtomicExchange
1300//===----------------------------------------------------------------------===//
1301
1302void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) {
1303 printer << " \"" << stringifyScope(getMemoryScope()) << "\" \""
1304 << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands()
1305 << " : " << getPointer().getType();
1306}
1307
1308ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser,
1309 OperationState &result) {
1310 spirv::Scope memoryScope;
1311 spirv::MemorySemantics semantics;
1312 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
1313 Type type;
1314 if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result,
1315 kMemoryScopeAttrName) ||
1316 parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result,
1317 kSemanticsAttrName) ||
1318 parser.parseOperandList(operandInfo, 2))
1319 return failure();
1320
1321 auto loc = parser.getCurrentLocation();
1322 if (parser.parseColonType(type))
1323 return failure();
1324
1325 auto ptrType = type.dyn_cast<spirv::PointerType>();
1326 if (!ptrType)
1327 return parser.emitError(loc, "expected pointer type");
1328
1329 if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()},
1330 parser.getNameLoc(), result.operands))
1331 return failure();
1332
1333 return parser.addTypeToList(ptrType.getPointeeType(), result.types);
1334}
1335
1336LogicalResult spirv::AtomicExchangeOp::verify() {
1337 if (getType() != getValue().getType())
1338 return emitOpError("value operand must have the same type as the op "
1339 "result, but found ")
1340 << getValue().getType() << " vs " << getType();
1341
1342 Type pointeeType =
1343 getPointer().getType().cast<spirv::PointerType>().getPointeeType();
1344 if (getType() != pointeeType)
1345 return emitOpError("pointer operand's pointee type must have the same "
1346 "as the op result type, but found ")
1347 << pointeeType << " vs " << getType();
1348
1349 return success();
1350}
1351
1352//===----------------------------------------------------------------------===//
1353// spirv.AtomicIAddOp
1354//===----------------------------------------------------------------------===//
1355
1356LogicalResult spirv::AtomicIAddOp::verify() {
1357 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1358}
1359
1360ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser,
1361 OperationState &result) {
1362 return ::parseAtomicUpdateOp(parser, result, true);
1363}
1364void spirv::AtomicIAddOp::print(OpAsmPrinter &p) {
1365 ::printAtomicUpdateOp(*this, p);
1366}
1367
1368//===----------------------------------------------------------------------===//
1369// spirv.EXT.AtomicFAddOp
1370//===----------------------------------------------------------------------===//
1371
1372LogicalResult spirv::EXTAtomicFAddOp::verify() {
1373 return ::verifyAtomicUpdateOp<FloatType>(getOperation());
1374}
1375
1376ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser,
1377 OperationState &result) {
1378 return ::parseAtomicUpdateOp(parser, result, true);
1379}
1380void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) {
1381 ::printAtomicUpdateOp(*this, p);
1382}
1383
1384//===----------------------------------------------------------------------===//
1385// spirv.AtomicIDecrementOp
1386//===----------------------------------------------------------------------===//
1387
1388LogicalResult spirv::AtomicIDecrementOp::verify() {
1389 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1390}
1391
1392ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser,
1393 OperationState &result) {
1394 return ::parseAtomicUpdateOp(parser, result, false);
1395}
1396void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) {
1397 ::printAtomicUpdateOp(*this, p);
1398}
1399
1400//===----------------------------------------------------------------------===//
1401// spirv.AtomicIIncrementOp
1402//===----------------------------------------------------------------------===//
1403
1404LogicalResult spirv::AtomicIIncrementOp::verify() {
1405 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1406}
1407
1408ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser,
1409 OperationState &result) {
1410 return ::parseAtomicUpdateOp(parser, result, false);
1411}
1412void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) {
1413 ::printAtomicUpdateOp(*this, p);
1414}
1415
1416//===----------------------------------------------------------------------===//
1417// spirv.AtomicISubOp
1418//===----------------------------------------------------------------------===//
1419
1420LogicalResult spirv::AtomicISubOp::verify() {
1421 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1422}
1423
1424ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser,
1425 OperationState &result) {
1426 return ::parseAtomicUpdateOp(parser, result, true);
1427}
1428void spirv::AtomicISubOp::print(OpAsmPrinter &p) {
1429 ::printAtomicUpdateOp(*this, p);
1430}
1431
1432//===----------------------------------------------------------------------===//
1433// spirv.AtomicOrOp
1434//===----------------------------------------------------------------------===//
1435
1436LogicalResult spirv::AtomicOrOp::verify() {
1437 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1438}
1439
1440ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser,
1441 OperationState &result) {
1442 return ::parseAtomicUpdateOp(parser, result, true);
1443}
1444void spirv::AtomicOrOp::print(OpAsmPrinter &p) {
1445 ::printAtomicUpdateOp(*this, p);
1446}
1447
1448//===----------------------------------------------------------------------===//
1449// spirv.AtomicSMaxOp
1450//===----------------------------------------------------------------------===//
1451
1452LogicalResult spirv::AtomicSMaxOp::verify() {
1453 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1454}
1455
1456ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser,
1457 OperationState &result) {
1458 return ::parseAtomicUpdateOp(parser, result, true);
1459}
1460void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) {
1461 ::printAtomicUpdateOp(*this, p);
1462}
1463
1464//===----------------------------------------------------------------------===//
1465// spirv.AtomicSMinOp
1466//===----------------------------------------------------------------------===//
1467
1468LogicalResult spirv::AtomicSMinOp::verify() {
1469 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1470}
1471
1472ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser,
1473 OperationState &result) {
1474 return ::parseAtomicUpdateOp(parser, result, true);
1475}
1476void spirv::AtomicSMinOp::print(OpAsmPrinter &p) {
1477 ::printAtomicUpdateOp(*this, p);
1478}
1479
1480//===----------------------------------------------------------------------===//
1481// spirv.AtomicUMaxOp
1482//===----------------------------------------------------------------------===//
1483
1484LogicalResult spirv::AtomicUMaxOp::verify() {
1485 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1486}
1487
1488ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser,
1489 OperationState &result) {
1490 return ::parseAtomicUpdateOp(parser, result, true);
1491}
1492void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) {
1493 ::printAtomicUpdateOp(*this, p);
1494}
1495
1496//===----------------------------------------------------------------------===//
1497// spirv.AtomicUMinOp
1498//===----------------------------------------------------------------------===//
1499
1500LogicalResult spirv::AtomicUMinOp::verify() {
1501 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1502}
1503
1504ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser,
1505 OperationState &result) {
1506 return ::parseAtomicUpdateOp(parser, result, true);
1507}
1508void spirv::AtomicUMinOp::print(OpAsmPrinter &p) {
1509 ::printAtomicUpdateOp(*this, p);
1510}
1511
1512//===----------------------------------------------------------------------===//
1513// spirv.AtomicXorOp
1514//===----------------------------------------------------------------------===//
1515
1516LogicalResult spirv::AtomicXorOp::verify() {
1517 return ::verifyAtomicUpdateOp<IntegerType>(getOperation());
1518}
1519
1520ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser,
1521 OperationState &result) {
1522 return ::parseAtomicUpdateOp(parser, result, true);
1523}
1524void spirv::AtomicXorOp::print(OpAsmPrinter &p) {
1525 ::printAtomicUpdateOp(*this, p);
1526}
1527
1528//===----------------------------------------------------------------------===//
1529// spirv.BitcastOp
1530//===----------------------------------------------------------------------===//
1531
1532LogicalResult spirv::BitcastOp::verify() {
1533 // TODO: The SPIR-V spec validation rules are different for different
1534 // versions.
1535 auto operandType = getOperand().getType();
1536 auto resultType = getResult().getType();
1537 if (operandType == resultType) {
1538 return emitError("result type must be different from operand type");
1539 }
1540 if (operandType.isa<spirv::PointerType>() &&
1541 !resultType.isa<spirv::PointerType>()) {
1542 return emitError(
1543 "unhandled bit cast conversion from pointer type to non-pointer type");
1544 }
1545 if (!operandType.isa<spirv::PointerType>() &&
1546 resultType.isa<spirv::PointerType>()) {
1547 return emitError(
1548 "unhandled bit cast conversion from non-pointer type to pointer type");
1549 }
1550 auto operandBitWidth = getBitWidth(operandType);
1551 auto resultBitWidth = getBitWidth(resultType);
1552 if (operandBitWidth != resultBitWidth) {
1553 return emitOpError("mismatch in result type bitwidth ")
1554 << resultBitWidth << " and operand type bitwidth "
1555 << operandBitWidth;
1556 }
1557 return success();
1558}
1559
1560//===----------------------------------------------------------------------===//
1561// spirv.PtrCastToGenericOp
1562//===----------------------------------------------------------------------===//
1563
1564LogicalResult spirv::PtrCastToGenericOp::verify() {
1565 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1566 auto resultType = getResult().getType().cast<spirv::PointerType>();
1567
1568 spirv::StorageClass operandStorage = operandType.getStorageClass();
1569 if (operandStorage != spirv::StorageClass::Workgroup &&
1570 operandStorage != spirv::StorageClass::CrossWorkgroup &&
1571 operandStorage != spirv::StorageClass::Function)
1572 return emitError("pointer must point to the Workgroup, CrossWorkgroup"
1573 ", or Function Storage Class");
1574
1575 spirv::StorageClass resultStorage = resultType.getStorageClass();
1576 if (resultStorage != spirv::StorageClass::Generic)
1577 return emitError("result type must be of storage class Generic");
1578
1579 Type operandPointeeType = operandType.getPointeeType();
1580 Type resultPointeeType = resultType.getPointeeType();
1581 if (operandPointeeType != resultPointeeType)
1582 return emitOpError("pointer operand's pointee type must have the same "
1583 "as the op result type, but found ")
1584 << operandPointeeType << " vs " << resultPointeeType;
1585 return success();
1586}
1587
1588//===----------------------------------------------------------------------===//
1589// spirv.GenericCastToPtrOp
1590//===----------------------------------------------------------------------===//
1591
1592LogicalResult spirv::GenericCastToPtrOp::verify() {
1593 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1594 auto resultType = getResult().getType().cast<spirv::PointerType>();
1595
1596 spirv::StorageClass operandStorage = operandType.getStorageClass();
1597 if (operandStorage != spirv::StorageClass::Generic)
1598 return emitError("pointer type must be of storage class Generic");
1599
1600 spirv::StorageClass resultStorage = resultType.getStorageClass();
1601 if (resultStorage != spirv::StorageClass::Workgroup &&
1602 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1603 resultStorage != spirv::StorageClass::Function)
1604 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1605 "or Function Storage Class");
1606
1607 Type operandPointeeType = operandType.getPointeeType();
1608 Type resultPointeeType = resultType.getPointeeType();
1609 if (operandPointeeType != resultPointeeType)
1610 return emitOpError("pointer operand's pointee type must have the same "
1611 "as the op result type, but found ")
1612 << operandPointeeType << " vs " << resultPointeeType;
1613 return success();
1614}
1615
1616//===----------------------------------------------------------------------===//
1617// spirv.GenericCastToPtrExplicitOp
1618//===----------------------------------------------------------------------===//
1619
1620LogicalResult spirv::GenericCastToPtrExplicitOp::verify() {
1621 auto operandType = getPointer().getType().cast<spirv::PointerType>();
1622 auto resultType = getResult().getType().cast<spirv::PointerType>();
1623
1624 spirv::StorageClass operandStorage = operandType.getStorageClass();
1625 if (operandStorage != spirv::StorageClass::Generic)
1626 return emitError("pointer type must be of storage class Generic");
1627
1628 spirv::StorageClass resultStorage = resultType.getStorageClass();
1629 if (resultStorage != spirv::StorageClass::Workgroup &&
1630 resultStorage != spirv::StorageClass::CrossWorkgroup &&
1631 resultStorage != spirv::StorageClass::Function)
1632 return emitError("result must point to the Workgroup, CrossWorkgroup, "
1633 "or Function Storage Class");
1634
1635 Type operandPointeeType = operandType.getPointeeType();
1636 Type resultPointeeType = resultType.getPointeeType();
1637 if (operandPointeeType != resultPointeeType)
1638 return emitOpError("pointer operand's pointee type must have the same "
1639 "as the op result type, but found ")
1640 << operandPointeeType << " vs " << resultPointeeType;
1641 return success();
1642}
1643
1644//===----------------------------------------------------------------------===//
1645// spirv.BranchOp
1646//===----------------------------------------------------------------------===//
1647
1648SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) {
1649 assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index"
) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1649, __extension__
__PRETTY_FUNCTION__))
;
1650 return SuccessorOperands(0, getTargetOperandsMutable());
1651}
1652
1653//===----------------------------------------------------------------------===//
1654// spirv.BranchConditionalOp
1655//===----------------------------------------------------------------------===//
1656
1657SuccessorOperands
1658spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) {
1659 assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index"
) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1659, __extension__
__PRETTY_FUNCTION__))
;
1660 return SuccessorOperands(index == kTrueIndex
1661 ? getTrueTargetOperandsMutable()
1662 : getFalseTargetOperandsMutable());
1663}
1664
1665ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser,
1666 OperationState &result) {
1667 auto &builder = parser.getBuilder();
1668 OpAsmParser::UnresolvedOperand condInfo;
1669 Block *dest;
1670
1671 // Parse the condition.
1672 Type boolTy = builder.getI1Type();
1673 if (parser.parseOperand(condInfo) ||
1674 parser.resolveOperand(condInfo, boolTy, result.operands))
1675 return failure();
1676
1677 // Parse the optional branch weights.
1678 if (succeeded(parser.parseOptionalLSquare())) {
1679 IntegerAttr trueWeight, falseWeight;
1680 NamedAttrList weights;
1681
1682 auto i32Type = builder.getIntegerType(32);
1683 if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) ||
1684 parser.parseComma() ||
1685 parser.parseAttribute(falseWeight, i32Type, "weight", weights) ||
1686 parser.parseRSquare())
1687 return failure();
1688
1689 result.addAttribute(kBranchWeightAttrName,
1690 builder.getArrayAttr({trueWeight, falseWeight}));
1691 }
1692
1693 // Parse the true branch.
1694 SmallVector<Value, 4> trueOperands;
1695 if (parser.parseComma() ||
1696 parser.parseSuccessorAndUseList(dest, trueOperands))
1697 return failure();
1698 result.addSuccessors(dest);
1699 result.addOperands(trueOperands);
1700
1701 // Parse the false branch.
1702 SmallVector<Value, 4> falseOperands;
1703 if (parser.parseComma() ||
1704 parser.parseSuccessorAndUseList(dest, falseOperands))
1705 return failure();
1706 result.addSuccessors(dest);
1707 result.addOperands(falseOperands);
1708 result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(),
1709 builder.getDenseI32ArrayAttr(
1710 {1, static_cast<int32_t>(trueOperands.size()),
1711 static_cast<int32_t>(falseOperands.size())}));
1712
1713 return success();
1714}
1715
1716void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) {
1717 printer << ' ' << getCondition();
1718
1719 if (auto weights = getBranchWeights()) {
1720 printer << " [";
1721 llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) {
1722 printer << a.cast<IntegerAttr>().getInt();
1723 });
1724 printer << "]";
1725 }
1726
1727 printer << ", ";
1728 printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments());
1729 printer << ", ";
1730 printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments());
1731}
1732
1733LogicalResult spirv::BranchConditionalOp::verify() {
1734 if (auto weights = getBranchWeights()) {
1735 if (weights->getValue().size() != 2) {
1736 return emitOpError("must have exactly two branch weights");
1737 }
1738 if (llvm::all_of(*weights, [](Attribute attr) {
1739 return attr.cast<IntegerAttr>().getValue().isZero();
1740 }))
1741 return emitOpError("branch weights cannot both be zero");
1742 }
1743
1744 return success();
1745}
1746
1747//===----------------------------------------------------------------------===//
1748// spirv.CompositeConstruct
1749//===----------------------------------------------------------------------===//
1750
1751LogicalResult spirv::CompositeConstructOp::verify() {
1752 auto cType = getType().cast<spirv::CompositeType>();
1753 operand_range constituents = this->getConstituents();
1754
1755 if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) {
1756 if (constituents.size() != 1)
1757 return emitOpError("has incorrect number of operands: expected ")
1758 << "1, but provided " << constituents.size();
1759 if (coopType.getElementType() != constituents.front().getType())
1760 return emitOpError("operand type mismatch: expected operand type ")
1761 << coopType.getElementType() << ", but provided "
1762 << constituents.front().getType();
1763 return success();
1764 }
1765
1766 if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) {
1767 if (constituents.size() != 1)
1768 return emitOpError("has incorrect number of operands: expected ")
1769 << "1, but provided " << constituents.size();
1770 if (jointType.getElementType() != constituents.front().getType())
1771 return emitOpError("operand type mismatch: expected operand type ")
1772 << jointType.getElementType() << ", but provided "
1773 << constituents.front().getType();
1774 return success();
1775 }
1776
1777 if (constituents.size() == cType.getNumElements()) {
1778 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
1779 if (constituents[index].getType() != cType.getElementType(index)) {
1780 return emitOpError("operand type mismatch: expected operand type ")
1781 << cType.getElementType(index) << ", but provided "
1782 << constituents[index].getType();
1783 }
1784 }
1785 return success();
1786 }
1787
1788 // If not constructing a cooperative matrix type, then we must be constructing
1789 // a vector type.
1790 auto resultType = cType.dyn_cast<VectorType>();
1791 if (!resultType)
1792 return emitOpError(
1793 "expected to return a vector or cooperative matrix when the number of "
1794 "constituents is less than what the result needs");
1795
1796 SmallVector<unsigned> sizes;
1797 for (Value component : constituents) {
1798 if (!component.getType().isa<VectorType>() &&
1799 !component.getType().isIntOrFloat())
1800 return emitOpError("operand type mismatch: expected operand to have "
1801 "a scalar or vector type, but provided ")
1802 << component.getType();
1803
1804 Type elementType = component.getType();
1805 if (auto vectorType = component.getType().dyn_cast<VectorType>()) {
1806 sizes.push_back(vectorType.getNumElements());
1807 elementType = vectorType.getElementType();
1808 } else {
1809 sizes.push_back(1);
1810 }
1811
1812 if (elementType != resultType.getElementType())
1813 return emitOpError("operand element type mismatch: expected to be ")
1814 << resultType.getElementType() << ", but provided " << elementType;
1815 }
1816 unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0);
1817 if (totalCount != cType.getNumElements())
1818 return emitOpError("has incorrect number of operands: expected ")
1819 << cType.getNumElements() << ", but provided " << totalCount;
1820 return success();
1821}
1822
1823//===----------------------------------------------------------------------===//
1824// spirv.CompositeExtractOp
1825//===----------------------------------------------------------------------===//
1826
1827void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state,
1828 Value composite,
1829 ArrayRef<int32_t> indices) {
1830 auto indexAttr = builder.getI32ArrayAttr(indices);
1831 auto elementType =
1832 getElementType(composite.getType(), indexAttr, state.location);
1833 if (!elementType) {
1834 return;
1835 }
1836 build(builder, state, elementType, composite, indexAttr);
1837}
1838
1839ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser,
1840 OperationState &result) {
1841 OpAsmParser::UnresolvedOperand compositeInfo;
1842 Attribute indicesAttr;
1843 Type compositeType;
1844 SMLoc attrLocation;
1845
1846 if (parser.parseOperand(compositeInfo) ||
1847 parser.getCurrentLocation(&attrLocation) ||
1848 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1849 parser.parseColonType(compositeType) ||
1850 parser.resolveOperand(compositeInfo, compositeType, result.operands)) {
1851 return failure();
1852 }
1853
1854 Type resultType =
1855 getElementType(compositeType, indicesAttr, parser, attrLocation);
1856 if (!resultType) {
1857 return failure();
1858 }
1859 result.addTypes(resultType);
1860 return success();
1861}
1862
1863void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) {
1864 printer << ' ' << getComposite() << getIndices() << " : "
1865 << getComposite().getType();
1866}
1867
1868LogicalResult spirv::CompositeExtractOp::verify() {
1869 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1870 auto resultType =
1871 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1872 if (!resultType)
1873 return failure();
1874
1875 if (resultType != getType()) {
1876 return emitOpError("invalid result type: expected ")
1877 << resultType << " but provided " << getType();
1878 }
1879
1880 return success();
1881}
1882
1883//===----------------------------------------------------------------------===//
1884// spirv.CompositeInsert
1885//===----------------------------------------------------------------------===//
1886
1887void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state,
1888 Value object, Value composite,
1889 ArrayRef<int32_t> indices) {
1890 auto indexAttr = builder.getI32ArrayAttr(indices);
1891 build(builder, state, composite.getType(), object, composite, indexAttr);
1892}
1893
1894ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser,
1895 OperationState &result) {
1896 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
1897 Type objectType, compositeType;
1898 Attribute indicesAttr;
1899 auto loc = parser.getCurrentLocation();
1900
1901 return failure(
1902 parser.parseOperandList(operands, 2) ||
1903 parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) ||
1904 parser.parseColonType(objectType) ||
1905 parser.parseKeywordType("into", compositeType) ||
1906 parser.resolveOperands(operands, {objectType, compositeType}, loc,
1907 result.operands) ||
1908 parser.addTypesToList(compositeType, result.types));
1909}
1910
1911LogicalResult spirv::CompositeInsertOp::verify() {
1912 auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>();
1913 auto objectType =
1914 getElementType(getComposite().getType(), indicesArrayAttr, getLoc());
1915 if (!objectType)
1916 return failure();
1917
1918 if (objectType != getObject().getType()) {
1919 return emitOpError("object operand type should be ")
1920 << objectType << ", but found " << getObject().getType();
1921 }
1922
1923 if (getComposite().getType() != getType()) {
1924 return emitOpError("result type should be the same as "
1925 "the composite type, but found ")
1926 << getComposite().getType() << " vs " << getType();
1927 }
1928
1929 return success();
1930}
1931
1932void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) {
1933 printer << " " << getObject() << ", " << getComposite() << getIndices()
1934 << " : " << getObject().getType() << " into "
1935 << getComposite().getType();
1936}
1937
1938//===----------------------------------------------------------------------===//
1939// spirv.Constant
1940//===----------------------------------------------------------------------===//
1941
1942ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
1943 OperationState &result) {
1944 Attribute value;
1945 if (parser.parseAttribute(value, kValueAttrName, result.attributes))
1946 return failure();
1947
1948 Type type = NoneType::get(parser.getContext());
1949 if (auto typedAttr = value.dyn_cast<TypedAttr>())
1950 type = typedAttr.getType();
1951 if (type.isa<NoneType, TensorType>()) {
1952 if (parser.parseColonType(type))
1953 return failure();
1954 }
1955
1956 return parser.addTypeToList(type, result.types);
1957}
1958
1959void spirv::ConstantOp::print(OpAsmPrinter &printer) {
1960 printer << ' ' << getValue();
1961 if (getType().isa<spirv::ArrayType>())
1962 printer << " : " << getType();
1963}
1964
1965static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value,
1966 Type opType) {
1967 if (value.isa<IntegerAttr, FloatAttr>()) {
1968 auto valueType = value.cast<TypedAttr>().getType();
1969 if (valueType != opType)
1970 return op.emitOpError("result type (")
1971 << opType << ") does not match value type (" << valueType << ")";
1972 return success();
1973 }
1974 if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) {
1975 auto valueType = value.cast<TypedAttr>().getType();
1976 if (valueType == opType)
1977 return success();
1978 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
1979 auto shapedType = valueType.dyn_cast<ShapedType>();
1980 if (!arrayType)
1981 return op.emitOpError("result or element type (")
1982 << opType << ") does not match value type (" << valueType
1983 << "), must be the same or spirv.array";
1984
1985 int numElements = arrayType.getNumElements();
1986 auto opElemType = arrayType.getElementType();
1987 while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) {
1988 numElements *= t.getNumElements();
1989 opElemType = t.getElementType();
1990 }
1991 if (!opElemType.isIntOrFloat())
1992 return op.emitOpError("only support nested array result type");
1993
1994 auto valueElemType = shapedType.getElementType();
1995 if (valueElemType != opElemType) {
1996 return op.emitOpError("result element type (")
1997 << opElemType << ") does not match value element type ("
1998 << valueElemType << ")";
1999 }
2000
2001 if (numElements != shapedType.getNumElements()) {
2002 return op.emitOpError("result number of elements (")
2003 << numElements << ") does not match value number of elements ("
2004 << shapedType.getNumElements() << ")";
2005 }
2006 return success();
2007 }
2008 if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) {
2009 auto arrayType = opType.dyn_cast<spirv::ArrayType>();
2010 if (!arrayType)
2011 return op.emitOpError(
2012 "must have spirv.array result type for array value");
2013 Type elemType = arrayType.getElementType();
2014 for (Attribute element : arrayAttr.getValue()) {
2015 // Verify array elements recursively.
2016 if (failed(verifyConstantType(op, element, elemType)))
2017 return failure();
2018 }
2019 return success();
2020 }
2021 return op.emitOpError("cannot have attribute: ") << value;
2022}
2023
2024LogicalResult spirv::ConstantOp::verify() {
2025 // ODS already generates checks to make sure the result type is valid. We just
2026 // need to additionally check that the value's attribute type is consistent
2027 // with the result type.
2028 return verifyConstantType(*this, getValueAttr(), getType());
2029}
2030
2031bool spirv::ConstantOp::isBuildableWith(Type type) {
2032 // Must be valid SPIR-V type first.
2033 if (!type.isa<spirv::SPIRVType>())
2034 return false;
2035
2036 if (isa<SPIRVDialect>(type.getDialect())) {
2037 // TODO: support constant struct
2038 return type.isa<spirv::ArrayType>();
2039 }
2040
2041 return true;
2042}
2043
2044spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc,
2045 OpBuilder &builder) {
2046 if (auto intType = type.dyn_cast<IntegerType>()) {
2047 unsigned width = intType.getWidth();
2048 if (width == 1)
2049 return builder.create<spirv::ConstantOp>(loc, type,
2050 builder.getBoolAttr(false));
2051 return builder.create<spirv::ConstantOp>(
2052 loc, type, builder.getIntegerAttr(type, APInt(width, 0)));
2053 }
2054 if (auto floatType = type.dyn_cast<FloatType>()) {
2055 return builder.create<spirv::ConstantOp>(
2056 loc, type, builder.getFloatAttr(floatType, 0.0));
2057 }
2058 if (auto vectorType = type.dyn_cast<VectorType>()) {
2059 Type elemType = vectorType.getElementType();
2060 if (elemType.isa<IntegerType>()) {
2061 return builder.create<spirv::ConstantOp>(
2062 loc, type,
2063 DenseElementsAttr::get(vectorType,
2064 IntegerAttr::get(elemType, 0).getValue()));
2065 }
2066 if (elemType.isa<FloatType>()) {
2067 return builder.create<spirv::ConstantOp>(
2068 loc, type,
2069 DenseFPElementsAttr::get(vectorType,
2070 FloatAttr::get(elemType, 0.0).getValue()));
2071 }
2072 }
2073
2074 llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2074)
;
2075}
2076
2077spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc,
2078 OpBuilder &builder) {
2079 if (auto intType = type.dyn_cast<IntegerType>()) {
2080 unsigned width = intType.getWidth();
2081 if (width == 1)
2082 return builder.create<spirv::ConstantOp>(loc, type,
2083 builder.getBoolAttr(true));
2084 return builder.create<spirv::ConstantOp>(
2085 loc, type, builder.getIntegerAttr(type, APInt(width, 1)));
2086 }
2087 if (auto floatType = type.dyn_cast<FloatType>()) {
2088 return builder.create<spirv::ConstantOp>(
2089 loc, type, builder.getFloatAttr(floatType, 1.0));
2090 }
2091 if (auto vectorType = type.dyn_cast<VectorType>()) {
2092 Type elemType = vectorType.getElementType();
2093 if (elemType.isa<IntegerType>()) {
2094 return builder.create<spirv::ConstantOp>(
2095 loc, type,
2096 DenseElementsAttr::get(vectorType,
2097 IntegerAttr::get(elemType, 1).getValue()));
2098 }
2099 if (elemType.isa<FloatType>()) {
2100 return builder.create<spirv::ConstantOp>(
2101 loc, type,
2102 DenseFPElementsAttr::get(vectorType,
2103 FloatAttr::get(elemType, 1.0).getValue()));
2104 }
2105 }
2106
2107 llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2107)
;
2108}
2109
2110void mlir::spirv::ConstantOp::getAsmResultNames(
2111 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2112 Type type = getType();
2113
2114 SmallString<32> specialNameBuffer;
2115 llvm::raw_svector_ostream specialName(specialNameBuffer);
2116 specialName << "cst";
2117
2118 IntegerType intTy = type.dyn_cast<IntegerType>();
2119
2120 if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) {
2121 if (intTy && intTy.getWidth() == 1) {
2122 return setNameFn(getResult(), (intCst.getInt() ? "true" : "false"));
2123 }
2124
2125 if (intTy.isSignless()) {
2126 specialName << intCst.getInt();
2127 } else if (intTy.isUnsigned()) {
2128 specialName << intCst.getUInt();
2129 } else {
2130 specialName << intCst.getSInt();
2131 }
2132 }
2133
2134 if (intTy || type.isa<FloatType>()) {
2135 specialName << '_' << type;
2136 }
2137
2138 if (auto vecType = type.dyn_cast<VectorType>()) {
2139 specialName << "_vec_";
2140 specialName << vecType.getDimSize(0);
2141
2142 Type elementType = vecType.getElementType();
2143
2144 if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) {
2145 specialName << "x" << elementType;
2146 }
2147 }
2148
2149 setNameFn(getResult(), specialName.str());
2150}
2151
2152void mlir::spirv::AddressOfOp::getAsmResultNames(
2153 llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) {
2154 SmallString<32> specialNameBuffer;
2155 llvm::raw_svector_ostream specialName(specialNameBuffer);
2156 specialName << getVariable() << "_addr";
2157 setNameFn(getResult(), specialName.str());
2158}
2159
2160//===----------------------------------------------------------------------===//
2161// spirv.ControlBarrierOp
2162//===----------------------------------------------------------------------===//
2163
2164LogicalResult spirv::ControlBarrierOp::verify() {
2165 return verifyMemorySemantics(getOperation(), getMemorySemantics());
2166}
2167
2168//===----------------------------------------------------------------------===//
2169// spirv.ConvertFToSOp
2170//===----------------------------------------------------------------------===//
2171
2172LogicalResult spirv::ConvertFToSOp::verify() {
2173 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2174 /*skipBitWidthCheck=*/true);
2175}
2176
2177//===----------------------------------------------------------------------===//
2178// spirv.ConvertFToUOp
2179//===----------------------------------------------------------------------===//
2180
2181LogicalResult spirv::ConvertFToUOp::verify() {
2182 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2183 /*skipBitWidthCheck=*/true);
2184}
2185
2186//===----------------------------------------------------------------------===//
2187// spirv.ConvertSToFOp
2188//===----------------------------------------------------------------------===//
2189
2190LogicalResult spirv::ConvertSToFOp::verify() {
2191 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2192 /*skipBitWidthCheck=*/true);
2193}
2194
2195//===----------------------------------------------------------------------===//
2196// spirv.ConvertUToFOp
2197//===----------------------------------------------------------------------===//
2198
2199LogicalResult spirv::ConvertUToFOp::verify() {
2200 return verifyCastOp(*this, /*requireSameBitWidth=*/false,
2201 /*skipBitWidthCheck=*/true);
2202}
2203
2204//===----------------------------------------------------------------------===//
2205// spirv.INTELConvertBF16ToFOp
2206//===----------------------------------------------------------------------===//
2207
2208LogicalResult spirv::INTELConvertBF16ToFOp::verify() {
2209 auto operandType = getOperand().getType();
2210 auto resultType = getResult().getType();
2211 // ODS checks that vector result type and vector operand type have the same
2212 // shape.
2213 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
2214 unsigned operandNumElements = vectorType.getNumElements();
2215 unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
2216 if (operandNumElements != resultNumElements) {
2217 return emitOpError(
2218 "operand and result must have same number of elements");
2219 }
2220 }
2221 return success();
2222}
2223
2224//===----------------------------------------------------------------------===//
2225// spirv.INTELConvertFToBF16Op
2226//===----------------------------------------------------------------------===//
2227
2228LogicalResult spirv::INTELConvertFToBF16Op::verify() {
2229 auto operandType = getOperand().getType();
2230 auto resultType = getResult().getType();
2231 // ODS checks that vector result type and vector operand type have the same
2232 // shape.
2233 if (auto vectorType = operandType.dyn_cast<VectorType>()) {
2234 unsigned operandNumElements = vectorType.getNumElements();
2235 unsigned resultNumElements = resultType.cast<VectorType>().getNumElements();
2236 if (operandNumElements != resultNumElements) {
2237 return emitOpError(
2238 "operand and result must have same number of elements");
2239 }
2240 }
2241 return success();
2242}
2243
2244//===----------------------------------------------------------------------===//
2245// spirv.EntryPoint
2246//===----------------------------------------------------------------------===//
2247
2248void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state,
2249 spirv::ExecutionModel executionModel,
2250 spirv::FuncOp function,
2251 ArrayRef<Attribute> interfaceVars) {
2252 build(builder, state,
2253 spirv::ExecutionModelAttr::get(builder.getContext(), executionModel),
2254 SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars));
2255}
2256
2257ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser,
2258 OperationState &result) {
2259 spirv::ExecutionModel execModel;
1
'execModel' declared without an initial value
2260 SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers;
2261 SmallVector<Type, 0> idTypes;
2262 SmallVector<Attribute, 4> interfaceVars;
2263
2264 FlatSymbolRefAttr fn;
2265 if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) ||
2
Calling 'parseEnumStrAttr<mlir::spirv::ExecutionModelAttr, mlir::spirv::ExecutionModel>'
2266 parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) {
2267 return failure();
2268 }
2269
2270 if (!parser.parseOptionalComma()) {
2271 // Parse the interface variables
2272 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
2273 // The name of the interface variable attribute isnt important
2274 FlatSymbolRefAttr var;
2275 NamedAttrList attrs;
2276 if (parser.parseAttribute(var, Type(), "var_symbol", attrs))
2277 return failure();
2278 interfaceVars.push_back(var);
2279 return success();
2280 }))
2281 return failure();
2282 }
2283 result.addAttribute(kInterfaceAttrName,
2284 parser.getBuilder().getArrayAttr(interfaceVars));
2285 return success();
2286}
2287
2288void spirv::EntryPointOp::print(OpAsmPrinter &printer) {
2289 printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" ";
2290 printer.printSymbolName(getFn());
2291 auto interfaceVars = getInterface().getValue();
2292 if (!interfaceVars.empty()) {
2293 printer << ", ";
2294 llvm::interleaveComma(interfaceVars, printer);
2295 }
2296}
2297
2298LogicalResult spirv::EntryPointOp::verify() {
2299 // Checks for fn and interface symbol reference are done in spirv::ModuleOp
2300 // verification.
2301 return success();
2302}
2303
2304//===----------------------------------------------------------------------===//
2305// spirv.ExecutionMode
2306//===----------------------------------------------------------------------===//
2307
2308void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state,
2309 spirv::FuncOp function,
2310 spirv::ExecutionMode executionMode,
2311 ArrayRef<int32_t> params) {
2312 build(builder, state, SymbolRefAttr::get(function),
2313 spirv::ExecutionModeAttr::get(builder.getContext(), executionMode),
2314 builder.getI32ArrayAttr(params));
2315}
2316
2317ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser,
2318 OperationState &result) {
2319 spirv::ExecutionMode execMode;
2320 Attribute fn;
2321 if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) ||
2322 parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) {
2323 return failure();
2324 }
2325
2326 SmallVector<int32_t, 4> values;
2327 Type i32Type = parser.getBuilder().getIntegerType(32);
2328 while (!parser.parseOptionalComma()) {
2329 NamedAttrList attr;
2330 Attribute value;
2331 if (parser.parseAttribute(value, i32Type, "value", attr)) {
2332 return failure();
2333 }
2334 values.push_back(value.cast<IntegerAttr>().getInt());
2335 }
2336 result.addAttribute(kValuesAttrName,
2337 parser.getBuilder().getI32ArrayAttr(values));
2338 return success();
2339}
2340
2341void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) {
2342 printer << " ";
2343 printer.printSymbolName(getFn());
2344 printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\"";
2345 auto values = this->getValues();
2346 if (values.empty())
2347 return;
2348 printer << ", ";
2349 llvm::interleaveComma(values, printer, [&](Attribute a) {
2350 printer << a.cast<IntegerAttr>().getInt();
2351 });
2352}
2353
2354//===----------------------------------------------------------------------===//
2355// spirv.FConvertOp
2356//===----------------------------------------------------------------------===//
2357
2358LogicalResult spirv::FConvertOp::verify() {
2359 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2360}
2361
2362//===----------------------------------------------------------------------===//
2363// spirv.SConvertOp
2364//===----------------------------------------------------------------------===//
2365
2366LogicalResult spirv::SConvertOp::verify() {
2367 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2368}
2369
2370//===----------------------------------------------------------------------===//
2371// spirv.UConvertOp
2372//===----------------------------------------------------------------------===//
2373
2374LogicalResult spirv::UConvertOp::verify() {
2375 return verifyCastOp(*this, /*requireSameBitWidth=*/false);
2376}
2377
2378//===----------------------------------------------------------------------===//
2379// spirv.func
2380//===----------------------------------------------------------------------===//
2381
2382ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) {
2383 SmallVector<OpAsmParser::Argument> entryArgs;
2384 SmallVector<DictionaryAttr> resultAttrs;
2385 SmallVector<Type> resultTypes;
2386 auto &builder = parser.getBuilder();
2387
2388 // Parse the name as a symbol.
2389 StringAttr nameAttr;
2390 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2391 result.attributes))
2392 return failure();
2393
2394 // Parse the function signature.
2395 bool isVariadic = false;
2396 if (function_interface_impl::parseFunctionSignature(
2397 parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
2398 resultAttrs))
2399 return failure();
2400
2401 SmallVector<Type> argTypes;
2402 for (auto &arg : entryArgs)
2403 argTypes.push_back(arg.type);
2404 auto fnType = builder.getFunctionType(argTypes, resultTypes);
2405 result.addAttribute(getFunctionTypeAttrName(result.name),
2406 TypeAttr::get(fnType));
2407
2408 // Parse the optional function control keyword.
2409 spirv::FunctionControl fnControl;
2410 if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result))
2411 return failure();
2412
2413 // If additional attributes are present, parse them.
2414 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
2415 return failure();
2416
2417 // Add the attributes to the function arguments.
2418 assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes.
size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()"
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2418, __extension__
__PRETTY_FUNCTION__))
;
2419 function_interface_impl::addArgAndResultAttrs(
2420 builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
2421 getResAttrsAttrName(result.name));
2422
2423 // Parse the optional function body.
2424 auto *body = result.addRegion();
2425 OptionalParseResult parseResult =
2426 parser.parseOptionalRegion(*body, entryArgs);
2427 return failure(parseResult.has_value() && failed(*parseResult));
2428}
2429
2430void spirv::FuncOp::print(OpAsmPrinter &printer) {
2431 // Print function name, signature, and control.
2432 printer << " ";
2433 printer.printSymbolName(getSymName());
2434 auto fnType = getFunctionType();
2435 function_interface_impl::printFunctionSignature(
2436 printer, *this, fnType.getInputs(),
2437 /*isVariadic=*/false, fnType.getResults());
2438 printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl())
2439 << "\"";
2440 function_interface_impl::printFunctionAttributes(
2441 printer, *this,
2442 {spirv::attributeName<spirv::FunctionControl>(),
2443 getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(),
2444 getFunctionControlAttrName()});
2445
2446 // Print the body if this is not an external function.
2447 Region &body = this->getBody();
2448 if (!body.empty()) {
2449 printer << ' ';
2450 printer.printRegion(body, /*printEntryBlockArgs=*/false,
2451 /*printBlockTerminators=*/true);
2452 }
2453}
2454
2455LogicalResult spirv::FuncOp::verifyType() {
2456 if (getFunctionType().getNumResults() > 1)
2457 return emitOpError("cannot have more than one result");
2458 return success();
2459}
2460
2461LogicalResult spirv::FuncOp::verifyBody() {
2462 FunctionType fnType = getFunctionType();
2463
2464 auto walkResult = walk([fnType](Operation *op) -> WalkResult {
2465 if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) {
2466 if (fnType.getNumResults() != 0)
2467 return retOp.emitOpError("cannot be used in functions returning value");
2468 } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) {
2469 if (fnType.getNumResults() != 1)
2470 return retOp.emitOpError(
2471 "returns 1 value but enclosing function requires ")
2472 << fnType.getNumResults() << " results";
2473
2474 auto retOperandType = retOp.getValue().getType();
2475 auto fnResultType = fnType.getResult(0);
2476 if (retOperandType != fnResultType)
2477 return retOp.emitOpError(" return value's type (")
2478 << retOperandType << ") mismatch with function's result type ("
2479 << fnResultType << ")";
2480 }
2481 return WalkResult::advance();
2482 });
2483
2484 // TODO: verify other bits like linkage type.
2485
2486 return failure(walkResult.wasInterrupted());
2487}
2488
2489void spirv::FuncOp::build(OpBuilder &builder, OperationState &state,
2490 StringRef name, FunctionType type,
2491 spirv::FunctionControl control,
2492 ArrayRef<NamedAttribute> attrs) {
2493 state.addAttribute(SymbolTable::getSymbolAttrName(),
2494 builder.getStringAttr(name));
2495 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
2496 state.addAttribute(spirv::attributeName<spirv::FunctionControl>(),
2497 builder.getAttr<spirv::FunctionControlAttr>(control));
2498 state.attributes.append(attrs.begin(), attrs.end());
2499 state.addRegion();
2500}
2501
2502// CallableOpInterface
2503Region *spirv::FuncOp::getCallableRegion() {
2504 return isExternal() ? nullptr : &getBody();
2505}
2506
2507// CallableOpInterface
2508ArrayRef<Type> spirv::FuncOp::getCallableResults() {
2509 return getFunctionType().getResults();
2510}
2511
2512// CallableOpInterface
2513::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() {
2514 return getArgAttrs().value_or(nullptr);
2515}
2516
2517// CallableOpInterface
2518::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() {
2519 return getResAttrs().value_or(nullptr);
2520}
2521
2522//===----------------------------------------------------------------------===//
2523// spirv.FunctionCall
2524//===----------------------------------------------------------------------===//
2525
2526LogicalResult spirv::FunctionCallOp::verify() {
2527 auto fnName = getCalleeAttr();
2528
2529 auto funcOp = dyn_cast_or_null<spirv::FuncOp>(
2530 SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName));
2531 if (!funcOp) {
2532 return emitOpError("callee function '")
2533 << fnName.getValue() << "' not found in nearest symbol table";
2534 }
2535
2536 auto functionType = funcOp.getFunctionType();
2537
2538 if (getNumResults() > 1) {
2539 return emitOpError(
2540 "expected callee function to have 0 or 1 result, but provided ")
2541 << getNumResults();
2542 }
2543
2544 if (functionType.getNumInputs() != getNumOperands()) {
2545 return emitOpError("has incorrect number of operands for callee: expected ")
2546 << functionType.getNumInputs() << ", but provided "
2547 << getNumOperands();
2548 }
2549
2550 for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) {
2551 if (getOperand(i).getType() != functionType.getInput(i)) {
2552 return emitOpError("operand type mismatch: expected operand type ")
2553 << functionType.getInput(i) << ", but provided "
2554 << getOperand(i).getType() << " for operand number " << i;
2555 }
2556 }
2557
2558 if (functionType.getNumResults() != getNumResults()) {
2559 return emitOpError(
2560 "has incorrect number of results has for callee: expected ")
2561 << functionType.getNumResults() << ", but provided "
2562 << getNumResults();
2563 }
2564
2565 if (getNumResults() &&
2566 (getResult(0).getType() != functionType.getResult(0))) {
2567 return emitOpError("result type mismatch: expected ")
2568 << functionType.getResult(0) << ", but provided "
2569 << getResult(0).getType();
2570 }
2571
2572 return success();
2573}
2574
2575CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() {
2576 return (*this)->getAttrOfType<SymbolRefAttr>(kCallee);
2577}
2578
2579void spirv::FunctionCallOp::setCalleeFromCallable(
2580 CallInterfaceCallable callee) {
2581 (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>());
2582}
2583
2584Operation::operand_range spirv::FunctionCallOp::getArgOperands() {
2585 return getArguments();
2586}
2587
2588//===----------------------------------------------------------------------===//
2589// spirv.GLFClampOp
2590//===----------------------------------------------------------------------===//
2591
2592ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser,
2593 OperationState &result) {
2594 return parseOneResultSameOperandTypeOp(parser, result);
2595}
2596void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2597
2598//===----------------------------------------------------------------------===//
2599// spirv.GLUClampOp
2600//===----------------------------------------------------------------------===//
2601
2602ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser,
2603 OperationState &result) {
2604 return parseOneResultSameOperandTypeOp(parser, result);
2605}
2606void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2607
2608//===----------------------------------------------------------------------===//
2609// spirv.GLSClampOp
2610//===----------------------------------------------------------------------===//
2611
2612ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser,
2613 OperationState &result) {
2614 return parseOneResultSameOperandTypeOp(parser, result);
2615}
2616void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2617
2618//===----------------------------------------------------------------------===//
2619// spirv.GLFmaOp
2620//===----------------------------------------------------------------------===//
2621
2622ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) {
2623 return parseOneResultSameOperandTypeOp(parser, result);
2624}
2625void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); }
2626
2627//===----------------------------------------------------------------------===//
2628// spirv.GlobalVariable
2629//===----------------------------------------------------------------------===//
2630
2631void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2632 Type type, StringRef name,
2633 unsigned descriptorSet, unsigned binding) {
2634 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2635 state.addAttribute(
2636 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet),
2637 builder.getI32IntegerAttr(descriptorSet));
2638 state.addAttribute(
2639 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding),
2640 builder.getI32IntegerAttr(binding));
2641}
2642
2643void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state,
2644 Type type, StringRef name,
2645 spirv::BuiltIn builtin) {
2646 build(builder, state, TypeAttr::get(type), builder.getStringAttr(name));
2647 state.addAttribute(
2648 spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn),
2649 builder.getStringAttr(spirv::stringifyBuiltIn(builtin)));
2650}
2651
2652ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser,
2653 OperationState &result) {
2654 // Parse variable name.
2655 StringAttr nameAttr;
2656 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
2657 result.attributes)) {
2658 return failure();
2659 }
2660
2661 // Parse optional initializer
2662 if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) {
2663 FlatSymbolRefAttr initSymbol;
2664 if (parser.parseLParen() ||
2665 parser.parseAttribute(initSymbol, Type(), kInitializerAttrName,
2666 result.attributes) ||
2667 parser.parseRParen())
2668 return failure();
2669 }
2670
2671 if (parseVariableDecorations(parser, result)) {
2672 return failure();
2673 }
2674
2675 Type type;
2676 auto loc = parser.getCurrentLocation();
2677 if (parser.parseColonType(type)) {
2678 return failure();
2679 }
2680 if (!type.isa<spirv::PointerType>()) {
2681 return parser.emitError(loc, "expected spirv.ptr type");
2682 }
2683 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
2684
2685 return success();
2686}
2687
2688void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) {
2689 SmallVector<StringRef, 4> elidedAttrs{
2690 spirv::attributeName<spirv::StorageClass>()};
2691
2692 // Print variable name.
2693 printer << ' ';
2694 printer.printSymbolName(getSymName());
2695 elidedAttrs.push_back(SymbolTable::getSymbolAttrName());
2696
2697 // Print optional initializer
2698 if (auto initializer = this->getInitializer()) {
2699 printer << " " << kInitializerAttrName << '(';
2700 printer.printSymbolName(*initializer);
2701 printer << ')';
2702 elidedAttrs.push_back(kInitializerAttrName);
2703 }
2704
2705 elidedAttrs.push_back(kTypeAttrName);
2706 printVariableDecorations(*this, printer, elidedAttrs);
2707 printer << " : " << getType();
2708}
2709
2710LogicalResult spirv::GlobalVariableOp::verify() {
2711 if (!getType().isa<spirv::PointerType>())
2712 return emitOpError("result must be of a !spv.ptr type");
2713
2714 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
2715 // object. It cannot be Generic. It must be the same as the Storage Class
2716 // operand of the Result Type."
2717 // Also, Function storage class is reserved by spirv.Variable.
2718 auto storageClass = this->storageClass();
2719 if (storageClass == spirv::StorageClass::Generic ||
2720 storageClass == spirv::StorageClass::Function) {
2721 return emitOpError("storage class cannot be '")
2722 << stringifyStorageClass(storageClass) << "'";
2723 }
2724
2725 if (auto init =
2726 (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) {
2727 Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
2728 (*this)->getParentOp(), init.getAttr());
2729 // TODO: Currently only variable initialization with specialization
2730 // constants and other variables is supported. They could be normal
2731 // constants in the module scope as well.
2732 if (!initOp ||
2733 !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
2734 return emitOpError("initializer must be result of a "
2735 "spirv.SpecConstant or spirv.GlobalVariable op");
2736 }
2737 }
2738
2739 return success();
2740}
2741
2742//===----------------------------------------------------------------------===//
2743// spirv.GroupBroadcast
2744//===----------------------------------------------------------------------===//
2745
2746LogicalResult spirv::GroupBroadcastOp::verify() {
2747 spirv::Scope scope = getExecutionScope();
2748 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2749 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2750
2751 if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>())
2752 if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3)
2753 return emitOpError("localid is a vector and can be with only "
2754 " 2 or 3 components, actual number is ")
2755 << localIdTy.getNumElements();
2756
2757 return success();
2758}
2759
2760//===----------------------------------------------------------------------===//
2761// spirv.GroupNonUniformBallotOp
2762//===----------------------------------------------------------------------===//
2763
2764LogicalResult spirv::GroupNonUniformBallotOp::verify() {
2765 spirv::Scope scope = getExecutionScope();
2766 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2767 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2768
2769 return success();
2770}
2771
2772//===----------------------------------------------------------------------===//
2773// spirv.GroupNonUniformBroadcast
2774//===----------------------------------------------------------------------===//
2775
2776LogicalResult spirv::GroupNonUniformBroadcastOp::verify() {
2777 spirv::Scope scope = getExecutionScope();
2778 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2779 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2780
2781 // SPIR-V spec: "Before version 1.5, Id must come from a
2782 // constant instruction.
2783 auto targetEnv = spirv::getDefaultTargetEnv(getContext());
2784 if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>())
2785 targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule);
2786
2787 if (targetEnv.getVersion() < spirv::Version::V_1_5) {
2788 auto *idOp = getId().getDefiningOp();
2789 if (!idOp || !isa<spirv::ConstantOp, // for normal constant
2790 spirv::ReferenceOfOp>(idOp)) // for spec constant
2791 return emitOpError("id must be the result of a constant op");
2792 }
2793
2794 return success();
2795}
2796
2797//===----------------------------------------------------------------------===//
2798// spirv.GroupNonUniformShuffle*
2799//===----------------------------------------------------------------------===//
2800
2801template <typename OpTy>
2802static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) {
2803 spirv::Scope scope = op.getExecutionScope();
2804 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2805 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2806
2807 if (op.getOperands().back().getType().isSignedInteger())
2808 return op.emitOpError("second operand must be a singless/unsigned integer");
2809
2810 return success();
2811}
2812
2813LogicalResult spirv::GroupNonUniformShuffleOp::verify() {
2814 return verifyGroupNonUniformShuffleOp(*this);
2815}
2816LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() {
2817 return verifyGroupNonUniformShuffleOp(*this);
2818}
2819LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() {
2820 return verifyGroupNonUniformShuffleOp(*this);
2821}
2822LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() {
2823 return verifyGroupNonUniformShuffleOp(*this);
2824}
2825
2826//===----------------------------------------------------------------------===//
2827// spirv.INTEL.SubgroupBlockRead
2828//===----------------------------------------------------------------------===//
2829
2830ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser,
2831 OperationState &result) {
2832 // Parse the storage class specification
2833 spirv::StorageClass storageClass;
2834 OpAsmParser::UnresolvedOperand ptrInfo;
2835 Type elementType;
2836 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
2837 parser.parseColon() || parser.parseType(elementType)) {
2838 return failure();
2839 }
2840
2841 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2842 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2843 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2844
2845 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
2846 return failure();
2847 }
2848
2849 result.addTypes(elementType);
2850 return success();
2851}
2852
2853void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) {
2854 printer << " " << getPtr() << " : " << getType();
2855}
2856
2857LogicalResult spirv::INTELSubgroupBlockReadOp::verify() {
2858 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2859 return failure();
2860
2861 return success();
2862}
2863
2864//===----------------------------------------------------------------------===//
2865// spirv.INTEL.SubgroupBlockWrite
2866//===----------------------------------------------------------------------===//
2867
2868ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser,
2869 OperationState &result) {
2870 // Parse the storage class specification
2871 spirv::StorageClass storageClass;
2872 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
2873 auto loc = parser.getCurrentLocation();
2874 Type elementType;
2875 if (parseEnumStrAttr(storageClass, parser) ||
2876 parser.parseOperandList(operandInfo, 2) || parser.parseColon() ||
2877 parser.parseType(elementType)) {
2878 return failure();
2879 }
2880
2881 auto ptrType = spirv::PointerType::get(elementType, storageClass);
2882 if (auto valVecTy = elementType.dyn_cast<VectorType>())
2883 ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass);
2884
2885 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
2886 result.operands)) {
2887 return failure();
2888 }
2889 return success();
2890}
2891
2892void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) {
2893 printer << " " << getPtr() << ", " << getValue() << " : "
2894 << getValue().getType();
2895}
2896
2897LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() {
2898 if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue())))
2899 return failure();
2900
2901 return success();
2902}
2903
2904//===----------------------------------------------------------------------===//
2905// spirv.GroupNonUniformElectOp
2906//===----------------------------------------------------------------------===//
2907
2908LogicalResult spirv::GroupNonUniformElectOp::verify() {
2909 spirv::Scope scope = getExecutionScope();
2910 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
2911 return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
2912
2913 return success();
2914}
2915
2916//===----------------------------------------------------------------------===//
2917// spirv.GroupNonUniformFAddOp
2918//===----------------------------------------------------------------------===//
2919
2920LogicalResult spirv::GroupNonUniformFAddOp::verify() {
2921 return verifyGroupNonUniformArithmeticOp(*this);
2922}
2923
2924ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser,
2925 OperationState &result) {
2926 return parseGroupNonUniformArithmeticOp(parser, result);
2927}
2928void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) {
2929 printGroupNonUniformArithmeticOp(*this, p);
2930}
2931
2932//===----------------------------------------------------------------------===//
2933// spirv.GroupNonUniformFMaxOp
2934//===----------------------------------------------------------------------===//
2935
2936LogicalResult spirv::GroupNonUniformFMaxOp::verify() {
2937 return verifyGroupNonUniformArithmeticOp(*this);
2938}
2939
2940ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser,
2941 OperationState &result) {
2942 return parseGroupNonUniformArithmeticOp(parser, result);
2943}
2944void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) {
2945 printGroupNonUniformArithmeticOp(*this, p);
2946}
2947
2948//===----------------------------------------------------------------------===//
2949// spirv.GroupNonUniformFMinOp
2950//===----------------------------------------------------------------------===//
2951
2952LogicalResult spirv::GroupNonUniformFMinOp::verify() {
2953 return verifyGroupNonUniformArithmeticOp(*this);
2954}
2955
2956ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser,
2957 OperationState &result) {
2958 return parseGroupNonUniformArithmeticOp(parser, result);
2959}
2960void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) {
2961 printGroupNonUniformArithmeticOp(*this, p);
2962}
2963
2964//===----------------------------------------------------------------------===//
2965// spirv.GroupNonUniformFMulOp
2966//===----------------------------------------------------------------------===//
2967
2968LogicalResult spirv::GroupNonUniformFMulOp::verify() {
2969 return verifyGroupNonUniformArithmeticOp(*this);
2970}
2971
2972ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser,
2973 OperationState &result) {
2974 return parseGroupNonUniformArithmeticOp(parser, result);
2975}
2976void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) {
2977 printGroupNonUniformArithmeticOp(*this, p);
2978}
2979
2980//===----------------------------------------------------------------------===//
2981// spirv.GroupNonUniformIAddOp
2982//===----------------------------------------------------------------------===//
2983
2984LogicalResult spirv::GroupNonUniformIAddOp::verify() {
2985 return verifyGroupNonUniformArithmeticOp(*this);
2986}
2987
2988ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser,
2989 OperationState &result) {
2990 return parseGroupNonUniformArithmeticOp(parser, result);
2991}
2992void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) {
2993 printGroupNonUniformArithmeticOp(*this, p);
2994}
2995
2996//===----------------------------------------------------------------------===//
2997// spirv.GroupNonUniformIMulOp
2998//===----------------------------------------------------------------------===//
2999
3000LogicalResult spirv::GroupNonUniformIMulOp::verify() {
3001 return verifyGroupNonUniformArithmeticOp(*this);
3002}
3003
3004ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser,
3005 OperationState &result) {
3006 return parseGroupNonUniformArithmeticOp(parser, result);
3007}
3008void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) {
3009 printGroupNonUniformArithmeticOp(*this, p);
3010}
3011
3012//===----------------------------------------------------------------------===//
3013// spirv.GroupNonUniformSMaxOp
3014//===----------------------------------------------------------------------===//
3015
3016LogicalResult spirv::GroupNonUniformSMaxOp::verify() {
3017 return verifyGroupNonUniformArithmeticOp(*this);
3018}
3019
3020ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser,
3021 OperationState &result) {
3022 return parseGroupNonUniformArithmeticOp(parser, result);
3023}
3024void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) {
3025 printGroupNonUniformArithmeticOp(*this, p);
3026}
3027
3028//===----------------------------------------------------------------------===//
3029// spirv.GroupNonUniformSMinOp
3030//===----------------------------------------------------------------------===//
3031
3032LogicalResult spirv::GroupNonUniformSMinOp::verify() {
3033 return verifyGroupNonUniformArithmeticOp(*this);
3034}
3035
3036ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser,
3037 OperationState &result) {
3038 return parseGroupNonUniformArithmeticOp(parser, result);
3039}
3040void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) {
3041 printGroupNonUniformArithmeticOp(*this, p);
3042}
3043
3044//===----------------------------------------------------------------------===//
3045// spirv.GroupNonUniformUMaxOp
3046//===----------------------------------------------------------------------===//
3047
3048LogicalResult spirv::GroupNonUniformUMaxOp::verify() {
3049 return verifyGroupNonUniformArithmeticOp(*this);
3050}
3051
3052ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser,
3053 OperationState &result) {
3054 return parseGroupNonUniformArithmeticOp(parser, result);
3055}
3056void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) {
3057 printGroupNonUniformArithmeticOp(*this, p);
3058}
3059
3060//===----------------------------------------------------------------------===//
3061// spirv.GroupNonUniformUMinOp
3062//===----------------------------------------------------------------------===//
3063
3064LogicalResult spirv::GroupNonUniformUMinOp::verify() {
3065 return verifyGroupNonUniformArithmeticOp(*this);
3066}
3067
3068ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser,
3069 OperationState &result) {
3070 return parseGroupNonUniformArithmeticOp(parser, result);
3071}
3072void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) {
3073 printGroupNonUniformArithmeticOp(*this, p);
3074}
3075
3076//===----------------------------------------------------------------------===//
3077// spirv.IAddCarryOp
3078//===----------------------------------------------------------------------===//
3079
3080LogicalResult spirv::IAddCarryOp::verify() {
3081 return ::verifyArithmeticExtendedBinaryOp(*this);
3082}
3083
3084ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser,
3085 OperationState &result) {
3086 return ::parseArithmeticExtendedBinaryOp(parser, result);
3087}
3088
3089void spirv::IAddCarryOp::print(OpAsmPrinter &printer) {
3090 ::printArithmeticExtendedBinaryOp(*this, printer);
3091}
3092
3093//===----------------------------------------------------------------------===//
3094// spirv.ISubBorrowOp
3095//===----------------------------------------------------------------------===//
3096
3097LogicalResult spirv::ISubBorrowOp::verify() {
3098 return ::verifyArithmeticExtendedBinaryOp(*this);
3099}
3100
3101ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser,
3102 OperationState &result) {
3103 return ::parseArithmeticExtendedBinaryOp(parser, result);
3104}
3105
3106void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) {
3107 ::printArithmeticExtendedBinaryOp(*this, printer);
3108}
3109
3110//===----------------------------------------------------------------------===//
3111// spirv.SMulExtended
3112//===----------------------------------------------------------------------===//
3113
3114LogicalResult spirv::SMulExtendedOp::verify() {
3115 return ::verifyArithmeticExtendedBinaryOp(*this);
3116}
3117
3118ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser,
3119 OperationState &result) {
3120 return ::parseArithmeticExtendedBinaryOp(parser, result);
3121}
3122
3123void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) {
3124 ::printArithmeticExtendedBinaryOp(*this, printer);
3125}
3126
3127//===----------------------------------------------------------------------===//
3128// spirv.UMulExtended
3129//===----------------------------------------------------------------------===//
3130
3131LogicalResult spirv::UMulExtendedOp::verify() {
3132 return ::verifyArithmeticExtendedBinaryOp(*this);
3133}
3134
3135ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser,
3136 OperationState &result) {
3137 return ::parseArithmeticExtendedBinaryOp(parser, result);
3138}
3139
3140void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) {
3141 ::printArithmeticExtendedBinaryOp(*this, printer);
3142}
3143
3144//===----------------------------------------------------------------------===//
3145// spirv.LoadOp
3146//===----------------------------------------------------------------------===//
3147
3148void spirv::LoadOp::build(OpBuilder &builder, OperationState &state,
3149 Value basePtr, MemoryAccessAttr memoryAccess,
3150 IntegerAttr alignment) {
3151 auto ptrType = basePtr.getType().cast<spirv::PointerType>();
3152 build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess,
3153 alignment);
3154}
3155
3156ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) {
3157 // Parse the storage class specification
3158 spirv::StorageClass storageClass;
3159 OpAsmParser::UnresolvedOperand ptrInfo;
3160 Type elementType;
3161 if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) ||
3162 parseMemoryAccessAttributes(parser, result) ||
3163 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
3164 parser.parseType(elementType)) {
3165 return failure();
3166 }
3167
3168 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3169 if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) {
3170 return failure();
3171 }
3172
3173 result.addTypes(elementType);
3174 return success();
3175}
3176
3177void spirv::LoadOp::print(OpAsmPrinter &printer) {
3178 SmallVector<StringRef, 4> elidedAttrs;
3179 StringRef sc = stringifyStorageClass(
3180 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3181 printer << " \"" << sc << "\" " << getPtr();
3182
3183 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3184
3185 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3186 printer << " : " << getType();
3187}
3188
3189LogicalResult spirv::LoadOp::verify() {
3190 // SPIR-V spec : "Result Type is the type of the loaded object. It must be a
3191 // type with fixed size; i.e., it cannot be, nor include, any
3192 // OpTypeRuntimeArray types."
3193 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) {
3194 return failure();
3195 }
3196 return verifyMemoryAccessAttribute(*this);
3197}
3198
3199//===----------------------------------------------------------------------===//
3200// spirv.mlir.loop
3201//===----------------------------------------------------------------------===//
3202
3203void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) {
3204 state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>(
3205 spirv::LoopControl::None));
3206 state.addRegion();
3207}
3208
3209ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) {
3210 if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
3211 result))
3212 return failure();
3213 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3214}
3215
3216void spirv::LoopOp::print(OpAsmPrinter &printer) {
3217 auto control = getLoopControl();
3218 if (control != spirv::LoopControl::None)
3219 printer << " control(" << spirv::stringifyLoopControl(control) << ")";
3220 printer << ' ';
3221 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3222 /*printBlockTerminators=*/true);
3223}
3224
3225/// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the
3226/// given `dstBlock`.
3227static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) {
3228 // Check that there is only one op in the `srcBlock`.
3229 if (!llvm::hasSingleElement(srcBlock))
3230 return false;
3231
3232 auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back());
3233 return branchOp && branchOp.getSuccessor() == &dstBlock;
3234}
3235
3236LogicalResult spirv::LoopOp::verifyRegions() {
3237 auto *op = getOperation();
3238
3239 // We need to verify that the blocks follow the following layout:
3240 //
3241 // +-------------+
3242 // | entry block |
3243 // +-------------+
3244 // |
3245 // v
3246 // +-------------+
3247 // | loop header | <-----+
3248 // +-------------+ |
3249 // |
3250 // ... |
3251 // \ | / |
3252 // v |
3253 // +---------------+ |
3254 // | loop continue | -----+
3255 // +---------------+
3256 //
3257 // ...
3258 // \ | /
3259 // v
3260 // +-------------+
3261 // | merge block |
3262 // +-------------+
3263
3264 auto &region = op->getRegion(0);
3265 // Allow empty region as a degenerated case, which can come from
3266 // optimizations.
3267 if (region.empty())
3268 return success();
3269
3270 // The last block is the merge block.
3271 Block &merge = region.back();
3272 if (!isMergeBlock(merge))
3273 return emitOpError("last block must be the merge block with only one "
3274 "'spirv.mlir.merge' op");
3275
3276 if (std::next(region.begin()) == region.end())
3277 return emitOpError(
3278 "must have an entry block branching to the loop header block");
3279 // The first block is the entry block.
3280 Block &entry = region.front();
3281
3282 if (std::next(region.begin(), 2) == region.end())
3283 return emitOpError(
3284 "must have a loop header block branched from the entry block");
3285 // The second block is the loop header block.
3286 Block &header = *std::next(region.begin(), 1);
3287
3288 if (!hasOneBranchOpTo(entry, header))
3289 return emitOpError(
3290 "entry block must only have one 'spirv.Branch' op to the second block");
3291
3292 if (std::next(region.begin(), 3) == region.end())
3293 return emitOpError(
3294 "requires a loop continue block branching to the loop header block");
3295 // The second to last block is the loop continue block.
3296 Block &cont = *std::prev(region.end(), 2);
3297
3298 // Make sure that we have a branch from the loop continue block to the loop
3299 // header block.
3300 if (llvm::none_of(
3301 llvm::seq<unsigned>(0, cont.getNumSuccessors()),
3302 [&](unsigned index) { return cont.getSuccessor(index) == &header; }))
3303 return emitOpError("second to last block must be the loop continue "
3304 "block that branches to the loop header block");
3305
3306 // Make sure that no other blocks (except the entry and loop continue block)
3307 // branches to the loop header block.
3308 for (auto &block : llvm::make_range(std::next(region.begin(), 2),
3309 std::prev(region.end(), 2))) {
3310 for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) {
3311 if (block.getSuccessor(i) == &header) {
3312 return emitOpError("can only have the entry and loop continue "
3313 "block branching to the loop header block");
3314 }
3315 }
3316 }
3317
3318 return success();
3319}
3320
3321Block *spirv::LoopOp::getEntryBlock() {
3322 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3322, __extension__
__PRETTY_FUNCTION__))
;
3323 return &getBody().front();
3324}
3325
3326Block *spirv::LoopOp::getHeaderBlock() {
3327 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3327, __extension__
__PRETTY_FUNCTION__))
;
3328 // The second block is the loop header block.
3329 return &*std::next(getBody().begin());
3330}
3331
3332Block *spirv::LoopOp::getContinueBlock() {
3333 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3333, __extension__
__PRETTY_FUNCTION__))
;
3334 // The second to last block is the loop continue block.
3335 return &*std::prev(getBody().end(), 2);
3336}
3337
3338Block *spirv::LoopOp::getMergeBlock() {
3339 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3339, __extension__
__PRETTY_FUNCTION__))
;
3340 // The last block is the loop merge block.
3341 return &getBody().back();
3342}
3343
3344void spirv::LoopOp::addEntryAndMergeBlock() {
3345 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3345, __extension__
__PRETTY_FUNCTION__))
;
3346 getBody().push_back(new Block());
3347 auto *mergeBlock = new Block();
3348 getBody().push_back(mergeBlock);
3349 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3350
3351 // Add a spirv.mlir.merge op into the merge block.
3352 builder.create<spirv::MergeOp>(getLoc());
3353}
3354
3355//===----------------------------------------------------------------------===//
3356// spirv.MemoryBarrierOp
3357//===----------------------------------------------------------------------===//
3358
3359LogicalResult spirv::MemoryBarrierOp::verify() {
3360 return verifyMemorySemantics(getOperation(), getMemorySemantics());
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// spirv.mlir.merge
3365//===----------------------------------------------------------------------===//
3366
3367LogicalResult spirv::MergeOp::verify() {
3368 auto *parentOp = (*this)->getParentOp();
3369 if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp))
3370 return emitOpError(
3371 "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'");
3372
3373 // TODO: This check should be done in `verifyRegions` of parent op.
3374 Block &parentLastBlock = (*this)->getParentRegion()->back();
3375 if (getOperation() != parentLastBlock.getTerminator())
3376 return emitOpError("can only be used in the last block of "
3377 "'spirv.mlir.selection' or 'spirv.mlir.loop'");
3378 return success();
3379}
3380
3381//===----------------------------------------------------------------------===//
3382// spirv.module
3383//===----------------------------------------------------------------------===//
3384
3385void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3386 std::optional<StringRef> name) {
3387 OpBuilder::InsertionGuard guard(builder);
3388 builder.createBlock(state.addRegion());
3389 if (name) {
3390 state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
3391 builder.getStringAttr(*name));
3392 }
3393}
3394
3395void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
3396 spirv::AddressingModel addressingModel,
3397 spirv::MemoryModel memoryModel,
3398 std::optional<VerCapExtAttr> vceTriple,
3399 std::optional<StringRef> name) {
3400 state.addAttribute(
3401 "addressing_model",
3402 builder.getAttr<spirv::AddressingModelAttr>(addressingModel));
3403 state.addAttribute("memory_model",
3404 builder.getAttr<spirv::MemoryModelAttr>(memoryModel));
3405 OpBuilder::InsertionGuard guard(builder);
3406 builder.createBlock(state.addRegion());
3407 if (vceTriple)
3408 state.addAttribute(getVCETripleAttrName(), *vceTriple);
3409 if (name)
3410 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
3411 builder.getStringAttr(*name));
3412}
3413
3414ParseResult spirv::ModuleOp::parse(OpAsmParser &parser,
3415 OperationState &result) {
3416 Region *body = result.addRegion();
3417
3418 // If the name is present, parse it.
3419 StringAttr nameAttr;
3420 (void)parser.parseOptionalSymbolName(
3421 nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes);
3422
3423 // Parse attributes
3424 spirv::AddressingModel addrModel;
3425 spirv::MemoryModel memoryModel;
3426 if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser,
3427 result) ||
3428 ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser,
3429 result))
3430 return failure();
3431
3432 if (succeeded(parser.parseOptionalKeyword("requires"))) {
3433 spirv::VerCapExtAttr vceTriple;
3434 if (parser.parseAttribute(vceTriple,
3435 spirv::ModuleOp::getVCETripleAttrName(),
3436 result.attributes))
3437 return failure();
3438 }
3439
3440 if (parser.parseOptionalAttrDictWithKeyword(result.attributes) ||
3441 parser.parseRegion(*body, /*arguments=*/{}))
3442 return failure();
3443
3444 // Make sure we have at least one block.
3445 if (body->empty())
3446 body->push_back(new Block());
3447
3448 return success();
3449}
3450
3451void spirv::ModuleOp::print(OpAsmPrinter &printer) {
3452 if (std::optional<StringRef> name = getName()) {
3453 printer << ' ';
3454 printer.printSymbolName(*name);
3455 }
3456
3457 SmallVector<StringRef, 2> elidedAttrs;
3458
3459 printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " "
3460 << spirv::stringifyMemoryModel(getMemoryModel());
3461 auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
3462 auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
3463 elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
3464 mlir::SymbolTable::getSymbolAttrName()});
3465
3466 if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) {
3467 printer << " requires " << *triple;
3468 elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName());
3469 }
3470
3471 printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs);
3472 printer << ' ';
3473 printer.printRegion(getRegion());
3474}
3475
3476LogicalResult spirv::ModuleOp::verifyRegions() {
3477 Dialect *dialect = (*this)->getDialect();
3478 DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
3479 entryPoints;
3480 mlir::SymbolTable table(*this);
3481
3482 for (auto &op : *getBody()) {
3483 if (op.getDialect() != dialect)
3484 return op.emitError("'spirv.module' can only contain spirv.* ops");
3485
3486 // For EntryPoint op, check that the function and execution model is not
3487 // duplicated in EntryPointOps. Also verify that the interface specified
3488 // comes from globalVariables here to make this check cheaper.
3489 if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) {
3490 auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn());
3491 if (!funcOp) {
3492 return entryPointOp.emitError("function '")
3493 << entryPointOp.getFn() << "' not found in 'spirv.module'";
3494 }
3495 if (auto interface = entryPointOp.getInterface()) {
3496 for (Attribute varRef : interface) {
3497 auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>();
3498 if (!varSymRef) {
3499 return entryPointOp.emitError(
3500 "expected symbol reference for interface "
3501 "specification instead of '")
3502 << varRef;
3503 }
3504 auto variableOp =
3505 table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue());
3506 if (!variableOp) {
3507 return entryPointOp.emitError("expected spirv.GlobalVariable "
3508 "symbol reference instead of'")
3509 << varSymRef << "'";
3510 }
3511 }
3512 }
3513
3514 auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>(
3515 funcOp, entryPointOp.getExecutionModel());
3516 auto entryPtIt = entryPoints.find(key);
3517 if (entryPtIt != entryPoints.end()) {
3518 return entryPointOp.emitError("duplicate of a previous EntryPointOp");
3519 }
3520 entryPoints[key] = entryPointOp;
3521 } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) {
3522 if (funcOp.isExternal())
3523 return op.emitError("'spirv.module' cannot contain external functions");
3524
3525 // TODO: move this check to spirv.func.
3526 for (auto &block : funcOp)
3527 for (auto &op : block) {
3528 if (op.getDialect() != dialect)
3529 return op.emitError(
3530 "functions in 'spirv.module' can only contain spirv.* ops");
3531 }
3532 }
3533 }
3534
3535 return success();
3536}
3537
3538//===----------------------------------------------------------------------===//
3539// spirv.mlir.referenceof
3540//===----------------------------------------------------------------------===//
3541
3542LogicalResult spirv::ReferenceOfOp::verify() {
3543 auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
3544 (*this)->getParentOp(), getSpecConstAttr());
3545 Type constType;
3546
3547 auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
3548 if (specConstOp)
3549 constType = specConstOp.getDefaultValue().getType();
3550
3551 auto specConstCompositeOp =
3552 dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
3553 if (specConstCompositeOp)
3554 constType = specConstCompositeOp.getType();
3555
3556 if (!specConstOp && !specConstCompositeOp)
3557 return emitOpError(
3558 "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol");
3559
3560 if (getReference().getType() != constType)
3561 return emitOpError("result type mismatch with the referenced "
3562 "specialization constant's type");
3563
3564 return success();
3565}
3566
3567//===----------------------------------------------------------------------===//
3568// spirv.Return
3569//===----------------------------------------------------------------------===//
3570
3571LogicalResult spirv::ReturnOp::verify() {
3572 // Verification is performed in spirv.func op.
3573 return success();
3574}
3575
3576//===----------------------------------------------------------------------===//
3577// spirv.ReturnValue
3578//===----------------------------------------------------------------------===//
3579
3580LogicalResult spirv::ReturnValueOp::verify() {
3581 // Verification is performed in spirv.func op.
3582 return success();
3583}
3584
3585//===----------------------------------------------------------------------===//
3586// spirv.Select
3587//===----------------------------------------------------------------------===//
3588
3589LogicalResult spirv::SelectOp::verify() {
3590 if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) {
3591 auto resultVectorTy = getResult().getType().dyn_cast<VectorType>();
3592 if (!resultVectorTy) {
3593 return emitOpError("result expected to be of vector type when "
3594 "condition is of vector type");
3595 }
3596 if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) {
3597 return emitOpError("result should have the same number of elements as "
3598 "the condition when condition is of vector type");
3599 }
3600 }
3601 return success();
3602}
3603
3604//===----------------------------------------------------------------------===//
3605// spirv.mlir.selection
3606//===----------------------------------------------------------------------===//
3607
3608ParseResult spirv::SelectionOp::parse(OpAsmParser &parser,
3609 OperationState &result) {
3610 if (parseControlAttribute<spirv::SelectionControlAttr,
3611 spirv::SelectionControl>(parser, result))
3612 return failure();
3613 return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
3614}
3615
3616void spirv::SelectionOp::print(OpAsmPrinter &printer) {
3617 auto control = getSelectionControl();
3618 if (control != spirv::SelectionControl::None)
3619 printer << " control(" << spirv::stringifySelectionControl(control) << ")";
3620 printer << ' ';
3621 printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3622 /*printBlockTerminators=*/true);
3623}
3624
3625LogicalResult spirv::SelectionOp::verifyRegions() {
3626 auto *op = getOperation();
3627
3628 // We need to verify that the blocks follow the following layout:
3629 //
3630 // +--------------+
3631 // | header block |
3632 // +--------------+
3633 // / | \
3634 // ...
3635 //
3636 //
3637 // +---------+ +---------+ +---------+
3638 // | case #0 | | case #1 | | case #2 | ...
3639 // +---------+ +---------+ +---------+
3640 //
3641 //
3642 // ...
3643 // \ | /
3644 // v
3645 // +-------------+
3646 // | merge block |
3647 // +-------------+
3648
3649 auto &region = op->getRegion(0);
3650 // Allow empty region as a degenerated case, which can come from
3651 // optimizations.
3652 if (region.empty())
3653 return success();
3654
3655 // The last block is the merge block.
3656 if (!isMergeBlock(region.back()))
3657 return emitOpError("last block must be the merge block with only one "
3658 "'spirv.mlir.merge' op");
3659
3660 if (std::next(region.begin()) == region.end())
3661 return emitOpError("must have a selection header block");
3662
3663 return success();
3664}
3665
3666Block *spirv::SelectionOp::getHeaderBlock() {
3667 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3667, __extension__
__PRETTY_FUNCTION__))
;
3668 // The first block is the loop header block.
3669 return &getBody().front();
3670}
3671
3672Block *spirv::SelectionOp::getMergeBlock() {
3673 assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!"
) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3673, __extension__
__PRETTY_FUNCTION__))
;
3674 // The last block is the loop merge block.
3675 return &getBody().back();
3676}
3677
3678void spirv::SelectionOp::addMergeBlock() {
3679 assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist"
) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3679, __extension__
__PRETTY_FUNCTION__))
;
3680 auto *mergeBlock = new Block();
3681 getBody().push_back(mergeBlock);
3682 OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock);
3683
3684 // Add a spirv.mlir.merge op into the merge block.
3685 builder.create<spirv::MergeOp>(getLoc());
3686}
3687
3688spirv::SelectionOp spirv::SelectionOp::createIfThen(
3689 Location loc, Value condition,
3690 function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) {
3691 auto selectionOp =
3692 builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None);
3693
3694 selectionOp.addMergeBlock();
3695 Block *mergeBlock = selectionOp.getMergeBlock();
3696 Block *thenBlock = nullptr;
3697
3698 // Build the "then" block.
3699 {
3700 OpBuilder::InsertionGuard guard(builder);
3701 thenBlock = builder.createBlock(mergeBlock);
3702 thenBody(builder);
3703 builder.create<spirv::BranchOp>(loc, mergeBlock);
3704 }
3705
3706 // Build the header block.
3707 {
3708 OpBuilder::InsertionGuard guard(builder);
3709 builder.createBlock(thenBlock);
3710 builder.create<spirv::BranchConditionalOp>(
3711 loc, condition, thenBlock,
3712 /*trueArguments=*/ArrayRef<Value>(), mergeBlock,
3713 /*falseArguments=*/ArrayRef<Value>());
3714 }
3715
3716 return selectionOp;
3717}
3718
3719//===----------------------------------------------------------------------===//
3720// spirv.SpecConstant
3721//===----------------------------------------------------------------------===//
3722
3723ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser,
3724 OperationState &result) {
3725 StringAttr nameAttr;
3726 Attribute valueAttr;
3727
3728 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
3729 result.attributes))
3730 return failure();
3731
3732 // Parse optional spec_id.
3733 if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) {
3734 IntegerAttr specIdAttr;
3735 if (parser.parseLParen() ||
3736 parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) ||
3737 parser.parseRParen())
3738 return failure();
3739 }
3740
3741 if (parser.parseEqual() ||
3742 parser.parseAttribute(valueAttr, kDefaultValueAttrName,
3743 result.attributes))
3744 return failure();
3745
3746 return success();
3747}
3748
3749void spirv::SpecConstantOp::print(OpAsmPrinter &printer) {
3750 printer << ' ';
3751 printer.printSymbolName(getSymName());
3752 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3753 printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')';
3754 printer << " = " << getDefaultValue();
3755}
3756
3757LogicalResult spirv::SpecConstantOp::verify() {
3758 if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName))
3759 if (specID.getValue().isNegative())
3760 return emitOpError("SpecId cannot be negative");
3761
3762 auto value = getDefaultValue();
3763 if (value.isa<IntegerAttr, FloatAttr>()) {
3764 // Make sure bitwidth is allowed.
3765 if (!value.getType().isa<spirv::SPIRVType>())
3766 return emitOpError("default value bitwidth disallowed");
3767 return success();
3768 }
3769 return emitOpError(
3770 "default value can only be a bool, integer, or float scalar");
3771}
3772
3773//===----------------------------------------------------------------------===//
3774// spirv.StoreOp
3775//===----------------------------------------------------------------------===//
3776
3777ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) {
3778 // Parse the storage class specification
3779 spirv::StorageClass storageClass;
3780 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo;
3781 auto loc = parser.getCurrentLocation();
3782 Type elementType;
3783 if (parseEnumStrAttr(storageClass, parser) ||
3784 parser.parseOperandList(operandInfo, 2) ||
3785 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3786 parser.parseType(elementType)) {
3787 return failure();
3788 }
3789
3790 auto ptrType = spirv::PointerType::get(elementType, storageClass);
3791 if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc,
3792 result.operands)) {
3793 return failure();
3794 }
3795 return success();
3796}
3797
3798void spirv::StoreOp::print(OpAsmPrinter &printer) {
3799 SmallVector<StringRef, 4> elidedAttrs;
3800 StringRef sc = stringifyStorageClass(
3801 getPtr().getType().cast<spirv::PointerType>().getStorageClass());
3802 printer << " \"" << sc << "\" " << getPtr() << ", " << getValue();
3803
3804 printMemoryAccessAttribute(*this, printer, elidedAttrs);
3805
3806 printer << " : " << getValue().getType();
3807 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
3808}
3809
3810LogicalResult spirv::StoreOp::verify() {
3811 // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an
3812 // OpTypePointer whose Type operand is the same as the type of Object."
3813 if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue())))
3814 return failure();
3815 return verifyMemoryAccessAttribute(*this);
3816}
3817
3818//===----------------------------------------------------------------------===//
3819// spirv.Unreachable
3820//===----------------------------------------------------------------------===//
3821
3822LogicalResult spirv::UnreachableOp::verify() {
3823 auto *block = (*this)->getBlock();
3824 // Fast track: if this is in entry block, its invalid. Otherwise, if no
3825 // predecessors, it's valid.
3826 if (block->isEntryBlock())
3827 return emitOpError("cannot be used in reachable block");
3828 if (block->hasNoPredecessors())
3829 return success();
3830
3831 // TODO: further verification needs to analyze reachability from
3832 // the entry block.
3833
3834 return success();
3835}
3836
3837//===----------------------------------------------------------------------===//
3838// spirv.Variable
3839//===----------------------------------------------------------------------===//
3840
3841ParseResult spirv::VariableOp::parse(OpAsmParser &parser,
3842 OperationState &result) {
3843 // Parse optional initializer
3844 std::optional<OpAsmParser::UnresolvedOperand> initInfo;
3845 if (succeeded(parser.parseOptionalKeyword("init"))) {
3846 initInfo = OpAsmParser::UnresolvedOperand();
3847 if (parser.parseLParen() || parser.parseOperand(*initInfo) ||
3848 parser.parseRParen())
3849 return failure();
3850 }
3851
3852 if (parseVariableDecorations(parser, result)) {
3853 return failure();
3854 }
3855
3856 // Parse result pointer type
3857 Type type;
3858 if (parser.parseColon())
3859 return failure();
3860 auto loc = parser.getCurrentLocation();
3861 if (parser.parseType(type))
3862 return failure();
3863
3864 auto ptrType = type.dyn_cast<spirv::PointerType>();
3865 if (!ptrType)
3866 return parser.emitError(loc, "expected spirv.ptr type");
3867 result.addTypes(ptrType);
3868
3869 // Resolve the initializer operand
3870 if (initInfo) {
3871 if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(),
3872 result.operands))
3873 return failure();
3874 }
3875
3876 auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>(
3877 ptrType.getStorageClass());
3878 result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr);
3879
3880 return success();
3881}
3882
3883void spirv::VariableOp::print(OpAsmPrinter &printer) {
3884 SmallVector<StringRef, 4> elidedAttrs{
3885 spirv::attributeName<spirv::StorageClass>()};
3886 // Print optional initializer
3887 if (getNumOperands() != 0)
3888 printer << " init(" << getInitializer() << ")";
3889
3890 printVariableDecorations(*this, printer, elidedAttrs);
3891 printer << " : " << getType();
3892}
3893
3894LogicalResult spirv::VariableOp::verify() {
3895 // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the
3896 // object. It cannot be Generic. It must be the same as the Storage Class
3897 // operand of the Result Type."
3898 if (getStorageClass() != spirv::StorageClass::Function) {
3899 return emitOpError(
3900 "can only be used to model function-level variables. Use "
3901 "spirv.GlobalVariable for module-level variables.");
3902 }
3903
3904 auto pointerType = getPointer().getType().cast<spirv::PointerType>();
3905 if (getStorageClass() != pointerType.getStorageClass())
3906 return emitOpError(
3907 "storage class must match result pointer's storage class");
3908
3909 if (getNumOperands() != 0) {
3910 // SPIR-V spec: "Initializer must be an <id> from a constant instruction or
3911 // a global (module scope) OpVariable instruction".
3912 auto *initOp = getOperand(0).getDefiningOp();
3913 if (!initOp || !isa<spirv::ConstantOp, // for normal constant
3914 spirv::ReferenceOfOp, // for spec constant
3915 spirv::AddressOfOp>(initOp))
3916 return emitOpError("initializer must be the result of a "
3917 "constant or spirv.GlobalVariable op");
3918 }
3919
3920 // TODO: generate these strings using ODS.
3921 auto *op = getOperation();
3922 auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
3923 stringifyDecoration(spirv::Decoration::DescriptorSet));
3924 auto bindingName = llvm::convertToSnakeFromCamelCase(
3925 stringifyDecoration(spirv::Decoration::Binding));
3926 auto builtInName = llvm::convertToSnakeFromCamelCase(
3927 stringifyDecoration(spirv::Decoration::BuiltIn));
3928
3929 for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
3930 if (op->getAttr(attr))
3931 return emitOpError("cannot have '")
3932 << attr << "' attribute (only allowed in spirv.GlobalVariable)";
3933 }
3934
3935 return success();
3936}
3937
3938//===----------------------------------------------------------------------===//
3939// spirv.VectorShuffle
3940//===----------------------------------------------------------------------===//
3941
3942LogicalResult spirv::VectorShuffleOp::verify() {
3943 VectorType resultType = getType().cast<VectorType>();
3944
3945 size_t numResultElements = resultType.getNumElements();
3946 if (numResultElements != getComponents().size())
3947 return emitOpError("result type element count (")
3948 << numResultElements
3949 << ") mismatch with the number of component selectors ("
3950 << getComponents().size() << ")";
3951
3952 size_t totalSrcElements =
3953 getVector1().getType().cast<VectorType>().getNumElements() +
3954 getVector2().getType().cast<VectorType>().getNumElements();
3955
3956 for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) {
3957 uint32_t index = selector.getZExtValue();
3958 if (index >= totalSrcElements &&
3959 index != std::numeric_limits<uint32_t>().max())
3960 return emitOpError("component selector ")
3961 << index << " out of range: expected to be in [0, "
3962 << totalSrcElements << ") or 0xffffffff";
3963 }
3964 return success();
3965}
3966
3967//===----------------------------------------------------------------------===//
3968// spirv.NV.CooperativeMatrixLoad
3969//===----------------------------------------------------------------------===//
3970
3971ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser,
3972 OperationState &result) {
3973 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo;
3974 Type strideType = parser.getBuilder().getIntegerType(32);
3975 Type columnMajorType = parser.getBuilder().getIntegerType(1);
3976 Type ptrType;
3977 Type elementType;
3978 if (parser.parseOperandList(operandInfo, 3) ||
3979 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
3980 parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) {
3981 return failure();
3982 }
3983 if (parser.resolveOperands(operandInfo,
3984 {ptrType, strideType, columnMajorType},
3985 parser.getNameLoc(), result.operands)) {
3986 return failure();
3987 }
3988
3989 result.addTypes(elementType);
3990 return success();
3991}
3992
3993void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) {
3994 printer << " " << getPointer() << ", " << getStride() << ", "
3995 << getColumnmajor();
3996 // Print optional memory access attribute.
3997 if (auto memAccess = getMemoryAccess())
3998 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
3999 printer << " : " << getPointer().getType() << " as " << getType();
4000}
4001
4002static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer,
4003 Type coopMatrix) {
4004 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4005 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4006 return op->emitError(
4007 "Pointer must point to a scalar or vector type but provided ")
4008 << pointeeType;
4009 spirv::StorageClass storage =
4010 pointer.cast<spirv::PointerType>().getStorageClass();
4011 if (storage != spirv::StorageClass::Workgroup &&
4012 storage != spirv::StorageClass::StorageBuffer &&
4013 storage != spirv::StorageClass::PhysicalStorageBuffer)
4014 return op->emitError(
4015 "Pointer storage class must be Workgroup, StorageBuffer or "
4016 "PhysicalStorageBufferEXT but provided ")
4017 << stringifyStorageClass(storage);
4018 return success();
4019}
4020
4021LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() {
4022 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4023 getResult().getType());
4024}
4025
4026//===----------------------------------------------------------------------===//
4027// spirv.NV.CooperativeMatrixStore
4028//===----------------------------------------------------------------------===//
4029
4030ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser,
4031 OperationState &result) {
4032 SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo;
4033 Type strideType = parser.getBuilder().getIntegerType(32);
4034 Type columnMajorType = parser.getBuilder().getIntegerType(1);
4035 Type ptrType;
4036 Type elementType;
4037 if (parser.parseOperandList(operandInfo, 4) ||
4038 parseMemoryAccessAttributes(parser, result) || parser.parseColon() ||
4039 parser.parseType(ptrType) || parser.parseComma() ||
4040 parser.parseType(elementType)) {
4041 return failure();
4042 }
4043 if (parser.resolveOperands(
4044 operandInfo, {ptrType, elementType, strideType, columnMajorType},
4045 parser.getNameLoc(), result.operands)) {
4046 return failure();
4047 }
4048
4049 return success();
4050}
4051
4052void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) {
4053 printer << " " << getPointer() << ", " << getObject() << ", " << getStride()
4054 << ", " << getColumnmajor();
4055 // Print optional memory access attribute.
4056 if (auto memAccess = getMemoryAccess())
4057 printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]";
4058 printer << " : " << getPointer().getType() << ", " << getOperand(1).getType();
4059}
4060
4061LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() {
4062 return verifyPointerAndCoopMatrixType(*this, getPointer().getType(),
4063 getObject().getType());
4064}
4065
4066//===----------------------------------------------------------------------===//
4067// spirv.NV.CooperativeMatrixMulAdd
4068//===----------------------------------------------------------------------===//
4069
4070static LogicalResult
4071verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) {
4072 if (op.getC().getType() != op.getResult().getType())
4073 return op.emitOpError("result and third operand must have the same type");
4074 auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>();
4075 auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>();
4076 auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>();
4077 auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>();
4078 if (typeA.getRows() != typeR.getRows() ||
4079 typeA.getColumns() != typeB.getRows() ||
4080 typeB.getColumns() != typeR.getColumns())
4081 return op.emitOpError("matrix size must match");
4082 if (typeR.getScope() != typeA.getScope() ||
4083 typeR.getScope() != typeB.getScope() ||
4084 typeR.getScope() != typeC.getScope())
4085 return op.emitOpError("matrix scope must match");
4086 auto elementTypeA = typeA.getElementType();
4087 auto elementTypeB = typeB.getElementType();
4088 if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) {
4089 if (elementTypeA.cast<IntegerType>().getWidth() !=
4090 elementTypeB.cast<IntegerType>().getWidth())
4091 return op.emitOpError(
4092 "matrix A and B integer element types must be the same bit width");
4093 } else if (elementTypeA != elementTypeB) {
4094 return op.emitOpError(
4095 "matrix A and B non-integer element types must match");
4096 }
4097 if (typeR.getElementType() != typeC.getElementType())
4098 return op.emitOpError("matrix accumulator element type must match");
4099 return success();
4100}
4101
4102LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() {
4103 return verifyCoopMatrixMulAdd(*this);
4104}
4105
4106static LogicalResult
4107verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) {
4108 Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType();
4109 if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>())
4110 return op->emitError(
4111 "Pointer must point to a scalar or vector type but provided ")
4112 << pointeeType;
4113 spirv::StorageClass storage =
4114 pointer.cast<spirv::PointerType>().getStorageClass();
4115 if (storage != spirv::StorageClass::Workgroup &&
4116 storage != spirv::StorageClass::CrossWorkgroup &&
4117 storage != spirv::StorageClass::UniformConstant &&
4118 storage != spirv::StorageClass::Generic)
4119 return op->emitError("Pointer storage class must be Workgroup or "
4120 "CrossWorkgroup but provided ")
4121 << stringifyStorageClass(storage);
4122 return success();
4123}
4124
4125//===----------------------------------------------------------------------===//
4126// spirv.INTEL.JointMatrixLoad
4127//===----------------------------------------------------------------------===//
4128
4129LogicalResult spirv::INTELJointMatrixLoadOp::verify() {
4130 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4131 getResult().getType());
4132}
4133
4134//===----------------------------------------------------------------------===//
4135// spirv.INTEL.JointMatrixStore
4136//===----------------------------------------------------------------------===//
4137
4138LogicalResult spirv::INTELJointMatrixStoreOp::verify() {
4139 return verifyPointerAndJointMatrixType(*this, getPointer().getType(),
4140 getObject().getType());
4141}
4142
4143//===----------------------------------------------------------------------===//
4144// spirv.INTEL.JointMatrixMad
4145//===----------------------------------------------------------------------===//
4146
4147static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) {
4148 if (op.getC().getType() != op.getResult().getType())
4149 return op.emitOpError("result and third operand must have the same type");
4150 auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>();
4151 auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>();
4152 auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>();
4153 auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>();
4154 if (typeA.getRows() != typeR.getRows() ||
4155 typeA.getColumns() != typeB.getRows() ||
4156 typeB.getColumns() != typeR.getColumns())
4157 return op.emitOpError("matrix size must match");
4158 if (typeR.getScope() != typeA.getScope() ||
4159 typeR.getScope() != typeB.getScope() ||
4160 typeR.getScope() != typeC.getScope())
4161 return op.emitOpError("matrix scope must match");
4162 if (typeA.getElementType() != typeB.getElementType() ||
4163 typeR.getElementType() != typeC.getElementType())
4164 return op.emitOpError("matrix element type must match");
4165 return success();
4166}
4167
4168LogicalResult spirv::INTELJointMatrixMadOp::verify() {
4169 return verifyJointMatrixMad(*this);
4170}
4171
4172//===----------------------------------------------------------------------===//
4173// spirv.MatrixTimesScalar
4174//===----------------------------------------------------------------------===//
4175
4176LogicalResult spirv::MatrixTimesScalarOp::verify() {
4177 if (auto inputCoopmat =
4178 getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) {
4179 if (inputCoopmat.getElementType() != getScalar().getType())
4180 return emitError("input matrix components' type and scaling value must "
4181 "have the same type");
4182 return success();
4183 }
4184
4185 // Check that the scalar type is the same as the matrix element type.
4186 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4187 if (getScalar().getType() != inputMatrix.getElementType())
4188 return emitError("input matrix components' type and scaling value must "
4189 "have the same type");
4190
4191 return success();
4192}
4193
4194//===----------------------------------------------------------------------===//
4195// spirv.CopyMemory
4196//===----------------------------------------------------------------------===//
4197
4198void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) {
4199 printer << ' ';
4200
4201 StringRef targetStorageClass = stringifyStorageClass(
4202 getTarget().getType().cast<spirv::PointerType>().getStorageClass());
4203 printer << " \"" << targetStorageClass << "\" " << getTarget() << ", ";
4204
4205 StringRef sourceStorageClass = stringifyStorageClass(
4206 getSource().getType().cast<spirv::PointerType>().getStorageClass());
4207 printer << " \"" << sourceStorageClass << "\" " << getSource();
4208
4209 SmallVector<StringRef, 4> elidedAttrs;
4210 printMemoryAccessAttribute(*this, printer, elidedAttrs);
4211 printSourceMemoryAccessAttribute(*this, printer, elidedAttrs,
4212 getSourceMemoryAccess(),
4213 getSourceAlignment());
4214
4215 printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
4216
4217 Type pointeeType =
4218 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4219 printer << " : " << pointeeType;
4220}
4221
4222ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser,
4223 OperationState &result) {
4224 spirv::StorageClass targetStorageClass;
4225 OpAsmParser::UnresolvedOperand targetPtrInfo;
4226
4227 spirv::StorageClass sourceStorageClass;
4228 OpAsmParser::UnresolvedOperand sourcePtrInfo;
4229
4230 Type elementType;
4231
4232 if (parseEnumStrAttr(targetStorageClass, parser) ||
4233 parser.parseOperand(targetPtrInfo) || parser.parseComma() ||
4234 parseEnumStrAttr(sourceStorageClass, parser) ||
4235 parser.parseOperand(sourcePtrInfo) ||
4236 parseMemoryAccessAttributes(parser, result)) {
4237 return failure();
4238 }
4239
4240 if (!parser.parseOptionalComma()) {
4241 // Parse 2nd memory access attributes.
4242 if (parseSourceMemoryAccessAttributes(parser, result)) {
4243 return failure();
4244 }
4245 }
4246
4247 if (parser.parseColon() || parser.parseType(elementType))
4248 return failure();
4249
4250 if (parser.parseOptionalAttrDict(result.attributes))
4251 return failure();
4252
4253 auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass);
4254 auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass);
4255
4256 if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) ||
4257 parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) {
4258 return failure();
4259 }
4260
4261 return success();
4262}
4263
4264LogicalResult spirv::CopyMemoryOp::verify() {
4265 Type targetType =
4266 getTarget().getType().cast<spirv::PointerType>().getPointeeType();
4267
4268 Type sourceType =
4269 getSource().getType().cast<spirv::PointerType>().getPointeeType();
4270
4271 if (targetType != sourceType)
4272 return emitOpError("both operands must be pointers to the same type");
4273
4274 if (failed(verifyMemoryAccessAttribute(*this)))
4275 return failure();
4276
4277 // TODO - According to the spec:
4278 //
4279 // If two masks are present, the first applies to Target and cannot include
4280 // MakePointerVisible, and the second applies to Source and cannot include
4281 // MakePointerAvailable.
4282 //
4283 // Add such verification here.
4284
4285 return verifySourceMemoryAccessAttribute(*this);
4286}
4287
4288//===----------------------------------------------------------------------===//
4289// spirv.Transpose
4290//===----------------------------------------------------------------------===//
4291
4292LogicalResult spirv::TransposeOp::verify() {
4293 auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>();
4294 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4295
4296 // Verify that the input and output matrices have correct shapes.
4297 if (inputMatrix.getNumRows() != resultMatrix.getNumColumns())
4298 return emitError("input matrix rows count must be equal to "
4299 "output matrix columns count");
4300
4301 if (inputMatrix.getNumColumns() != resultMatrix.getNumRows())
4302 return emitError("input matrix columns count must be equal to "
4303 "output matrix rows count");
4304
4305 // Verify that the input and output matrices have the same component type
4306 if (inputMatrix.getElementType() != resultMatrix.getElementType())
4307 return emitError("input and output matrices must have the same "
4308 "component type");
4309
4310 return success();
4311}
4312
4313//===----------------------------------------------------------------------===//
4314// spirv.MatrixTimesMatrix
4315//===----------------------------------------------------------------------===//
4316
4317LogicalResult spirv::MatrixTimesMatrixOp::verify() {
4318 auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>();
4319 auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>();
4320 auto resultMatrix = getResult().getType().cast<spirv::MatrixType>();
4321
4322 // left matrix columns' count and right matrix rows' count must be equal
4323 if (leftMatrix.getNumColumns() != rightMatrix.getNumRows())
4324 return emitError("left matrix columns' count must be equal to "
4325 "the right matrix rows' count");
4326
4327 // right and result matrices columns' count must be the same
4328 if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns())
4329 return emitError(
4330 "right and result matrices must have equal columns' count");
4331
4332 // right and result matrices component type must be the same
4333 if (rightMatrix.getElementType() != resultMatrix.getElementType())
4334 return emitError("right and result matrices' component type must"
4335 " be the same");
4336
4337 // left and result matrices component type must be the same
4338 if (leftMatrix.getElementType() != resultMatrix.getElementType())
4339 return emitError("left and result matrices' component type"
4340 " must be the same");
4341
4342 // left and result matrices rows count must be the same
4343 if (leftMatrix.getNumRows() != resultMatrix.getNumRows())
4344 return emitError("left and result matrices must have equal rows' count");
4345
4346 return success();
4347}
4348
4349//===----------------------------------------------------------------------===//
4350// spirv.SpecConstantComposite
4351//===----------------------------------------------------------------------===//
4352
4353ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser,
4354 OperationState &result) {
4355
4356 StringAttr compositeName;
4357 if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(),
4358 result.attributes))
4359 return failure();
4360
4361 if (parser.parseLParen())
4362 return failure();
4363
4364 SmallVector<Attribute, 4> constituents;
4365
4366 do {
4367 // The name of the constituent attribute isn't important
4368 const char *attrName = "spec_const";
4369 FlatSymbolRefAttr specConstRef;
4370 NamedAttrList attrs;
4371
4372 if (parser.parseAttribute(specConstRef, Type(), attrName, attrs))
4373 return failure();
4374
4375 constituents.push_back(specConstRef);
4376 } while (!parser.parseOptionalComma());
4377
4378 if (parser.parseRParen())
4379 return failure();
4380
4381 result.addAttribute(kCompositeSpecConstituentsName,
4382 parser.getBuilder().getArrayAttr(constituents));
4383
4384 Type type;
4385 if (parser.parseColonType(type))
4386 return failure();
4387
4388 result.addAttribute(kTypeAttrName, TypeAttr::get(type));
4389
4390 return success();
4391}
4392
4393void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) {
4394 printer << " ";
4395 printer.printSymbolName(getSymName());
4396 printer << " (";
4397 auto constituents = this->getConstituents().getValue();
4398
4399 if (!constituents.empty())
4400 llvm::interleaveComma(constituents, printer);
4401
4402 printer << ") : " << getType();
4403}
4404
4405LogicalResult spirv::SpecConstantCompositeOp::verify() {
4406 auto cType = getType().dyn_cast<spirv::CompositeType>();
4407 auto constituents = this->getConstituents().getValue();
4408
4409 if (!cType)
4410 return emitError("result type must be a composite type, but provided ")
4411 << getType();
4412
4413 if (cType.isa<spirv::CooperativeMatrixNVType>())
4414 return emitError("unsupported composite type ") << cType;
4415 if (cType.isa<spirv::JointMatrixINTELType>())
4416 return emitError("unsupported composite type ") << cType;
4417 if (constituents.size() != cType.getNumElements())
4418 return emitError("has incorrect number of operands: expected ")
4419 << cType.getNumElements() << ", but provided "
4420 << constituents.size();
4421
4422 for (auto index : llvm::seq<uint32_t>(0, constituents.size())) {
4423 auto constituent = constituents[index].cast<FlatSymbolRefAttr>();
4424
4425 auto constituentSpecConstOp =
4426 dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom(
4427 (*this)->getParentOp(), constituent.getAttr()));
4428
4429 if (constituentSpecConstOp.getDefaultValue().getType() !=
4430 cType.getElementType(index))
4431 return emitError("has incorrect types of operands: expected ")
4432 << cType.getElementType(index) << ", but provided "
4433 << constituentSpecConstOp.getDefaultValue().getType();
4434 }
4435
4436 return success();
4437}
4438
4439//===----------------------------------------------------------------------===//
4440// spirv.SpecConstantOperation
4441//===----------------------------------------------------------------------===//
4442
4443ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser,
4444 OperationState &result) {
4445 Region *body = result.addRegion();
4446
4447 if (parser.parseKeyword("wraps"))
4448 return failure();
4449
4450 body->push_back(new Block);
4451 Block &block = body->back();
4452 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
4453
4454 if (!wrappedOp)
4455 return failure();
4456
4457 OpBuilder builder(parser.getContext());
4458 builder.setInsertionPointToEnd(&block);
4459 builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0));
4460 result.location = wrappedOp->getLoc();
4461
4462 result.addTypes(wrappedOp->getResult(0).getType());
4463
4464 if (parser.parseOptionalAttrDict(result.attributes))
4465 return failure();
4466
4467 return success();
4468}
4469
4470void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) {
4471 printer << " wraps ";
4472 printer.printGenericOp(&getBody().front().front());
4473}
4474
4475LogicalResult spirv::SpecConstantOperationOp::verifyRegions() {
4476 Block &block = getRegion().getBlocks().front();
4477
4478 if (block.getOperations().size() != 2)
4479 return emitOpError("expected exactly 2 nested ops");
4480
4481 Operation &enclosedOp = block.getOperations().front();
4482
4483 if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>())
4484 return emitOpError("invalid enclosed op");
4485
4486 for (auto operand : enclosedOp.getOperands())
4487 if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp,
4488 spirv::SpecConstantOperationOp>(operand.getDefiningOp()))
4489 return emitOpError(
4490 "invalid operand, must be defined by a constant operation");
4491
4492 return success();
4493}
4494
4495//===----------------------------------------------------------------------===//
4496// spirv.GL.FrexpStruct
4497//===----------------------------------------------------------------------===//
4498
4499LogicalResult spirv::GLFrexpStructOp::verify() {
4500 spirv::StructType structTy =
4501 getResult().getType().dyn_cast<spirv::StructType>();
4502
4503 if (structTy.getNumElements() != 2)
4504 return emitError("result type must be a struct type with two memebers");
4505
4506 Type significandTy = structTy.getElementType(0);
4507 Type exponentTy = structTy.getElementType(1);
4508 VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>();
4509 IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>();
4510
4511 Type operandTy = getOperand().getType();
4512 VectorType operandVecTy = operandTy.dyn_cast<VectorType>();
4513 FloatType operandFTy = operandTy.dyn_cast<FloatType>();
4514
4515 if (significandTy != operandTy)
4516 return emitError("member zero of the resulting struct type must be the "
4517 "same type as the operand");
4518
4519 if (exponentVecTy) {
4520 IntegerType componentIntTy =
4521 exponentVecTy.getElementType().dyn_cast<IntegerType>();
4522 if (!componentIntTy || componentIntTy.getWidth() != 32)
4523 return emitError("member one of the resulting struct type must"
4524 "be a scalar or vector of 32 bit integer type");
4525 } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) {
4526 return emitError("member one of the resulting struct type "
4527 "must be a scalar or vector of 32 bit integer type");
4528 }
4529
4530 // Check that the two member types have the same number of components
4531 if (operandVecTy && exponentVecTy &&
4532 (exponentVecTy.getNumElements() == operandVecTy.getNumElements()))
4533 return success();
4534
4535 if (operandFTy && exponentIntTy)
4536 return success();
4537
4538 return emitError("member one of the resulting struct type must have the same "
4539 "number of components as the operand type");
4540}
4541
4542//===----------------------------------------------------------------------===//
4543// spirv.GL.Ldexp
4544//===----------------------------------------------------------------------===//
4545
4546LogicalResult spirv::GLLdexpOp::verify() {
4547 Type significandType = getX().getType();
4548 Type exponentType = getExp().getType();
4549
4550 if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>())
4551 return emitOpError("operands must both be scalars or vectors");
4552
4553 auto getNumElements = [](Type type) -> unsigned {
4554 if (auto vectorType = type.dyn_cast<VectorType>())
4555 return vectorType.getNumElements();
4556 return 1;
4557 };
4558
4559 if (getNumElements(significandType) != getNumElements(exponentType))
4560 return emitOpError("operands must have the same number of elements");
4561
4562 return success();
4563}
4564
4565//===----------------------------------------------------------------------===//
4566// spirv.ImageDrefGather
4567//===----------------------------------------------------------------------===//
4568
4569LogicalResult spirv::ImageDrefGatherOp::verify() {
4570 VectorType resultType = getResult().getType().cast<VectorType>();
4571 auto sampledImageType =
4572 getSampledimage().getType().cast<spirv::SampledImageType>();
4573 auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>();
4574
4575 if (resultType.getNumElements() != 4)
4576 return emitOpError("result type must be a vector of four components");
4577
4578 Type elementType = resultType.getElementType();
4579 Type sampledElementType = imageType.getElementType();
4580 if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType)
4581 return emitOpError(
4582 "the component type of result must be the same as sampled type of the "
4583 "underlying image type");
4584
4585 spirv::Dim imageDim = imageType.getDim();
4586 spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo();
4587
4588 if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube &&
4589 imageDim != spirv::Dim::Rect)
4590 return emitOpError(
4591 "the Dim operand of the underlying image type must be 2D, Cube, or "
4592 "Rect");
4593
4594 if (imageMS != spirv::ImageSamplingInfo::SingleSampled)
4595 return emitOpError("the MS operand of the underlying image type must be 0");
4596
4597 spirv::ImageOperandsAttr attr = getImageoperandsAttr();
4598 auto operandArguments = getOperandArguments();
4599
4600 return verifyImageOperands(*this, attr, operandArguments);
4601}
4602
4603//===----------------------------------------------------------------------===//
4604// spirv.ShiftLeftLogicalOp
4605//===----------------------------------------------------------------------===//
4606
4607LogicalResult spirv::ShiftLeftLogicalOp::verify() {
4608 return verifyShiftOp(*this);
4609}
4610
4611//===----------------------------------------------------------------------===//
4612// spirv.ShiftRightArithmeticOp
4613//===----------------------------------------------------------------------===//
4614
4615LogicalResult spirv::ShiftRightArithmeticOp::verify() {
4616 return verifyShiftOp(*this);
4617}
4618
4619//===----------------------------------------------------------------------===//
4620// spirv.ShiftRightLogicalOp
4621//===----------------------------------------------------------------------===//
4622
4623LogicalResult spirv::ShiftRightLogicalOp::verify() {
4624 return verifyShiftOp(*this);
4625}
4626
4627//===----------------------------------------------------------------------===//
4628// spirv.ImageQuerySize
4629//===----------------------------------------------------------------------===//
4630
4631LogicalResult spirv::ImageQuerySizeOp::verify() {
4632 spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>();
4633 Type resultType = getResult().getType();
4634
4635 spirv::Dim dim = imageType.getDim();
4636 spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
4637 spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
4638 switch (dim) {
4639 case spirv::Dim::Dim1D:
4640 case spirv::Dim::Dim2D:
4641 case spirv::Dim::Dim3D:
4642 case spirv::Dim::Cube:
4643 if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
4644 samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
4645 samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
4646 return emitError(
4647 "if Dim is 1D, 2D, 3D, or Cube, "
4648 "it must also have either an MS of 1 or a Sampled of 0 or 2");
4649 break;
4650 case spirv::Dim::Buffer:
4651 case spirv::Dim::Rect:
4652 break;
4653 default:
4654 return emitError("the Dim operand of the image type must "
4655 "be 1D, 2D, 3D, Buffer, Cube, or Rect");
4656 }
4657
4658 unsigned componentNumber = 0;
4659 switch (dim) {
4660 case spirv::Dim::Dim1D:
4661 case spirv::Dim::Buffer:
4662 componentNumber = 1;
4663 break;
4664 case spirv::Dim::Dim2D:
4665 case spirv::Dim::Cube:
4666 case spirv::Dim::Rect:
4667 componentNumber = 2;
4668 break;
4669 case spirv::Dim::Dim3D:
4670 componentNumber = 3;
4671 break;
4672 default:
4673 break;
4674 }
4675
4676 if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
4677 componentNumber += 1;
4678
4679 unsigned resultComponentNumber = 1;
4680 if (auto resultVectorType = resultType.dyn_cast<VectorType>())
4681 resultComponentNumber = resultVectorType.getNumElements();
4682
4683 if (componentNumber != resultComponentNumber)
4684 return emitError("expected the result to have ")
4685 << componentNumber << " component(s), but found "
4686 << resultComponentNumber << " component(s)";
4687
4688 return success();
4689}
4690
4691static ParseResult parsePtrAccessChainOpImpl(StringRef opName,
4692 OpAsmParser &parser,
4693 OperationState &state) {
4694 OpAsmParser::UnresolvedOperand ptrInfo;
4695 SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo;
4696 Type type;
4697 auto loc = parser.getCurrentLocation();
4698 SmallVector<Type, 4> indicesTypes;
4699
4700 if (parser.parseOperand(ptrInfo) ||
4701 parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) ||
4702 parser.parseColonType(type) ||
4703 parser.resolveOperand(ptrInfo, type, state.operands))
4704 return failure();
4705
4706 // Check that the provided indices list is not empty before parsing their
4707 // type list.
4708 if (indicesInfo.empty())
4709 return emitError(state.location) << opName << " expected element";
4710
4711 if (parser.parseComma() || parser.parseTypeList(indicesTypes))
4712 return failure();
4713
4714 // Check that the indices types list is not empty and that it has a one-to-one
4715 // mapping to the provided indices.
4716 if (indicesTypes.size() != indicesInfo.size())
4717 return emitError(state.location)
4718 << opName
4719 << " indices types' count must be equal to indices info count";
4720
4721 if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands))
4722 return failure();
4723
4724 auto resultType = getElementPtrType(
4725 type, llvm::ArrayRef(state.operands).drop_front(2), state.location);
4726 if (!resultType)
4727 return failure();
4728
4729 state.addTypes(resultType);
4730 return success();
4731}
4732
4733template <typename Op>
4734static auto concatElemAndIndices(Op op) {
4735 SmallVector<Value> ret(op.getIndices().size() + 1);
4736 ret[0] = op.getElement();
4737 llvm::copy(op.getIndices(), ret.begin() + 1);
4738 return ret;
4739}
4740
4741//===----------------------------------------------------------------------===//
4742// spirv.InBoundsPtrAccessChainOp
4743//===----------------------------------------------------------------------===//
4744
4745void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder,
4746 OperationState &state,
4747 Value basePtr, Value element,
4748 ValueRange indices) {
4749 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4750 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4750, __extension__
__PRETTY_FUNCTION__))
;
4751 build(builder, state, type, basePtr, element, indices);
4752}
4753
4754ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser,
4755 OperationState &result) {
4756 return parsePtrAccessChainOpImpl(
4757 spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result);
4758}
4759
4760void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) {
4761 printAccessChain(*this, concatElemAndIndices(*this), printer);
4762}
4763
4764LogicalResult spirv::InBoundsPtrAccessChainOp::verify() {
4765 return verifyAccessChain(*this, getIndices());
4766}
4767
4768//===----------------------------------------------------------------------===//
4769// spirv.PtrAccessChainOp
4770//===----------------------------------------------------------------------===//
4771
4772void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state,
4773 Value basePtr, Value element,
4774 ValueRange indices) {
4775 auto type = getElementPtrType(basePtr.getType(), indices, state.location);
4776 assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices"
) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4776, __extension__
__PRETTY_FUNCTION__))
;
4777 build(builder, state, type, basePtr, element, indices);
4778}
4779
4780ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser,
4781 OperationState &result) {
4782 return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(),
4783 parser, result);
4784}
4785
4786void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) {
4787 printAccessChain(*this, concatElemAndIndices(*this), printer);
4788}
4789
4790LogicalResult spirv::PtrAccessChainOp::verify() {
4791 return verifyAccessChain(*this, getIndices());
4792}
4793
4794//===----------------------------------------------------------------------===//
4795// spirv.VectorTimesScalarOp
4796//===----------------------------------------------------------------------===//
4797
4798LogicalResult spirv::VectorTimesScalarOp::verify() {
4799 if (getVector().getType() != getType())
4800 return emitOpError("vector operand and result type mismatch");
4801 auto scalarType = getType().cast<VectorType>().getElementType();
4802 if (getScalar().getType() != scalarType)
4803 return emitOpError("scalar operand and result element type match");
4804 return success();
4805}
4806
4807//===----------------------------------------------------------------------===//
4808// Group ops
4809//===----------------------------------------------------------------------===//
4810
4811template <typename Op>
4812static LogicalResult verifyGroupOp(Op op) {
4813 spirv::Scope scope = op.getExecutionScope();
4814 if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup)
4815 return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'");
4816
4817 return success();
4818}
4819
4820LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); }
4821
4822LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); }
4823
4824LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); }
4825
4826LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); }
4827
4828LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); }
4829
4830LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); }
4831
4832LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); }
4833
4834LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); }
4835
4836LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); }
4837
4838LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); }
4839
4840//===----------------------------------------------------------------------===//
4841// Integer Dot Product ops
4842//===----------------------------------------------------------------------===//
4843
4844static LogicalResult verifyIntegerDotProduct(Operation *op) {
4845 assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&(static_cast <bool> (llvm::is_contained({2u, 3u}, op->
getNumOperands()) && "Not an integer dot product op?"
) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4846, __extension__
__PRETTY_FUNCTION__))
4846 "Not an integer dot product op?")(static_cast <bool> (llvm::is_contained({2u, 3u}, op->
getNumOperands()) && "Not an integer dot product op?"
) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4846, __extension__
__PRETTY_FUNCTION__))
;
4847 assert(op->getNumResults() == 1 && "Expected a single result")(static_cast <bool> (op->getNumResults() == 1 &&
"Expected a single result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"Expected a single result\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4847, __extension__
__PRETTY_FUNCTION__))
;
4848
4849 Type factorTy = op->getOperand(0).getType();
4850 if (op->getOperand(1).getType() != factorTy)
4851 return op->emitOpError("requires the same type for both vector operands");
4852
4853 unsigned expectedNumAttrs = 0;
4854 if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4855 ++expectedNumAttrs;
4856 auto packedVectorFormat =
4857 op->getAttr(kPackedVectorFormatAttrName)
4858 .dyn_cast_or_null<spirv::PackedVectorFormatAttr>();
4859 if (!packedVectorFormat)
4860 return op->emitOpError("requires Packed Vector Format attribute for "
4861 "integer vector operands");
4862
4863 assert(packedVectorFormat.getValue() ==(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__
__PRETTY_FUNCTION__))
4864 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__
__PRETTY_FUNCTION__))
4865 "Unknown Packed Vector Format")(static_cast <bool> (packedVectorFormat.getValue() == spirv
::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format"
) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\""
, "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__
__PRETTY_FUNCTION__))
;
4866 if (intTy.getWidth() != 32)
4867 return op->emitOpError(
4868 llvm::formatv("with specified Packed Vector Format ({0}) requires "
4869 "integer vector operands to be 32-bits wide",
4870 packedVectorFormat.getValue()));
4871 } else {
4872 if (op->hasAttr(kPackedVectorFormatAttrName))
4873 return op->emitOpError(llvm::formatv(
4874 "with invalid format attribute for vector operands of type '{0}'",
4875 factorTy));
4876 }
4877
4878 if (op->getAttrs().size() > expectedNumAttrs)
4879 return op->emitError(
4880 "op only supports the 'format' #spirv.packed_vector_format attribute");
4881
4882 Type resultTy = op->getResultTypes().front();
4883 bool hasAccumulator = op->getNumOperands() == 3;
4884 if (hasAccumulator && op->getOperand(2).getType() != resultTy)
4885 return op->emitOpError(
4886 "requires the same accumulator operand and result types");
4887
4888 unsigned factorBitWidth = getBitWidth(factorTy);
4889 unsigned resultBitWidth = getBitWidth(resultTy);
4890 if (factorBitWidth > resultBitWidth)
4891 return op->emitOpError(
4892 llvm::formatv("result type has insufficient bit-width ({0} bits) "
4893 "for the specified vector operand type ({1} bits)",
4894 resultBitWidth, factorBitWidth));
4895
4896 return success();
4897}
4898
4899static std::optional<spirv::Version> getIntegerDotProductMinVersion() {
4900 return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0.
4901}
4902
4903static std::optional<spirv::Version> getIntegerDotProductMaxVersion() {
4904 return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6.
4905}
4906
4907static SmallVector<ArrayRef<spirv::Extension>, 1>
4908getIntegerDotProductExtensions() {
4909 // Requires the SPV_KHR_integer_dot_product extension, specified either
4910 // explicitly or implied by target env's SPIR-V version >= 1.6.
4911 static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product;
4912 return {extension};
4913}
4914
4915static SmallVector<ArrayRef<spirv::Capability>, 1>
4916getIntegerDotProductCapabilities(Operation *op) {
4917 // Requires the the DotProduct capability and capabilities that depend on
4918 // exact op types.
4919 static const auto dotProductCap = spirv::Capability::DotProduct;
4920 static const auto dotProductInput4x8BitPackedCap =
4921 spirv::Capability::DotProductInput4x8BitPacked;
4922 static const auto dotProductInput4x8BitCap =
4923 spirv::Capability::DotProductInput4x8Bit;
4924 static const auto dotProductInputAllCap =
4925 spirv::Capability::DotProductInputAll;
4926
4927 SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap};
4928
4929 Type factorTy = op->getOperand(0).getType();
4930 if (auto intTy = factorTy.dyn_cast<IntegerType>()) {
4931 auto formatAttr = op->getAttr(kPackedVectorFormatAttrName)
4932 .cast<spirv::PackedVectorFormatAttr>();
4933 if (formatAttr.getValue() ==
4934 spirv::PackedVectorFormat::PackedVectorFormat4x8Bit)
4935 capabilities.push_back(dotProductInput4x8BitPackedCap);
4936
4937 return capabilities;
4938 }
4939
4940 auto vecTy = factorTy.cast<VectorType>();
4941 if (vecTy.getElementTypeBitWidth() == 8) {
4942 capabilities.push_back(dotProductInput4x8BitCap);
4943 return capabilities;
4944 }
4945
4946 capabilities.push_back(dotProductInputAllCap);
4947 return capabilities;
4948}
4949
4950#define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \
4951 LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \
4952 SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \
4953 return getIntegerDotProductExtensions(); \
4954 } \
4955 SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \
4956 return getIntegerDotProductCapabilities(*this); \
4957 } \
4958 std::optional<spirv::Version> OpName::getMinVersion() { \
4959 return getIntegerDotProductMinVersion(); \
4960 } \
4961 std::optional<spirv::Version> OpName::getMaxVersion() { \
4962 return getIntegerDotProductMaxVersion(); \
4963 }
4964
4965SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp)
4966SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp)
4967SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp)
4968SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp)
4969SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp)
4970SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp)
4971
4972#undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP
4973
4974// TableGen'erated operation interfaces for querying versions, extensions, and
4975// capabilities.
4976#include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc"
4977
4978// TablenGen'erated operation definitions.
4979#define GET_OP_CLASSES
4980#include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc"
4981
4982namespace mlir {
4983namespace spirv {
4984// TableGen'erated operation availability interface implementations.
4985#include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc"
4986} // namespace spirv
4987} // namespace mlir

/build/source/mlir/include/mlir/IR/Builders.h

1//===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_BUILDERS_H
10#define MLIR_IR_BUILDERS_H
11
12#include "mlir/IR/OpDefinition.h"
13#include "llvm/Support/Compiler.h"
14#include <optional>
15
16namespace mlir {
17
18class AffineExpr;
19class IRMapping;
20class UnknownLoc;
21class FileLineColLoc;
22class Type;
23class PrimitiveType;
24class IntegerType;
25class FloatType;
26class FunctionType;
27class IndexType;
28class MemRefType;
29class VectorType;
30class RankedTensorType;
31class UnrankedTensorType;
32class TupleType;
33class NoneType;
34class BoolAttr;
35class IntegerAttr;
36class FloatAttr;
37class StringAttr;
38class TypeAttr;
39class ArrayAttr;
40class SymbolRefAttr;
41class ElementsAttr;
42class DenseElementsAttr;
43class DenseIntElementsAttr;
44class AffineMapAttr;
45class AffineMap;
46class UnitAttr;
47
48/// This class is a general helper class for creating context-global objects
49/// like types, attributes, and affine expressions.
50class Builder {
51public:
52 explicit Builder(MLIRContext *context) : context(context) {}
53 explicit Builder(Operation *op) : Builder(op->getContext()) {}
54
55 MLIRContext *getContext() const { return context; }
56
57 // Locations.
58 Location getUnknownLoc();
59 Location getFusedLoc(ArrayRef<Location> locs,
60 Attribute metadata = Attribute());
61
62 // Types.
63 FloatType getFloat8E5M2Type();
64 FloatType getFloat8E4M3FNType();
65 FloatType getFloat8E5M2FNUZType();
66 FloatType getFloat8E4M3FNUZType();
67 FloatType getFloat8E4M3B11FNUZType();
68 FloatType getBF16Type();
69 FloatType getF16Type();
70 FloatType getF32Type();
71 FloatType getF64Type();
72 FloatType getF80Type();
73 FloatType getF128Type();
74
75 IndexType getIndexType();
76
77 IntegerType getI1Type();
78 IntegerType getI2Type();
79 IntegerType getI4Type();
80 IntegerType getI8Type();
81 IntegerType getI16Type();
82 IntegerType getI32Type();
83 IntegerType getI64Type();
84 IntegerType getIntegerType(unsigned width);
85 IntegerType getIntegerType(unsigned width, bool isSigned);
86 FunctionType getFunctionType(TypeRange inputs, TypeRange results);
87 TupleType getTupleType(TypeRange elementTypes);
88 NoneType getNoneType();
89
90 /// Get or construct an instance of the type `Ty` with provided arguments.
91 template <typename Ty, typename... Args>
92 Ty getType(Args &&...args) {
93 return Ty::get(context, std::forward<Args>(args)...);
94 }
95
96 /// Get or construct an instance of the attribute `Attr` with provided
97 /// arguments.
98 template <typename Attr, typename... Args>
99 Attr getAttr(Args &&...args) {
100 return Attr::get(context, std::forward<Args>(args)...);
9
Calling 'forward<mlir::spirv::ExecutionModel &>'
10
Returning from 'forward<mlir::spirv::ExecutionModel &>'
11
2nd function call argument is an uninitialized value
101 }
102
103 // Attributes.
104 NamedAttribute getNamedAttr(StringRef name, Attribute val);
105
106 UnitAttr getUnitAttr();
107 BoolAttr getBoolAttr(bool value);
108 DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value);
109 IntegerAttr getIntegerAttr(Type type, int64_t value);
110 IntegerAttr getIntegerAttr(Type type, const APInt &value);
111 FloatAttr getFloatAttr(Type type, double value);
112 FloatAttr getFloatAttr(Type type, const APFloat &value);
113 StringAttr getStringAttr(const Twine &bytes);
114 ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
115
116 // Returns a 0-valued attribute of the given `type`. This function only
117 // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
118 // ranked tensor of them. Returns null attribute otherwise.
119 TypedAttr getZeroAttr(Type type);
120
121 // Convenience methods for fixed types.
122 FloatAttr getF16FloatAttr(float value);
123 FloatAttr getF32FloatAttr(float value);
124 FloatAttr getF64FloatAttr(double value);
125
126 IntegerAttr getI8IntegerAttr(int8_t value);
127 IntegerAttr getI16IntegerAttr(int16_t value);
128 IntegerAttr getI32IntegerAttr(int32_t value);
129 IntegerAttr getI64IntegerAttr(int64_t value);
130 IntegerAttr getIndexAttr(int64_t value);
131
132 /// Signed and unsigned integer attribute getters.
133 IntegerAttr getSI32IntegerAttr(int32_t value);
134 IntegerAttr getUI32IntegerAttr(uint32_t value);
135
136 /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty.
137 DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values);
138 DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values);
139 DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values);
140 DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values);
141
142 /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty.
143 /// These are generally preferable for representing general lists of integers
144 /// as attributes.
145 DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values);
146 DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values);
147 DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values);
148
149 /// Tensor-typed DenseArrayAttr getters.
150 DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef<bool> values);
151 DenseI8ArrayAttr getDenseI8ArrayAttr(ArrayRef<int8_t> values);
152 DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef<int16_t> values);
153 DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef<int32_t> values);
154 DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef<int64_t> values);
155 DenseF32ArrayAttr getDenseF32ArrayAttr(ArrayRef<float> values);
156 DenseF64ArrayAttr getDenseF64ArrayAttr(ArrayRef<double> values);
157
158 ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
159 ArrayAttr getBoolArrayAttr(ArrayRef<bool> values);
160 ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
161 ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
162 ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values);
163 ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
164 ArrayAttr getF64ArrayAttr(ArrayRef<double> values);
165 ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values);
166 ArrayAttr getTypeArrayAttr(TypeRange values);
167
168 // Affine expressions and affine maps.
169 AffineExpr getAffineDimExpr(unsigned position);
170 AffineExpr getAffineSymbolExpr(unsigned position);
171 AffineExpr getAffineConstantExpr(int64_t constant);
172
173 // Special cases of affine maps and integer sets
174 /// Returns a zero result affine map with no dimensions or symbols: () -> ().
175 AffineMap getEmptyAffineMap();
176 /// Returns a single constant result affine map with 0 dimensions and 0
177 /// symbols. One constant result: () -> (val).
178 AffineMap getConstantAffineMap(int64_t val);
179 // One dimension id identity map: (i) -> (i).
180 AffineMap getDimIdentityMap();
181 // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2).
182 AffineMap getMultiDimIdentityMap(unsigned rank);
183 // One symbol identity map: ()[s] -> (s).
184 AffineMap getSymbolIdentityMap();
185
186 /// Returns a map that shifts its (single) input dimension by 'shift'.
187 /// (d0) -> (d0 + shift)
188 AffineMap getSingleDimShiftAffineMap(int64_t shift);
189
190 /// Returns an affine map that is a translation (shift) of all result
191 /// expressions in 'map' by 'shift'.
192 /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2
193 /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2)
194 AffineMap getShiftedAffineMap(AffineMap map, int64_t shift);
195
196protected:
197 MLIRContext *context;
198};
199
200/// This class helps build Operations. Operations that are created are
201/// automatically inserted at an insertion point. The builder is copyable.
202class OpBuilder : public Builder {
203public:
204 struct Listener;
205
206 /// Create a builder with the given context.
207 explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr)
208 : Builder(ctx), listener(listener) {}
209
210 /// Create a builder and set the insertion point to the start of the region.
211 explicit OpBuilder(Region *region, Listener *listener = nullptr)
212 : OpBuilder(region->getContext(), listener) {
213 if (!region->empty())
214 setInsertionPoint(&region->front(), region->front().begin());
215 }
216 explicit OpBuilder(Region &region, Listener *listener = nullptr)
217 : OpBuilder(&region, listener) {}
218
219 /// Create a builder and set insertion point to the given operation, which
220 /// will cause subsequent insertions to go right before it.
221 explicit OpBuilder(Operation *op, Listener *listener = nullptr)
222 : OpBuilder(op->getContext(), listener) {
223 setInsertionPoint(op);
224 }
225
226 OpBuilder(Block *block, Block::iterator insertPoint,
227 Listener *listener = nullptr)
228 : OpBuilder(block->getParent()->getContext(), listener) {
229 setInsertionPoint(block, insertPoint);
230 }
231
232 /// Create a builder and set the insertion point to before the first operation
233 /// in the block but still inside the block.
234 static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) {
235 return OpBuilder(block, block->begin(), listener);
236 }
237
238 /// Create a builder and set the insertion point to after the last operation
239 /// in the block but still inside the block.
240 static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) {
241 return OpBuilder(block, block->end(), listener);
242 }
243
244 /// Create a builder and set the insertion point to before the block
245 /// terminator.
246 static OpBuilder atBlockTerminator(Block *block,
247 Listener *listener = nullptr) {
248 auto *terminator = block->getTerminator();
249 assert(terminator != nullptr && "the block has no terminator")(static_cast <bool> (terminator != nullptr && "the block has no terminator"
) ? void (0) : __assert_fail ("terminator != nullptr && \"the block has no terminator\""
, "mlir/include/mlir/IR/Builders.h", 249, __extension__ __PRETTY_FUNCTION__
))
;
250 return OpBuilder(block, Block::iterator(terminator), listener);
251 }
252
253 //===--------------------------------------------------------------------===//
254 // Listeners
255 //===--------------------------------------------------------------------===//
256
257 /// Base class for listeners.
258 struct ListenerBase {
259 /// The kind of listener.
260 enum class Kind {
261 /// OpBuilder::Listener or user-derived class.
262 OpBuilderListener = 0,
263
264 /// RewriterBase::Listener or user-derived class.
265 RewriterBaseListener = 1
266 };
267
268 Kind getKind() const { return kind; }
269
270 protected:
271 ListenerBase(Kind kind) : kind(kind) {}
272
273 private:
274 const Kind kind;
275 };
276
277 /// This class represents a listener that may be used to hook into various
278 /// actions within an OpBuilder.
279 struct Listener : public ListenerBase {
280 Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {}
281
282 virtual ~Listener() = default;
283
284 /// Notification handler for when an operation is inserted into the builder.
285 /// `op` is the operation that was inserted.
286 virtual void notifyOperationInserted(Operation *op) {}
287
288 /// Notification handler for when a block is created using the builder.
289 /// `block` is the block that was created.
290 virtual void notifyBlockCreated(Block *block) {}
291
292 protected:
293 Listener(Kind kind) : ListenerBase(kind) {}
294 };
295
296 /// Sets the listener of this builder to the one provided.
297 void setListener(Listener *newListener) { listener = newListener; }
298
299 /// Returns the current listener of this builder, or nullptr if this builder
300 /// doesn't have a listener.
301 Listener *getListener() const { return listener; }
302
303 //===--------------------------------------------------------------------===//
304 // Insertion Point Management
305 //===--------------------------------------------------------------------===//
306
307 /// This class represents a saved insertion point.
308 class InsertPoint {
309 public:
310 /// Creates a new insertion point which doesn't point to anything.
311 InsertPoint() = default;
312
313 /// Creates a new insertion point at the given location.
314 InsertPoint(Block *insertBlock, Block::iterator insertPt)
315 : block(insertBlock), point(insertPt) {}
316
317 /// Returns true if this insert point is set.
318 bool isSet() const { return (block != nullptr); }
319
320 Block *getBlock() const { return block; }
321 Block::iterator getPoint() const { return point; }
322
323 private:
324 Block *block = nullptr;
325 Block::iterator point;
326 };
327
328 /// RAII guard to reset the insertion point of the builder when destroyed.
329 class InsertionGuard {
330 public:
331 InsertionGuard(OpBuilder &builder)
332 : builder(&builder), ip(builder.saveInsertionPoint()) {}
333
334 ~InsertionGuard() {
335 if (builder)
336 builder->restoreInsertionPoint(ip);
337 }
338
339 InsertionGuard(const InsertionGuard &) = delete;
340 InsertionGuard &operator=(const InsertionGuard &) = delete;
341
342 /// Implement the move constructor to clear the builder field of `other`.
343 /// That way it does not restore the insertion point upon destruction as
344 /// that should be done exclusively by the just constructed InsertionGuard.
345 InsertionGuard(InsertionGuard &&other) noexcept
346 : builder(other.builder), ip(other.ip) {
347 other.builder = nullptr;
348 }
349
350 InsertionGuard &operator=(InsertionGuard &&other) = delete;
351
352 private:
353 OpBuilder *builder;
354 OpBuilder::InsertPoint ip;
355 };
356
357 /// Reset the insertion point to no location. Creating an operation without a
358 /// set insertion point is an error, but this can still be useful when the
359 /// current insertion point a builder refers to is being removed.
360 void clearInsertionPoint() {
361 this->block = nullptr;
362 insertPoint = Block::iterator();
363 }
364
365 /// Return a saved insertion point.
366 InsertPoint saveInsertionPoint() const {
367 return InsertPoint(getInsertionBlock(), getInsertionPoint());
368 }
369
370 /// Restore the insert point to a previously saved point.
371 void restoreInsertionPoint(InsertPoint ip) {
372 if (ip.isSet())
373 setInsertionPoint(ip.getBlock(), ip.getPoint());
374 else
375 clearInsertionPoint();
376 }
377
378 /// Set the insertion point to the specified location.
379 void setInsertionPoint(Block *block, Block::iterator insertPoint) {
380 // TODO: check that insertPoint is in this rather than some other block.
381 this->block = block;
382 this->insertPoint = insertPoint;
383 }
384
385 /// Sets the insertion point to the specified operation, which will cause
386 /// subsequent insertions to go right before it.
387 void setInsertionPoint(Operation *op) {
388 setInsertionPoint(op->getBlock(), Block::iterator(op));
389 }
390
391 /// Sets the insertion point to the node after the specified operation, which
392 /// will cause subsequent insertions to go right after it.
393 void setInsertionPointAfter(Operation *op) {
394 setInsertionPoint(op->getBlock(), ++Block::iterator(op));
395 }
396
397 /// Sets the insertion point to the node after the specified value. If value
398 /// has a defining operation, sets the insertion point to the node after such
399 /// defining operation. This will cause subsequent insertions to go right
400 /// after it. Otherwise, value is a BlockArgument. Sets the insertion point to
401 /// the start of its block.
402 void setInsertionPointAfterValue(Value val) {
403 if (Operation *op = val.getDefiningOp()) {
404 setInsertionPointAfter(op);
405 } else {
406 auto blockArg = val.cast<BlockArgument>();
407 setInsertionPointToStart(blockArg.getOwner());
408 }
409 }
410
411 /// Sets the insertion point to the start of the specified block.
412 void setInsertionPointToStart(Block *block) {
413 setInsertionPoint(block, block->begin());
414 }
415
416 /// Sets the insertion point to the end of the specified block.
417 void setInsertionPointToEnd(Block *block) {
418 setInsertionPoint(block, block->end());
419 }
420
421 /// Return the block the current insertion point belongs to. Note that the
422 /// insertion point is not necessarily the end of the block.
423 Block *getInsertionBlock() const { return block; }
424
425 /// Returns the current insertion point of the builder.
426 Block::iterator getInsertionPoint() const { return insertPoint; }
427
428 /// Returns the current block of the builder.
429 Block *getBlock() const { return block; }
430
431 //===--------------------------------------------------------------------===//
432 // Block Creation
433 //===--------------------------------------------------------------------===//
434
435 /// Add new block with 'argTypes' arguments and set the insertion point to the
436 /// end of it. The block is inserted at the provided insertion point of
437 /// 'parent'. `locs` contains the locations of the inserted arguments, and
438 /// should match the size of `argTypes`.
439 Block *createBlock(Region *parent, Region::iterator insertPt = {},
440 TypeRange argTypes = std::nullopt,
441 ArrayRef<Location> locs = std::nullopt);
442
443 /// Add new block with 'argTypes' arguments and set the insertion point to the
444 /// end of it. The block is placed before 'insertBefore'. `locs` contains the
445 /// locations of the inserted arguments, and should match the size of
446 /// `argTypes`.
447 Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt,
448 ArrayRef<Location> locs = std::nullopt);
449
450 //===--------------------------------------------------------------------===//
451 // Operation Creation
452 //===--------------------------------------------------------------------===//
453
454 /// Insert the given operation at the current insertion point and return it.
455 Operation *insert(Operation *op);
456
457 /// Creates an operation given the fields represented as an OperationState.
458 Operation *create(const OperationState &state);
459
460 /// Creates an operation with the given fields.
461 Operation *create(Location loc, StringAttr opName, ValueRange operands,
462 TypeRange types = {},
463 ArrayRef<NamedAttribute> attributes = {},
464 BlockRange successors = {},
465 MutableArrayRef<std::unique_ptr<Region>> regions = {});
466
467private:
468 /// Helper for sanity checking preconditions for create* methods below.
469 template <typename OpT>
470 RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) {
471 std::optional<RegisteredOperationName> opName =
472 RegisteredOperationName::lookup(OpT::getOperationName(), ctx);
473 if (LLVM_UNLIKELY(!opName)__builtin_expect((bool)(!opName), false)) {
474 llvm::report_fatal_error(
475 "Building op `" + OpT::getOperationName() +
476 "` but it isn't registered in this MLIRContext: the dialect may not "
477 "be loaded or this operation isn't registered by the dialect. See "
478 "also https://mlir.llvm.org/getting_started/Faq/"
479 "#registered-loaded-dependent-whats-up-with-dialects-management");
480 }
481 return *opName;
482 }
483
484public:
485 /// Create an operation of specific op type at the current insertion point.
486 template <typename OpTy, typename... Args>
487 OpTy create(Location location, Args &&...args) {
488 OperationState state(location,
489 getCheckRegisteredInfo<OpTy>(location.getContext()));
490 OpTy::build(*this, state, std::forward<Args>(args)...);
491 auto *op = create(state);
492 auto result = dyn_cast<OpTy>(op);
493 assert(result && "builder didn't return the right type")(static_cast <bool> (result && "builder didn't return the right type"
) ? void (0) : __assert_fail ("result && \"builder didn't return the right type\""
, "mlir/include/mlir/IR/Builders.h", 493, __extension__ __PRETTY_FUNCTION__
))
;
494 return result;
495 }
496
497 /// Create an operation of specific op type at the current insertion point,
498 /// and immediately try to fold it. This functions populates 'results' with
499 /// the results after folding the operation.
500 template <typename OpTy, typename... Args>
501 void createOrFold(SmallVectorImpl<Value> &results, Location location,
502 Args &&...args) {
503 // Create the operation without using 'create' as we don't want to
504 // insert it yet.
505 OperationState state(location,
506 getCheckRegisteredInfo<OpTy>(location.getContext()));
507 OpTy::build(*this, state, std::forward<Args>(args)...);
508 Operation *op = Operation::create(state);
509
510 // Fold the operation. If successful destroy it, otherwise insert it.
511 if (succeeded(tryFold(op, results)))
512 op->destroy();
513 else
514 insert(op);
515 }
516
517 /// Overload to create or fold a single result operation.
518 template <typename OpTy, typename... Args>
519 std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value>
520 createOrFold(Location location, Args &&...args) {
521 SmallVector<Value, 1> results;
522 createOrFold<OpTy>(results, location, std::forward<Args>(args)...);
523 return results.front();
524 }
525
526 /// Overload to create or fold a zero result operation.
527 template <typename OpTy, typename... Args>
528 std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy>
529 createOrFold(Location location, Args &&...args) {
530 auto op = create<OpTy>(location, std::forward<Args>(args)...);
531 SmallVector<Value, 0> unused;
532 (void)tryFold(op.getOperation(), unused);
533
534 // Folding cannot remove a zero-result operation, so for convenience we
535 // continue to return it.
536 return op;
537 }
538
539 /// Attempts to fold the given operation and places new results within
540 /// 'results'. Returns success if the operation was folded, failure otherwise.
541 /// Note: This function does not erase the operation on a successful fold.
542 LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
543
544 /// Creates a deep copy of the specified operation, remapping any operands
545 /// that use values outside of the operation using the map that is provided
546 /// ( leaving them alone if no entry is present). Replaces references to
547 /// cloned sub-operations to the corresponding operation that is copied,
548 /// and adds those mappings to the map.
549 Operation *clone(Operation &op, IRMapping &mapper);
550 Operation *clone(Operation &op);
551
552 /// Creates a deep copy of this operation but keep the operation regions
553 /// empty. Operands are remapped using `mapper` (if present), and `mapper` is
554 /// updated to contain the results.
555 Operation *cloneWithoutRegions(Operation &op, IRMapping &mapper) {
556 return insert(op.cloneWithoutRegions(mapper));
557 }
558 Operation *cloneWithoutRegions(Operation &op) {
559 return insert(op.cloneWithoutRegions());
560 }
561 template <typename OpT>
562 OpT cloneWithoutRegions(OpT op) {
563 return cast<OpT>(cloneWithoutRegions(*op.getOperation()));
564 }
565
566protected:
567 /// The optional listener for events of this builder.
568 Listener *listener;
569
570private:
571 /// The current block this builder is inserting into.
572 Block *block = nullptr;
573 /// The insertion point within the block that this builder is inserting
574 /// before.
575 Block::iterator insertPoint;
576};
577
578} // namespace mlir
579
580#endif