File: | build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp |
Warning: | line 4209, column 24 2nd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file defines the operations in the SPIR-V dialect. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" | |||
14 | ||||
15 | #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" | |||
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" | |||
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | |||
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" | |||
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" | |||
20 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" | |||
21 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" | |||
22 | #include "mlir/IR/Builders.h" | |||
23 | #include "mlir/IR/BuiltinTypes.h" | |||
24 | #include "mlir/IR/FunctionImplementation.h" | |||
25 | #include "mlir/IR/OpDefinition.h" | |||
26 | #include "mlir/IR/OpImplementation.h" | |||
27 | #include "mlir/IR/Operation.h" | |||
28 | #include "mlir/IR/TypeUtilities.h" | |||
29 | #include "mlir/Interfaces/CallInterfaces.h" | |||
30 | #include "llvm/ADT/APFloat.h" | |||
31 | #include "llvm/ADT/APInt.h" | |||
32 | #include "llvm/ADT/ArrayRef.h" | |||
33 | #include "llvm/ADT/STLExtras.h" | |||
34 | #include "llvm/ADT/StringExtras.h" | |||
35 | #include "llvm/Support/FormatVariadic.h" | |||
36 | #include <cassert> | |||
37 | #include <numeric> | |||
38 | ||||
39 | using namespace mlir; | |||
40 | ||||
41 | // TODO: generate these strings using ODS. | |||
42 | constexpr char kAlignmentAttrName[] = "alignment"; | |||
43 | constexpr char kBranchWeightAttrName[] = "branch_weights"; | |||
44 | constexpr char kCallee[] = "callee"; | |||
45 | constexpr char kClusterSize[] = "cluster_size"; | |||
46 | constexpr char kControl[] = "control"; | |||
47 | constexpr char kDefaultValueAttrName[] = "default_value"; | |||
48 | constexpr char kEqualSemanticsAttrName[] = "equal_semantics"; | |||
49 | constexpr char kExecutionScopeAttrName[] = "execution_scope"; | |||
50 | constexpr char kFnNameAttrName[] = "fn"; | |||
51 | constexpr char kGroupOperationAttrName[] = "group_operation"; | |||
52 | constexpr char kIndicesAttrName[] = "indices"; | |||
53 | constexpr char kInitializerAttrName[] = "initializer"; | |||
54 | constexpr char kInterfaceAttrName[] = "interface"; | |||
55 | constexpr char kMemoryAccessAttrName[] = "memory_access"; | |||
56 | constexpr char kMemoryScopeAttrName[] = "memory_scope"; | |||
57 | constexpr char kPackedVectorFormatAttrName[] = "format"; | |||
58 | constexpr char kSemanticsAttrName[] = "semantics"; | |||
59 | constexpr char kSourceAlignmentAttrName[] = "source_alignment"; | |||
60 | constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; | |||
61 | constexpr char kSpecIdAttrName[] = "spec_id"; | |||
62 | constexpr char kTypeAttrName[] = "type"; | |||
63 | constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics"; | |||
64 | constexpr char kValueAttrName[] = "value"; | |||
65 | constexpr char kValuesAttrName[] = "values"; | |||
66 | constexpr char kCompositeSpecConstituentsName[] = "constituents"; | |||
67 | ||||
68 | //===----------------------------------------------------------------------===// | |||
69 | // Common utility functions | |||
70 | //===----------------------------------------------------------------------===// | |||
71 | ||||
72 | static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, | |||
73 | OperationState &result) { | |||
74 | SmallVector<OpAsmParser::UnresolvedOperand, 2> ops; | |||
75 | Type type; | |||
76 | // If the operand list is in-between parentheses, then we have a generic form. | |||
77 | // (see the fallback in `printOneResultOp`). | |||
78 | SMLoc loc = parser.getCurrentLocation(); | |||
79 | if (!parser.parseOptionalLParen()) { | |||
80 | if (parser.parseOperandList(ops) || parser.parseRParen() || | |||
81 | parser.parseOptionalAttrDict(result.attributes) || | |||
82 | parser.parseColon() || parser.parseType(type)) | |||
83 | return failure(); | |||
84 | auto fnType = type.dyn_cast<FunctionType>(); | |||
85 | if (!fnType) { | |||
86 | parser.emitError(loc, "expected function type"); | |||
87 | return failure(); | |||
88 | } | |||
89 | if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) | |||
90 | return failure(); | |||
91 | result.addTypes(fnType.getResults()); | |||
92 | return success(); | |||
93 | } | |||
94 | return failure(parser.parseOperandList(ops) || | |||
95 | parser.parseOptionalAttrDict(result.attributes) || | |||
96 | parser.parseColonType(type) || | |||
97 | parser.resolveOperands(ops, type, result.operands) || | |||
98 | parser.addTypeToList(type, result.types)); | |||
99 | } | |||
100 | ||||
101 | static void printOneResultOp(Operation *op, OpAsmPrinter &p) { | |||
102 | assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 && "op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 102, __extension__ __PRETTY_FUNCTION__)); | |||
103 | ||||
104 | // If not all the operand and result types are the same, just use the | |||
105 | // generic assembly form to avoid omitting information in printing. | |||
106 | auto resultType = op->getResult(0).getType(); | |||
107 | if (llvm::any_of(op->getOperandTypes(), | |||
108 | [&](Type type) { return type != resultType; })) { | |||
109 | p.printGenericOp(op, /*printOpName=*/false); | |||
110 | return; | |||
111 | } | |||
112 | ||||
113 | p << ' '; | |||
114 | p.printOperands(op->getOperands()); | |||
115 | p.printOptionalAttrDict(op->getAttrs()); | |||
116 | // Now we can output only one type for all operands and the result. | |||
117 | p << " : " << resultType; | |||
118 | } | |||
119 | ||||
120 | /// Returns true if the given op is a function-like op or nested in a | |||
121 | /// function-like op without a module-like op in the middle. | |||
122 | static bool isNestedInFunctionOpInterface(Operation *op) { | |||
123 | if (!op) | |||
124 | return false; | |||
125 | if (op->hasTrait<OpTrait::SymbolTable>()) | |||
126 | return false; | |||
127 | if (isa<FunctionOpInterface>(op)) | |||
128 | return true; | |||
129 | return isNestedInFunctionOpInterface(op->getParentOp()); | |||
130 | } | |||
131 | ||||
132 | /// Returns true if the given op is an module-like op that maintains a symbol | |||
133 | /// table. | |||
134 | static bool isDirectInModuleLikeOp(Operation *op) { | |||
135 | return op && op->hasTrait<OpTrait::SymbolTable>(); | |||
136 | } | |||
137 | ||||
138 | static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { | |||
139 | auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op); | |||
140 | if (!constOp) { | |||
141 | return failure(); | |||
142 | } | |||
143 | auto valueAttr = constOp.getValue(); | |||
144 | auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>(); | |||
145 | if (!integerValueAttr) { | |||
146 | return failure(); | |||
147 | } | |||
148 | ||||
149 | if (integerValueAttr.getType().isSignlessInteger()) | |||
150 | value = integerValueAttr.getInt(); | |||
151 | else | |||
152 | value = integerValueAttr.getSInt(); | |||
153 | ||||
154 | return success(); | |||
155 | } | |||
156 | ||||
157 | template <typename Ty> | |||
158 | static ArrayAttr | |||
159 | getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, | |||
160 | function_ref<StringRef(Ty)> stringifyFn) { | |||
161 | if (enumValues.empty()) { | |||
162 | return nullptr; | |||
163 | } | |||
164 | SmallVector<StringRef, 1> enumValStrs; | |||
165 | enumValStrs.reserve(enumValues.size()); | |||
166 | for (auto val : enumValues) { | |||
167 | enumValStrs.emplace_back(stringifyFn(val)); | |||
168 | } | |||
169 | return builder.getStrArrayAttr(enumValStrs); | |||
170 | } | |||
171 | ||||
172 | /// Parses the next string attribute in `parser` as an enumerant of the given | |||
173 | /// `EnumClass`. | |||
174 | template <typename EnumClass> | |||
175 | static ParseResult | |||
176 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, | |||
177 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
178 | Attribute attrVal; | |||
179 | NamedAttrList attr; | |||
180 | auto loc = parser.getCurrentLocation(); | |||
181 | if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), | |||
182 | attrName, attr)) | |||
183 | return failure(); | |||
184 | if (!attrVal.isa<StringAttr>()) | |||
185 | return parser.emitError(loc, "expected ") | |||
186 | << attrName << " attribute specified as string"; | |||
187 | auto attrOptional = | |||
188 | spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue()); | |||
189 | if (!attrOptional) | |||
190 | return parser.emitError(loc, "invalid ") | |||
191 | << attrName << " attribute specification: " << attrVal; | |||
192 | value = *attrOptional; | |||
193 | return success(); | |||
194 | } | |||
195 | ||||
196 | /// Parses the next string attribute in `parser` as an enumerant of the given | |||
197 | /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer | |||
198 | /// attribute with the enum class's name as attribute name. | |||
199 | template <typename EnumAttrClass, | |||
200 | typename EnumClass = typename EnumAttrClass::ValueType> | |||
201 | static ParseResult | |||
202 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, | |||
203 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
204 | if (parseEnumStrAttr(value, parser)) | |||
205 | return failure(); | |||
206 | state.addAttribute(attrName, | |||
207 | parser.getBuilder().getAttr<EnumAttrClass>(value)); | |||
208 | return success(); | |||
209 | } | |||
210 | ||||
211 | /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` | |||
212 | /// and inserts the enumerant into `state` as an 32-bit integer attribute with | |||
213 | /// the enum class's name as attribute name. | |||
214 | template <typename EnumAttrClass, | |||
215 | typename EnumClass = typename EnumAttrClass::ValueType> | |||
216 | static ParseResult | |||
217 | parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, | |||
218 | OperationState &state, | |||
219 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
220 | if (parseEnumKeywordAttr(value, parser)) | |||
221 | return failure(); | |||
222 | state.addAttribute(attrName, | |||
223 | parser.getBuilder().getAttr<EnumAttrClass>(value)); | |||
224 | return success(); | |||
225 | } | |||
226 | ||||
227 | /// Parses Function, Selection and Loop control attributes. If no control is | |||
228 | /// specified, "None" is used as a default. | |||
229 | template <typename EnumAttrClass, typename EnumClass> | |||
230 | static ParseResult | |||
231 | parseControlAttribute(OpAsmParser &parser, OperationState &state, | |||
232 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
233 | if (succeeded(parser.parseOptionalKeyword(kControl))) { | |||
234 | EnumClass control; | |||
235 | if (parser.parseLParen() || | |||
236 | parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) || | |||
237 | parser.parseRParen()) | |||
238 | return failure(); | |||
239 | return success(); | |||
240 | } | |||
241 | // Set control to "None" otherwise. | |||
242 | Builder builder = parser.getBuilder(); | |||
243 | state.addAttribute(attrName, | |||
244 | builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0))); | |||
245 | return success(); | |||
246 | } | |||
247 | ||||
248 | /// Parses optional memory access attributes attached to a memory access | |||
249 | /// operand/pointer. Specifically, parses the following syntax: | |||
250 | /// (`[` memory-access `]`)? | |||
251 | /// where: | |||
252 | /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` | |||
253 | /// integer-literal | `"NonTemporal"` | |||
254 | static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, | |||
255 | OperationState &state) { | |||
256 | // Parse an optional list of attributes staring with '[' | |||
257 | if (parser.parseOptionalLSquare()) { | |||
258 | // Nothing to do | |||
259 | return success(); | |||
260 | } | |||
261 | ||||
262 | spirv::MemoryAccess memoryAccessAttr; | |||
263 | if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state, | |||
264 | kMemoryAccessAttrName)) | |||
265 | return failure(); | |||
266 | ||||
267 | if (spirv::bitEnumContainsAll(memoryAccessAttr, | |||
268 | spirv::MemoryAccess::Aligned)) { | |||
269 | // Parse integer attribute for alignment. | |||
270 | Attribute alignmentAttr; | |||
271 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
272 | if (parser.parseComma() || | |||
273 | parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, | |||
274 | state.attributes)) { | |||
275 | return failure(); | |||
276 | } | |||
277 | } | |||
278 | return parser.parseRSquare(); | |||
279 | } | |||
280 | ||||
281 | // TODO Make sure to merge this and the previous function into one template | |||
282 | // parameterized by memory access attribute name and alignment. Doing so now | |||
283 | // results in VS2017 in producing an internal error (at the call site) that's | |||
284 | // not detailed enough to understand what is happening. | |||
285 | static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, | |||
286 | OperationState &state) { | |||
287 | // Parse an optional list of attributes staring with '[' | |||
288 | if (parser.parseOptionalLSquare()) { | |||
289 | // Nothing to do | |||
290 | return success(); | |||
291 | } | |||
292 | ||||
293 | spirv::MemoryAccess memoryAccessAttr; | |||
294 | if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state, | |||
295 | kSourceMemoryAccessAttrName)) | |||
296 | return failure(); | |||
297 | ||||
298 | if (spirv::bitEnumContainsAll(memoryAccessAttr, | |||
299 | spirv::MemoryAccess::Aligned)) { | |||
300 | // Parse integer attribute for alignment. | |||
301 | Attribute alignmentAttr; | |||
302 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
303 | if (parser.parseComma() || | |||
304 | parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, | |||
305 | state.attributes)) { | |||
306 | return failure(); | |||
307 | } | |||
308 | } | |||
309 | return parser.parseRSquare(); | |||
310 | } | |||
311 | ||||
312 | template <typename MemoryOpTy> | |||
313 | static void printMemoryAccessAttribute( | |||
314 | MemoryOpTy memoryOp, OpAsmPrinter &printer, | |||
315 | SmallVectorImpl<StringRef> &elidedAttrs, | |||
316 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, | |||
317 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { | |||
318 | // Print optional memory access attribute. | |||
319 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue | |||
320 | : memoryOp.getMemoryAccess())) { | |||
321 | elidedAttrs.push_back(kMemoryAccessAttrName); | |||
322 | ||||
323 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; | |||
324 | ||||
325 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { | |||
326 | // Print integer alignment attribute. | |||
327 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue | |||
328 | : memoryOp.getAlignment())) { | |||
329 | elidedAttrs.push_back(kAlignmentAttrName); | |||
330 | printer << ", " << *alignment; | |||
331 | } | |||
332 | } | |||
333 | printer << "]"; | |||
334 | } | |||
335 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); | |||
336 | } | |||
337 | ||||
338 | // TODO Make sure to merge this and the previous function into one template | |||
339 | // parameterized by memory access attribute name and alignment. Doing so now | |||
340 | // results in VS2017 in producing an internal error (at the call site) that's | |||
341 | // not detailed enough to understand what is happening. | |||
342 | template <typename MemoryOpTy> | |||
343 | static void printSourceMemoryAccessAttribute( | |||
344 | MemoryOpTy memoryOp, OpAsmPrinter &printer, | |||
345 | SmallVectorImpl<StringRef> &elidedAttrs, | |||
346 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, | |||
347 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { | |||
348 | ||||
349 | printer << ", "; | |||
350 | ||||
351 | // Print optional memory access attribute. | |||
352 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue | |||
353 | : memoryOp.getMemoryAccess())) { | |||
354 | elidedAttrs.push_back(kSourceMemoryAccessAttrName); | |||
355 | ||||
356 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; | |||
357 | ||||
358 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { | |||
359 | // Print integer alignment attribute. | |||
360 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue | |||
361 | : memoryOp.getAlignment())) { | |||
362 | elidedAttrs.push_back(kSourceAlignmentAttrName); | |||
363 | printer << ", " << *alignment; | |||
364 | } | |||
365 | } | |||
366 | printer << "]"; | |||
367 | } | |||
368 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); | |||
369 | } | |||
370 | ||||
371 | static ParseResult parseImageOperands(OpAsmParser &parser, | |||
372 | spirv::ImageOperandsAttr &attr) { | |||
373 | // Expect image operands | |||
374 | if (parser.parseOptionalLSquare()) | |||
375 | return success(); | |||
376 | ||||
377 | spirv::ImageOperands imageOperands; | |||
378 | if (parseEnumStrAttr(imageOperands, parser)) | |||
379 | return failure(); | |||
380 | ||||
381 | attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands); | |||
382 | ||||
383 | return parser.parseRSquare(); | |||
384 | } | |||
385 | ||||
386 | static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, | |||
387 | spirv::ImageOperandsAttr attr) { | |||
388 | if (attr) { | |||
389 | auto strImageOperands = stringifyImageOperands(attr.getValue()); | |||
390 | printer << "[\"" << strImageOperands << "\"]"; | |||
391 | } | |||
392 | } | |||
393 | ||||
394 | template <typename Op> | |||
395 | static LogicalResult verifyImageOperands(Op imageOp, | |||
396 | spirv::ImageOperandsAttr attr, | |||
397 | Operation::operand_range operands) { | |||
398 | if (!attr) { | |||
399 | if (operands.empty()) | |||
400 | return success(); | |||
401 | ||||
402 | return imageOp.emitError("the Image Operands should encode what operands " | |||
403 | "follow, as per Image Operands"); | |||
404 | } | |||
405 | ||||
406 | // TODO: Add the validation rules for the following Image Operands. | |||
407 | spirv::ImageOperands noSupportOperands = | |||
408 | spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | | |||
409 | spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | | |||
410 | spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | | |||
411 | spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | | |||
412 | spirv::ImageOperands::MakeTexelAvailable | | |||
413 | spirv::ImageOperands::MakeTexelVisible | | |||
414 | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; | |||
415 | ||||
416 | if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) | |||
417 | llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 417); | |||
418 | ||||
419 | return success(); | |||
420 | } | |||
421 | ||||
422 | static LogicalResult verifyCastOp(Operation *op, | |||
423 | bool requireSameBitWidth = true, | |||
424 | bool skipBitWidthCheck = false) { | |||
425 | // Some CastOps have no limit on bit widths for result and operand type. | |||
426 | if (skipBitWidthCheck) | |||
427 | return success(); | |||
428 | ||||
429 | Type operandType = op->getOperand(0).getType(); | |||
430 | Type resultType = op->getResult(0).getType(); | |||
431 | ||||
432 | // ODS checks that result type and operand type have the same shape. | |||
433 | if (auto vectorType = operandType.dyn_cast<VectorType>()) { | |||
434 | operandType = vectorType.getElementType(); | |||
435 | resultType = resultType.cast<VectorType>().getElementType(); | |||
436 | } | |||
437 | ||||
438 | if (auto coopMatrixType = | |||
439 | operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
440 | operandType = coopMatrixType.getElementType(); | |||
441 | resultType = | |||
442 | resultType.cast<spirv::CooperativeMatrixNVType>().getElementType(); | |||
443 | } | |||
444 | ||||
445 | if (auto jointMatrixType = | |||
446 | operandType.dyn_cast<spirv::JointMatrixINTELType>()) { | |||
447 | operandType = jointMatrixType.getElementType(); | |||
448 | resultType = | |||
449 | resultType.cast<spirv::JointMatrixINTELType>().getElementType(); | |||
450 | } | |||
451 | ||||
452 | auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); | |||
453 | auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); | |||
454 | auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; | |||
455 | ||||
456 | if (requireSameBitWidth) { | |||
457 | if (!isSameBitWidth) { | |||
458 | return op->emitOpError( | |||
459 | "expected the same bit widths for operand type and result " | |||
460 | "type, but provided ") | |||
461 | << operandType << " and " << resultType; | |||
462 | } | |||
463 | return success(); | |||
464 | } | |||
465 | ||||
466 | if (isSameBitWidth) { | |||
467 | return op->emitOpError( | |||
468 | "expected the different bit widths for operand type and result " | |||
469 | "type, but provided ") | |||
470 | << operandType << " and " << resultType; | |||
471 | } | |||
472 | return success(); | |||
473 | } | |||
474 | ||||
475 | template <typename MemoryOpTy> | |||
476 | static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { | |||
477 | // ODS checks for attributes values. Just need to verify that if the | |||
478 | // memory-access attribute is Aligned, then the alignment attribute must be | |||
479 | // present. | |||
480 | auto *op = memoryOp.getOperation(); | |||
481 | auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); | |||
482 | if (!memAccessAttr) { | |||
483 | // Alignment attribute shouldn't be present if memory access attribute is | |||
484 | // not present. | |||
485 | if (op->getAttr(kAlignmentAttrName)) { | |||
486 | return memoryOp.emitOpError( | |||
487 | "invalid alignment specification without aligned memory access " | |||
488 | "specification"); | |||
489 | } | |||
490 | return success(); | |||
491 | } | |||
492 | ||||
493 | auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>(); | |||
494 | ||||
495 | if (!memAccess) { | |||
496 | return memoryOp.emitOpError("invalid memory access specifier: ") | |||
497 | << memAccessAttr; | |||
498 | } | |||
499 | ||||
500 | if (spirv::bitEnumContainsAll(memAccess.getValue(), | |||
501 | spirv::MemoryAccess::Aligned)) { | |||
502 | if (!op->getAttr(kAlignmentAttrName)) { | |||
503 | return memoryOp.emitOpError("missing alignment value"); | |||
504 | } | |||
505 | } else { | |||
506 | if (op->getAttr(kAlignmentAttrName)) { | |||
507 | return memoryOp.emitOpError( | |||
508 | "invalid alignment specification with non-aligned memory access " | |||
509 | "specification"); | |||
510 | } | |||
511 | } | |||
512 | return success(); | |||
513 | } | |||
514 | ||||
515 | // TODO Make sure to merge this and the previous function into one template | |||
516 | // parameterized by memory access attribute name and alignment. Doing so now | |||
517 | // results in VS2017 in producing an internal error (at the call site) that's | |||
518 | // not detailed enough to understand what is happening. | |||
519 | template <typename MemoryOpTy> | |||
520 | static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { | |||
521 | // ODS checks for attributes values. Just need to verify that if the | |||
522 | // memory-access attribute is Aligned, then the alignment attribute must be | |||
523 | // present. | |||
524 | auto *op = memoryOp.getOperation(); | |||
525 | auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); | |||
526 | if (!memAccessAttr) { | |||
527 | // Alignment attribute shouldn't be present if memory access attribute is | |||
528 | // not present. | |||
529 | if (op->getAttr(kSourceAlignmentAttrName)) { | |||
530 | return memoryOp.emitOpError( | |||
531 | "invalid alignment specification without aligned memory access " | |||
532 | "specification"); | |||
533 | } | |||
534 | return success(); | |||
535 | } | |||
536 | ||||
537 | auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>(); | |||
538 | ||||
539 | if (!memAccess) { | |||
540 | return memoryOp.emitOpError("invalid memory access specifier: ") | |||
541 | << memAccess; | |||
542 | } | |||
543 | ||||
544 | if (spirv::bitEnumContainsAll(memAccess.getValue(), | |||
545 | spirv::MemoryAccess::Aligned)) { | |||
546 | if (!op->getAttr(kSourceAlignmentAttrName)) { | |||
547 | return memoryOp.emitOpError("missing alignment value"); | |||
548 | } | |||
549 | } else { | |||
550 | if (op->getAttr(kSourceAlignmentAttrName)) { | |||
551 | return memoryOp.emitOpError( | |||
552 | "invalid alignment specification with non-aligned memory access " | |||
553 | "specification"); | |||
554 | } | |||
555 | } | |||
556 | return success(); | |||
557 | } | |||
558 | ||||
559 | static LogicalResult | |||
560 | verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) { | |||
561 | // According to the SPIR-V specification: | |||
562 | // "Despite being a mask and allowing multiple bits to be combined, it is | |||
563 | // invalid for more than one of these four bits to be set: Acquire, Release, | |||
564 | // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and | |||
565 | // Release semantics is done by setting the AcquireRelease bit, not by setting | |||
566 | // two bits." | |||
567 | auto atMostOneInSet = spirv::MemorySemantics::Acquire | | |||
568 | spirv::MemorySemantics::Release | | |||
569 | spirv::MemorySemantics::AcquireRelease | | |||
570 | spirv::MemorySemantics::SequentiallyConsistent; | |||
571 | ||||
572 | auto bitCount = llvm::countPopulation( | |||
573 | static_cast<uint32_t>(memorySemantics & atMostOneInSet)); | |||
574 | if (bitCount > 1) { | |||
575 | return op->emitError( | |||
576 | "expected at most one of these four memory constraints " | |||
577 | "to be set: `Acquire`, `Release`," | |||
578 | "`AcquireRelease` or `SequentiallyConsistent`"); | |||
579 | } | |||
580 | return success(); | |||
581 | } | |||
582 | ||||
583 | template <typename LoadStoreOpTy> | |||
584 | static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, | |||
585 | Value val) { | |||
586 | // ODS already checks ptr is spirv::PointerType. Just check that the pointee | |||
587 | // type of the pointer and the type of the value are the same | |||
588 | // | |||
589 | // TODO: Check that the value type satisfies restrictions of | |||
590 | // SPIR-V OpLoad/OpStore operations | |||
591 | if (val.getType() != | |||
592 | ptr.getType().cast<spirv::PointerType>().getPointeeType()) { | |||
593 | return op.emitOpError("mismatch in result type and pointer type"); | |||
594 | } | |||
595 | return success(); | |||
596 | } | |||
597 | ||||
598 | template <typename BlockReadWriteOpTy> | |||
599 | static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, | |||
600 | Value ptr, Value val) { | |||
601 | auto valType = val.getType(); | |||
602 | if (auto valVecTy = valType.dyn_cast<VectorType>()) | |||
603 | valType = valVecTy.getElementType(); | |||
604 | ||||
605 | if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) { | |||
606 | return op.emitOpError("mismatch in result type and pointer type"); | |||
607 | } | |||
608 | return success(); | |||
609 | } | |||
610 | ||||
611 | static ParseResult parseVariableDecorations(OpAsmParser &parser, | |||
612 | OperationState &state) { | |||
613 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
614 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
615 | if (succeeded(parser.parseOptionalKeyword("bind"))) { | |||
616 | Attribute set, binding; | |||
617 | // Parse optional descriptor binding | |||
618 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
619 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
620 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
621 | stringifyDecoration(spirv::Decoration::Binding)); | |||
622 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
623 | if (parser.parseLParen() || | |||
624 | parser.parseAttribute(set, i32Type, descriptorSetName, | |||
625 | state.attributes) || | |||
626 | parser.parseComma() || | |||
627 | parser.parseAttribute(binding, i32Type, bindingName, | |||
628 | state.attributes) || | |||
629 | parser.parseRParen()) { | |||
630 | return failure(); | |||
631 | } | |||
632 | } else if (succeeded(parser.parseOptionalKeyword(builtInName))) { | |||
633 | StringAttr builtIn; | |||
634 | if (parser.parseLParen() || | |||
635 | parser.parseAttribute(builtIn, builtInName, state.attributes) || | |||
636 | parser.parseRParen()) { | |||
637 | return failure(); | |||
638 | } | |||
639 | } | |||
640 | ||||
641 | // Parse other attributes | |||
642 | if (parser.parseOptionalAttrDict(state.attributes)) | |||
643 | return failure(); | |||
644 | ||||
645 | return success(); | |||
646 | } | |||
647 | ||||
648 | static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, | |||
649 | SmallVectorImpl<StringRef> &elidedAttrs) { | |||
650 | // Print optional descriptor binding | |||
651 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
652 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
653 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
654 | stringifyDecoration(spirv::Decoration::Binding)); | |||
655 | auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); | |||
656 | auto binding = op->getAttrOfType<IntegerAttr>(bindingName); | |||
657 | if (descriptorSet && binding) { | |||
658 | elidedAttrs.push_back(descriptorSetName); | |||
659 | elidedAttrs.push_back(bindingName); | |||
660 | printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() | |||
661 | << ")"; | |||
662 | } | |||
663 | ||||
664 | // Print BuiltIn attribute if present | |||
665 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
666 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
667 | if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { | |||
668 | printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; | |||
669 | elidedAttrs.push_back(builtInName); | |||
670 | } | |||
671 | ||||
672 | printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); | |||
673 | } | |||
674 | ||||
675 | // Get bit width of types. | |||
676 | static unsigned getBitWidth(Type type) { | |||
677 | if (type.isa<spirv::PointerType>()) { | |||
678 | // Just return 64 bits for pointer types for now. | |||
679 | // TODO: Make sure not caller relies on the actual pointer width value. | |||
680 | return 64; | |||
681 | } | |||
682 | ||||
683 | if (type.isIntOrFloat()) | |||
684 | return type.getIntOrFloatBitWidth(); | |||
685 | ||||
686 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
687 | assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat ()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 687, __extension__ __PRETTY_FUNCTION__)); | |||
688 | return vectorType.getNumElements() * | |||
689 | vectorType.getElementType().getIntOrFloatBitWidth(); | |||
690 | } | |||
691 | llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 691); | |||
692 | } | |||
693 | ||||
694 | /// Walks the given type hierarchy with the given indices, potentially down | |||
695 | /// to component granularity, to select an element type. Returns null type and | |||
696 | /// emits errors with the given loc on failure. | |||
697 | static Type | |||
698 | getElementType(Type type, ArrayRef<int32_t> indices, | |||
699 | function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { | |||
700 | if (indices.empty()) { | |||
701 | emitErrorFn("expected at least one index for spirv.CompositeExtract"); | |||
702 | return nullptr; | |||
703 | } | |||
704 | ||||
705 | for (auto index : indices) { | |||
706 | if (auto cType = type.dyn_cast<spirv::CompositeType>()) { | |||
707 | if (cType.hasCompileTimeKnownNumElements() && | |||
708 | (index < 0 || | |||
709 | static_cast<uint64_t>(index) >= cType.getNumElements())) { | |||
710 | emitErrorFn("index ") << index << " out of bounds for " << type; | |||
711 | return nullptr; | |||
712 | } | |||
713 | type = cType.getElementType(index); | |||
714 | } else { | |||
715 | emitErrorFn("cannot extract from non-composite type ") | |||
716 | << type << " with index " << index; | |||
717 | return nullptr; | |||
718 | } | |||
719 | } | |||
720 | return type; | |||
721 | } | |||
722 | ||||
723 | static Type | |||
724 | getElementType(Type type, Attribute indices, | |||
725 | function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { | |||
726 | auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>(); | |||
727 | if (!indicesArrayAttr) { | |||
728 | emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); | |||
729 | return nullptr; | |||
730 | } | |||
731 | if (indicesArrayAttr.empty()) { | |||
732 | emitErrorFn("expected at least one index for spirv.CompositeExtract"); | |||
733 | return nullptr; | |||
734 | } | |||
735 | ||||
736 | SmallVector<int32_t, 2> indexVals; | |||
737 | for (auto indexAttr : indicesArrayAttr) { | |||
738 | auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>(); | |||
739 | if (!indexIntAttr) { | |||
740 | emitErrorFn("expected an 32-bit integer for index, but found '") | |||
741 | << indexAttr << "'"; | |||
742 | return nullptr; | |||
743 | } | |||
744 | indexVals.push_back(indexIntAttr.getInt()); | |||
745 | } | |||
746 | return getElementType(type, indexVals, emitErrorFn); | |||
747 | } | |||
748 | ||||
749 | static Type getElementType(Type type, Attribute indices, Location loc) { | |||
750 | auto errorFn = [&](StringRef err) -> InFlightDiagnostic { | |||
751 | return ::mlir::emitError(loc, err); | |||
752 | }; | |||
753 | return getElementType(type, indices, errorFn); | |||
754 | } | |||
755 | ||||
756 | static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, | |||
757 | SMLoc loc) { | |||
758 | auto errorFn = [&](StringRef err) -> InFlightDiagnostic { | |||
759 | return parser.emitError(loc, err); | |||
760 | }; | |||
761 | return getElementType(type, indices, errorFn); | |||
762 | } | |||
763 | ||||
764 | /// Returns true if the given `block` only contains one `spirv.mlir.merge` op. | |||
765 | static inline bool isMergeBlock(Block &block) { | |||
766 | return !block.empty() && std::next(block.begin()) == block.end() && | |||
767 | isa<spirv::MergeOp>(block.front()); | |||
768 | } | |||
769 | ||||
770 | template <typename ExtendedBinaryOp> | |||
771 | static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { | |||
772 | auto resultType = op.getType().template cast<spirv::StructType>(); | |||
773 | if (resultType.getNumElements() != 2) | |||
774 | return op.emitOpError("expected result struct type containing two members"); | |||
775 | ||||
776 | if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), | |||
777 | resultType.getElementType(0), | |||
778 | resultType.getElementType(1)})) | |||
779 | return op.emitOpError( | |||
780 | "expected all operand types and struct member types are the same"); | |||
781 | ||||
782 | return success(); | |||
783 | } | |||
784 | ||||
785 | static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, | |||
786 | OperationState &result) { | |||
787 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; | |||
788 | if (parser.parseOptionalAttrDict(result.attributes) || | |||
789 | parser.parseOperandList(operands) || parser.parseColon()) | |||
790 | return failure(); | |||
791 | ||||
792 | Type resultType; | |||
793 | SMLoc loc = parser.getCurrentLocation(); | |||
794 | if (parser.parseType(resultType)) | |||
795 | return failure(); | |||
796 | ||||
797 | auto structType = resultType.dyn_cast<spirv::StructType>(); | |||
798 | if (!structType || structType.getNumElements() != 2) | |||
799 | return parser.emitError(loc, "expected spirv.struct type with two members"); | |||
800 | ||||
801 | SmallVector<Type, 2> operandTypes(2, structType.getElementType(0)); | |||
802 | if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) | |||
803 | return failure(); | |||
804 | ||||
805 | result.addTypes(resultType); | |||
806 | return success(); | |||
807 | } | |||
808 | ||||
809 | static void printArithmeticExtendedBinaryOp(Operation *op, | |||
810 | OpAsmPrinter &printer) { | |||
811 | printer << ' '; | |||
812 | printer.printOptionalAttrDict(op->getAttrs()); | |||
813 | printer.printOperands(op->getOperands()); | |||
814 | printer << " : " << op->getResultTypes().front(); | |||
815 | } | |||
816 | ||||
817 | //===----------------------------------------------------------------------===// | |||
818 | // Common parsers and printers | |||
819 | //===----------------------------------------------------------------------===// | |||
820 | ||||
821 | // Parses an atomic update op. If the update op does not take a value (like | |||
822 | // AtomicIIncrement) `hasValue` must be false. | |||
823 | static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, | |||
824 | OperationState &state, bool hasValue) { | |||
825 | spirv::Scope scope; | |||
826 | spirv::MemorySemantics memoryScope; | |||
827 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
828 | OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; | |||
829 | Type type; | |||
830 | SMLoc loc; | |||
831 | if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state, | |||
832 | kMemoryScopeAttrName) || | |||
833 | parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state, | |||
834 | kSemanticsAttrName) || | |||
835 | parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || | |||
836 | parser.getCurrentLocation(&loc) || parser.parseColonType(type)) | |||
837 | return failure(); | |||
838 | ||||
839 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
840 | if (!ptrType) | |||
841 | return parser.emitError(loc, "expected pointer type"); | |||
842 | ||||
843 | SmallVector<Type, 2> operandTypes; | |||
844 | operandTypes.push_back(ptrType); | |||
845 | if (hasValue) | |||
846 | operandTypes.push_back(ptrType.getPointeeType()); | |||
847 | if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), | |||
848 | state.operands)) | |||
849 | return failure(); | |||
850 | return parser.addTypeToList(ptrType.getPointeeType(), state.types); | |||
851 | } | |||
852 | ||||
853 | // Prints an atomic update op. | |||
854 | static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { | |||
855 | printer << " \""; | |||
856 | auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName); | |||
857 | printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \""; | |||
858 | auto memorySemanticsAttr = | |||
859 | op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName); | |||
860 | printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue()) | |||
861 | << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); | |||
862 | } | |||
863 | ||||
864 | template <typename T> | |||
865 | static StringRef stringifyTypeName(); | |||
866 | ||||
867 | template <> | |||
868 | StringRef stringifyTypeName<IntegerType>() { | |||
869 | return "integer"; | |||
870 | } | |||
871 | ||||
872 | template <> | |||
873 | StringRef stringifyTypeName<FloatType>() { | |||
874 | return "float"; | |||
875 | } | |||
876 | ||||
877 | // Verifies an atomic update op. | |||
878 | template <typename ExpectedElementType> | |||
879 | static LogicalResult verifyAtomicUpdateOp(Operation *op) { | |||
880 | auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>(); | |||
881 | auto elementType = ptrType.getPointeeType(); | |||
882 | if (!elementType.isa<ExpectedElementType>()) | |||
883 | return op->emitOpError() << "pointer operand must point to an " | |||
884 | << stringifyTypeName<ExpectedElementType>() | |||
885 | << " value, found " << elementType; | |||
886 | ||||
887 | if (op->getNumOperands() > 1) { | |||
888 | auto valueType = op->getOperand(1).getType(); | |||
889 | if (valueType != elementType) | |||
890 | return op->emitOpError("expected value to have the same type as the " | |||
891 | "pointer operand's pointee type ") | |||
892 | << elementType << ", but found " << valueType; | |||
893 | } | |||
894 | auto memorySemantics = | |||
895 | op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName) | |||
896 | .getValue(); | |||
897 | if (failed(verifyMemorySemantics(op, memorySemantics))) { | |||
898 | return failure(); | |||
899 | } | |||
900 | return success(); | |||
901 | } | |||
902 | ||||
903 | static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, | |||
904 | OperationState &state) { | |||
905 | spirv::Scope executionScope; | |||
906 | spirv::GroupOperation groupOperation; | |||
907 | OpAsmParser::UnresolvedOperand valueInfo; | |||
908 | if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state, | |||
909 | kExecutionScopeAttrName) || | |||
910 | parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state, | |||
911 | kGroupOperationAttrName) || | |||
912 | parser.parseOperand(valueInfo)) | |||
913 | return failure(); | |||
914 | ||||
915 | std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo; | |||
916 | if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { | |||
917 | clusterSizeInfo = OpAsmParser::UnresolvedOperand(); | |||
918 | if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || | |||
919 | parser.parseRParen()) | |||
920 | return failure(); | |||
921 | } | |||
922 | ||||
923 | Type resultType; | |||
924 | if (parser.parseColonType(resultType)) | |||
925 | return failure(); | |||
926 | ||||
927 | if (parser.resolveOperand(valueInfo, resultType, state.operands)) | |||
928 | return failure(); | |||
929 | ||||
930 | if (clusterSizeInfo) { | |||
931 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
932 | if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) | |||
933 | return failure(); | |||
934 | } | |||
935 | ||||
936 | return parser.addTypeToList(resultType, state.types); | |||
937 | } | |||
938 | ||||
939 | static void printGroupNonUniformArithmeticOp(Operation *groupOp, | |||
940 | OpAsmPrinter &printer) { | |||
941 | printer | |||
942 | << " \"" | |||
943 | << stringifyScope( | |||
944 | groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName) | |||
945 | .getValue()) | |||
946 | << "\" \"" | |||
947 | << stringifyGroupOperation(groupOp | |||
948 | ->getAttrOfType<spirv::GroupOperationAttr>( | |||
949 | kGroupOperationAttrName) | |||
950 | .getValue()) | |||
951 | << "\" " << groupOp->getOperand(0); | |||
952 | ||||
953 | if (groupOp->getNumOperands() > 1) | |||
954 | printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; | |||
955 | printer << " : " << groupOp->getResult(0).getType(); | |||
956 | } | |||
957 | ||||
958 | static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { | |||
959 | spirv::Scope scope = | |||
960 | groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName) | |||
961 | .getValue(); | |||
962 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
963 | return groupOp->emitOpError( | |||
964 | "execution scope must be 'Workgroup' or 'Subgroup'"); | |||
965 | ||||
966 | spirv::GroupOperation operation = | |||
967 | groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName) | |||
968 | .getValue(); | |||
969 | if (operation == spirv::GroupOperation::ClusteredReduce && | |||
970 | groupOp->getNumOperands() == 1) | |||
971 | return groupOp->emitOpError("cluster size operand must be provided for " | |||
972 | "'ClusteredReduce' group operation"); | |||
973 | if (groupOp->getNumOperands() > 1) { | |||
974 | Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); | |||
975 | int32_t clusterSize = 0; | |||
976 | ||||
977 | // TODO: support specialization constant here. | |||
978 | if (failed(extractValueFromConstOp(sizeOp, clusterSize))) | |||
979 | return groupOp->emitOpError( | |||
980 | "cluster size operand must come from a constant op"); | |||
981 | ||||
982 | if (!llvm::isPowerOf2_32(clusterSize)) | |||
983 | return groupOp->emitOpError( | |||
984 | "cluster size operand must be a power of two"); | |||
985 | } | |||
986 | return success(); | |||
987 | } | |||
988 | ||||
989 | /// Result of a logical op must be a scalar or vector of boolean type. | |||
990 | static Type getUnaryOpResultType(Type operandType) { | |||
991 | Builder builder(operandType.getContext()); | |||
992 | Type resultType = builder.getIntegerType(1); | |||
993 | if (auto vecType = operandType.dyn_cast<VectorType>()) | |||
994 | return VectorType::get(vecType.getNumElements(), resultType); | |||
995 | return resultType; | |||
996 | } | |||
997 | ||||
998 | static LogicalResult verifyShiftOp(Operation *op) { | |||
999 | if (op->getOperand(0).getType() != op->getResult(0).getType()) { | |||
1000 | return op->emitError("expected the same type for the first operand and " | |||
1001 | "result, but provided ") | |||
1002 | << op->getOperand(0).getType() << " and " | |||
1003 | << op->getResult(0).getType(); | |||
1004 | } | |||
1005 | return success(); | |||
1006 | } | |||
1007 | ||||
1008 | static void buildLogicalBinaryOp(OpBuilder &builder, OperationState &state, | |||
1009 | Value lhs, Value rhs) { | |||
1010 | assert(lhs.getType() == rhs.getType())(static_cast <bool> (lhs.getType() == rhs.getType()) ? void (0) : __assert_fail ("lhs.getType() == rhs.getType()", "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp" , 1010, __extension__ __PRETTY_FUNCTION__)); | |||
1011 | ||||
1012 | Type boolType = builder.getI1Type(); | |||
1013 | if (auto vecType = lhs.getType().dyn_cast<VectorType>()) | |||
1014 | boolType = VectorType::get(vecType.getShape(), boolType); | |||
1015 | state.addTypes(boolType); | |||
1016 | ||||
1017 | state.addOperands({lhs, rhs}); | |||
1018 | } | |||
1019 | ||||
1020 | static void buildLogicalUnaryOp(OpBuilder &builder, OperationState &state, | |||
1021 | Value value) { | |||
1022 | Type boolType = builder.getI1Type(); | |||
1023 | if (auto vecType = value.getType().dyn_cast<VectorType>()) | |||
1024 | boolType = VectorType::get(vecType.getShape(), boolType); | |||
1025 | state.addTypes(boolType); | |||
1026 | ||||
1027 | state.addOperands(value); | |||
1028 | } | |||
1029 | ||||
1030 | //===----------------------------------------------------------------------===// | |||
1031 | // spirv.AccessChainOp | |||
1032 | //===----------------------------------------------------------------------===// | |||
1033 | ||||
1034 | static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { | |||
1035 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1036 | if (!ptrType) { | |||
1037 | emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " | |||
1038 | "to composite type, but provided ") | |||
1039 | << type; | |||
1040 | return nullptr; | |||
1041 | } | |||
1042 | ||||
1043 | auto resultType = ptrType.getPointeeType(); | |||
1044 | auto resultStorageClass = ptrType.getStorageClass(); | |||
1045 | int32_t index = 0; | |||
1046 | ||||
1047 | for (auto indexSSA : indices) { | |||
1048 | auto cType = resultType.dyn_cast<spirv::CompositeType>(); | |||
1049 | if (!cType) { | |||
1050 | emitError( | |||
1051 | baseLoc, | |||
1052 | "'spirv.AccessChain' op cannot extract from non-composite type ") | |||
1053 | << resultType << " with index " << index; | |||
1054 | return nullptr; | |||
1055 | } | |||
1056 | index = 0; | |||
1057 | if (resultType.isa<spirv::StructType>()) { | |||
1058 | Operation *op = indexSSA.getDefiningOp(); | |||
1059 | if (!op) { | |||
1060 | emitError(baseLoc, "'spirv.AccessChain' op index must be an " | |||
1061 | "integer spirv.Constant to access " | |||
1062 | "element of spirv.struct"); | |||
1063 | return nullptr; | |||
1064 | } | |||
1065 | ||||
1066 | // TODO: this should be relaxed to allow | |||
1067 | // integer literals of other bitwidths. | |||
1068 | if (failed(extractValueFromConstOp(op, index))) { | |||
1069 | emitError( | |||
1070 | baseLoc, | |||
1071 | "'spirv.AccessChain' index must be an integer spirv.Constant to " | |||
1072 | "access element of spirv.struct, but provided ") | |||
1073 | << op->getName(); | |||
1074 | return nullptr; | |||
1075 | } | |||
1076 | if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { | |||
1077 | emitError(baseLoc, "'spirv.AccessChain' op index ") | |||
1078 | << index << " out of bounds for " << resultType; | |||
1079 | return nullptr; | |||
1080 | } | |||
1081 | } | |||
1082 | resultType = cType.getElementType(index); | |||
1083 | } | |||
1084 | return spirv::PointerType::get(resultType, resultStorageClass); | |||
1085 | } | |||
1086 | ||||
1087 | void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state, | |||
1088 | Value basePtr, ValueRange indices) { | |||
1089 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
1090 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1090, __extension__ __PRETTY_FUNCTION__)); | |||
1091 | build(builder, state, type, basePtr, indices); | |||
1092 | } | |||
1093 | ||||
1094 | ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser, | |||
1095 | OperationState &result) { | |||
1096 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
1097 | SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo; | |||
1098 | Type type; | |||
1099 | auto loc = parser.getCurrentLocation(); | |||
1100 | SmallVector<Type, 4> indicesTypes; | |||
1101 | ||||
1102 | if (parser.parseOperand(ptrInfo) || | |||
1103 | parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || | |||
1104 | parser.parseColonType(type) || | |||
1105 | parser.resolveOperand(ptrInfo, type, result.operands)) { | |||
1106 | return failure(); | |||
1107 | } | |||
1108 | ||||
1109 | // Check that the provided indices list is not empty before parsing their | |||
1110 | // type list. | |||
1111 | if (indicesInfo.empty()) { | |||
1112 | return mlir::emitError(result.location, | |||
1113 | "'spirv.AccessChain' op expected at " | |||
1114 | "least one index "); | |||
1115 | } | |||
1116 | ||||
1117 | if (parser.parseComma() || parser.parseTypeList(indicesTypes)) | |||
1118 | return failure(); | |||
1119 | ||||
1120 | // Check that the indices types list is not empty and that it has a one-to-one | |||
1121 | // mapping to the provided indices. | |||
1122 | if (indicesTypes.size() != indicesInfo.size()) { | |||
1123 | return mlir::emitError( | |||
1124 | result.location, "'spirv.AccessChain' op indices types' count must be " | |||
1125 | "equal to indices info count"); | |||
1126 | } | |||
1127 | ||||
1128 | if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands)) | |||
1129 | return failure(); | |||
1130 | ||||
1131 | auto resultType = getElementPtrType( | |||
1132 | type, llvm::makeArrayRef(result.operands).drop_front(), result.location); | |||
1133 | if (!resultType) { | |||
1134 | return failure(); | |||
1135 | } | |||
1136 | ||||
1137 | result.addTypes(resultType); | |||
1138 | return success(); | |||
1139 | } | |||
1140 | ||||
1141 | template <typename Op> | |||
1142 | static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { | |||
1143 | printer << ' ' << op.getBasePtr() << '[' << indices | |||
1144 | << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); | |||
1145 | } | |||
1146 | ||||
1147 | void spirv::AccessChainOp::print(OpAsmPrinter &printer) { | |||
1148 | printAccessChain(*this, getIndices(), printer); | |||
1149 | } | |||
1150 | ||||
1151 | template <typename Op> | |||
1152 | static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { | |||
1153 | auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), | |||
1154 | indices, accessChainOp.getLoc()); | |||
1155 | if (!resultType) | |||
1156 | return failure(); | |||
1157 | ||||
1158 | auto providedResultType = | |||
1159 | accessChainOp.getType().template dyn_cast<spirv::PointerType>(); | |||
1160 | if (!providedResultType) | |||
1161 | return accessChainOp.emitOpError( | |||
1162 | "result type must be a pointer, but provided") | |||
1163 | << providedResultType; | |||
1164 | ||||
1165 | if (resultType != providedResultType) | |||
1166 | return accessChainOp.emitOpError("invalid result type: expected ") | |||
1167 | << resultType << ", but provided " << providedResultType; | |||
1168 | ||||
1169 | return success(); | |||
1170 | } | |||
1171 | ||||
1172 | LogicalResult spirv::AccessChainOp::verify() { | |||
1173 | return verifyAccessChain(*this, getIndices()); | |||
1174 | } | |||
1175 | ||||
1176 | //===----------------------------------------------------------------------===// | |||
1177 | // spirv.mlir.addressof | |||
1178 | //===----------------------------------------------------------------------===// | |||
1179 | ||||
1180 | void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, | |||
1181 | spirv::GlobalVariableOp var) { | |||
1182 | build(builder, state, var.getType(), SymbolRefAttr::get(var)); | |||
1183 | } | |||
1184 | ||||
1185 | LogicalResult spirv::AddressOfOp::verify() { | |||
1186 | auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>( | |||
1187 | SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), | |||
1188 | getVariableAttr())); | |||
1189 | if (!varOp) { | |||
1190 | return emitOpError("expected spirv.GlobalVariable symbol"); | |||
1191 | } | |||
1192 | if (getPointer().getType() != varOp.getType()) { | |||
1193 | return emitOpError( | |||
1194 | "result type mismatch with the referenced global variable's type"); | |||
1195 | } | |||
1196 | return success(); | |||
1197 | } | |||
1198 | ||||
1199 | template <typename T> | |||
1200 | static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { | |||
1201 | printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \"" | |||
1202 | << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \"" | |||
1203 | << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" " | |||
1204 | << atomOp.getOperands() << " : " << atomOp.getPointer().getType(); | |||
1205 | } | |||
1206 | ||||
1207 | static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, | |||
1208 | OperationState &state) { | |||
1209 | spirv::Scope memoryScope; | |||
1210 | spirv::MemorySemantics equalSemantics, unequalSemantics; | |||
1211 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; | |||
1212 | Type type; | |||
1213 | if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state, | |||
1214 | kMemoryScopeAttrName) || | |||
1215 | parseEnumStrAttr<spirv::MemorySemanticsAttr>( | |||
1216 | equalSemantics, parser, state, kEqualSemanticsAttrName) || | |||
1217 | parseEnumStrAttr<spirv::MemorySemanticsAttr>( | |||
1218 | unequalSemantics, parser, state, kUnequalSemanticsAttrName) || | |||
1219 | parser.parseOperandList(operandInfo, 3)) | |||
1220 | return failure(); | |||
1221 | ||||
1222 | auto loc = parser.getCurrentLocation(); | |||
1223 | if (parser.parseColonType(type)) | |||
1224 | return failure(); | |||
1225 | ||||
1226 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1227 | if (!ptrType) | |||
1228 | return parser.emitError(loc, "expected pointer type"); | |||
1229 | ||||
1230 | if (parser.resolveOperands( | |||
1231 | operandInfo, | |||
1232 | {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()}, | |||
1233 | parser.getNameLoc(), state.operands)) | |||
1234 | return failure(); | |||
1235 | ||||
1236 | return parser.addTypeToList(ptrType.getPointeeType(), state.types); | |||
1237 | } | |||
1238 | ||||
1239 | template <typename T> | |||
1240 | static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { | |||
1241 | // According to the spec: | |||
1242 | // "The type of Value must be the same as Result Type. The type of the value | |||
1243 | // pointed to by Pointer must be the same as Result Type. This type must also | |||
1244 | // match the type of Comparator." | |||
1245 | if (atomOp.getType() != atomOp.getValue().getType()) | |||
1246 | return atomOp.emitOpError("value operand must have the same type as the op " | |||
1247 | "result, but found ") | |||
1248 | << atomOp.getValue().getType() << " vs " << atomOp.getType(); | |||
1249 | ||||
1250 | if (atomOp.getType() != atomOp.getComparator().getType()) | |||
1251 | return atomOp.emitOpError( | |||
1252 | "comparator operand must have the same type as the op " | |||
1253 | "result, but found ") | |||
1254 | << atomOp.getComparator().getType() << " vs " << atomOp.getType(); | |||
1255 | ||||
1256 | Type pointeeType = atomOp.getPointer() | |||
1257 | .getType() | |||
1258 | .template cast<spirv::PointerType>() | |||
1259 | .getPointeeType(); | |||
1260 | if (atomOp.getType() != pointeeType) | |||
1261 | return atomOp.emitOpError( | |||
1262 | "pointer operand's pointee type must have the same " | |||
1263 | "as the op result type, but found ") | |||
1264 | << pointeeType << " vs " << atomOp.getType(); | |||
1265 | ||||
1266 | // TODO: Unequal cannot be set to Release or Acquire and Release. | |||
1267 | // In addition, Unequal cannot be set to a stronger memory-order then Equal. | |||
1268 | ||||
1269 | return success(); | |||
1270 | } | |||
1271 | ||||
1272 | //===----------------------------------------------------------------------===// | |||
1273 | // spirv.AtomicAndOp | |||
1274 | //===----------------------------------------------------------------------===// | |||
1275 | ||||
1276 | LogicalResult spirv::AtomicAndOp::verify() { | |||
1277 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1278 | } | |||
1279 | ||||
1280 | ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser, | |||
1281 | OperationState &result) { | |||
1282 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1283 | } | |||
1284 | void spirv::AtomicAndOp::print(OpAsmPrinter &p) { | |||
1285 | ::printAtomicUpdateOp(*this, p); | |||
1286 | } | |||
1287 | ||||
1288 | //===----------------------------------------------------------------------===// | |||
1289 | // spirv.AtomicCompareExchangeOp | |||
1290 | //===----------------------------------------------------------------------===// | |||
1291 | ||||
1292 | LogicalResult spirv::AtomicCompareExchangeOp::verify() { | |||
1293 | return ::verifyAtomicCompareExchangeImpl(*this); | |||
1294 | } | |||
1295 | ||||
1296 | ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser, | |||
1297 | OperationState &result) { | |||
1298 | return ::parseAtomicCompareExchangeImpl(parser, result); | |||
1299 | } | |||
1300 | void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) { | |||
1301 | ::printAtomicCompareExchangeImpl(*this, p); | |||
1302 | } | |||
1303 | ||||
1304 | //===----------------------------------------------------------------------===// | |||
1305 | // spirv.AtomicCompareExchangeWeakOp | |||
1306 | //===----------------------------------------------------------------------===// | |||
1307 | ||||
1308 | LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { | |||
1309 | return ::verifyAtomicCompareExchangeImpl(*this); | |||
1310 | } | |||
1311 | ||||
1312 | ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, | |||
1313 | OperationState &result) { | |||
1314 | return ::parseAtomicCompareExchangeImpl(parser, result); | |||
1315 | } | |||
1316 | void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { | |||
1317 | ::printAtomicCompareExchangeImpl(*this, p); | |||
1318 | } | |||
1319 | ||||
1320 | //===----------------------------------------------------------------------===// | |||
1321 | // spirv.AtomicExchange | |||
1322 | //===----------------------------------------------------------------------===// | |||
1323 | ||||
1324 | void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { | |||
1325 | printer << " \"" << stringifyScope(getMemoryScope()) << "\" \"" | |||
1326 | << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands() | |||
1327 | << " : " << getPointer().getType(); | |||
1328 | } | |||
1329 | ||||
1330 | ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, | |||
1331 | OperationState &result) { | |||
1332 | spirv::Scope memoryScope; | |||
1333 | spirv::MemorySemantics semantics; | |||
1334 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
1335 | Type type; | |||
1336 | if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result, | |||
1337 | kMemoryScopeAttrName) || | |||
1338 | parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result, | |||
1339 | kSemanticsAttrName) || | |||
1340 | parser.parseOperandList(operandInfo, 2)) | |||
1341 | return failure(); | |||
1342 | ||||
1343 | auto loc = parser.getCurrentLocation(); | |||
1344 | if (parser.parseColonType(type)) | |||
1345 | return failure(); | |||
1346 | ||||
1347 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1348 | if (!ptrType) | |||
1349 | return parser.emitError(loc, "expected pointer type"); | |||
1350 | ||||
1351 | if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, | |||
1352 | parser.getNameLoc(), result.operands)) | |||
1353 | return failure(); | |||
1354 | ||||
1355 | return parser.addTypeToList(ptrType.getPointeeType(), result.types); | |||
1356 | } | |||
1357 | ||||
1358 | LogicalResult spirv::AtomicExchangeOp::verify() { | |||
1359 | if (getType() != getValue().getType()) | |||
1360 | return emitOpError("value operand must have the same type as the op " | |||
1361 | "result, but found ") | |||
1362 | << getValue().getType() << " vs " << getType(); | |||
1363 | ||||
1364 | Type pointeeType = | |||
1365 | getPointer().getType().cast<spirv::PointerType>().getPointeeType(); | |||
1366 | if (getType() != pointeeType) | |||
1367 | return emitOpError("pointer operand's pointee type must have the same " | |||
1368 | "as the op result type, but found ") | |||
1369 | << pointeeType << " vs " << getType(); | |||
1370 | ||||
1371 | return success(); | |||
1372 | } | |||
1373 | ||||
1374 | //===----------------------------------------------------------------------===// | |||
1375 | // spirv.AtomicIAddOp | |||
1376 | //===----------------------------------------------------------------------===// | |||
1377 | ||||
1378 | LogicalResult spirv::AtomicIAddOp::verify() { | |||
1379 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1380 | } | |||
1381 | ||||
1382 | ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser, | |||
1383 | OperationState &result) { | |||
1384 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1385 | } | |||
1386 | void spirv::AtomicIAddOp::print(OpAsmPrinter &p) { | |||
1387 | ::printAtomicUpdateOp(*this, p); | |||
1388 | } | |||
1389 | ||||
1390 | //===----------------------------------------------------------------------===// | |||
1391 | // spirv.EXT.AtomicFAddOp | |||
1392 | //===----------------------------------------------------------------------===// | |||
1393 | ||||
1394 | LogicalResult spirv::EXTAtomicFAddOp::verify() { | |||
1395 | return ::verifyAtomicUpdateOp<FloatType>(getOperation()); | |||
1396 | } | |||
1397 | ||||
1398 | ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser, | |||
1399 | OperationState &result) { | |||
1400 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1401 | } | |||
1402 | void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) { | |||
1403 | ::printAtomicUpdateOp(*this, p); | |||
1404 | } | |||
1405 | ||||
1406 | //===----------------------------------------------------------------------===// | |||
1407 | // spirv.AtomicIDecrementOp | |||
1408 | //===----------------------------------------------------------------------===// | |||
1409 | ||||
1410 | LogicalResult spirv::AtomicIDecrementOp::verify() { | |||
1411 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1412 | } | |||
1413 | ||||
1414 | ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser, | |||
1415 | OperationState &result) { | |||
1416 | return ::parseAtomicUpdateOp(parser, result, false); | |||
1417 | } | |||
1418 | void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) { | |||
1419 | ::printAtomicUpdateOp(*this, p); | |||
1420 | } | |||
1421 | ||||
1422 | //===----------------------------------------------------------------------===// | |||
1423 | // spirv.AtomicIIncrementOp | |||
1424 | //===----------------------------------------------------------------------===// | |||
1425 | ||||
1426 | LogicalResult spirv::AtomicIIncrementOp::verify() { | |||
1427 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1428 | } | |||
1429 | ||||
1430 | ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser, | |||
1431 | OperationState &result) { | |||
1432 | return ::parseAtomicUpdateOp(parser, result, false); | |||
1433 | } | |||
1434 | void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) { | |||
1435 | ::printAtomicUpdateOp(*this, p); | |||
1436 | } | |||
1437 | ||||
1438 | //===----------------------------------------------------------------------===// | |||
1439 | // spirv.AtomicISubOp | |||
1440 | //===----------------------------------------------------------------------===// | |||
1441 | ||||
1442 | LogicalResult spirv::AtomicISubOp::verify() { | |||
1443 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1444 | } | |||
1445 | ||||
1446 | ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser, | |||
1447 | OperationState &result) { | |||
1448 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1449 | } | |||
1450 | void spirv::AtomicISubOp::print(OpAsmPrinter &p) { | |||
1451 | ::printAtomicUpdateOp(*this, p); | |||
1452 | } | |||
1453 | ||||
1454 | //===----------------------------------------------------------------------===// | |||
1455 | // spirv.AtomicOrOp | |||
1456 | //===----------------------------------------------------------------------===// | |||
1457 | ||||
1458 | LogicalResult spirv::AtomicOrOp::verify() { | |||
1459 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1460 | } | |||
1461 | ||||
1462 | ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser, | |||
1463 | OperationState &result) { | |||
1464 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1465 | } | |||
1466 | void spirv::AtomicOrOp::print(OpAsmPrinter &p) { | |||
1467 | ::printAtomicUpdateOp(*this, p); | |||
1468 | } | |||
1469 | ||||
1470 | //===----------------------------------------------------------------------===// | |||
1471 | // spirv.AtomicSMaxOp | |||
1472 | //===----------------------------------------------------------------------===// | |||
1473 | ||||
1474 | LogicalResult spirv::AtomicSMaxOp::verify() { | |||
1475 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1476 | } | |||
1477 | ||||
1478 | ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser, | |||
1479 | OperationState &result) { | |||
1480 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1481 | } | |||
1482 | void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) { | |||
1483 | ::printAtomicUpdateOp(*this, p); | |||
1484 | } | |||
1485 | ||||
1486 | //===----------------------------------------------------------------------===// | |||
1487 | // spirv.AtomicSMinOp | |||
1488 | //===----------------------------------------------------------------------===// | |||
1489 | ||||
1490 | LogicalResult spirv::AtomicSMinOp::verify() { | |||
1491 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1492 | } | |||
1493 | ||||
1494 | ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser, | |||
1495 | OperationState &result) { | |||
1496 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1497 | } | |||
1498 | void spirv::AtomicSMinOp::print(OpAsmPrinter &p) { | |||
1499 | ::printAtomicUpdateOp(*this, p); | |||
1500 | } | |||
1501 | ||||
1502 | //===----------------------------------------------------------------------===// | |||
1503 | // spirv.AtomicUMaxOp | |||
1504 | //===----------------------------------------------------------------------===// | |||
1505 | ||||
1506 | LogicalResult spirv::AtomicUMaxOp::verify() { | |||
1507 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1508 | } | |||
1509 | ||||
1510 | ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser, | |||
1511 | OperationState &result) { | |||
1512 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1513 | } | |||
1514 | void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) { | |||
1515 | ::printAtomicUpdateOp(*this, p); | |||
1516 | } | |||
1517 | ||||
1518 | //===----------------------------------------------------------------------===// | |||
1519 | // spirv.AtomicUMinOp | |||
1520 | //===----------------------------------------------------------------------===// | |||
1521 | ||||
1522 | LogicalResult spirv::AtomicUMinOp::verify() { | |||
1523 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1524 | } | |||
1525 | ||||
1526 | ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser, | |||
1527 | OperationState &result) { | |||
1528 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1529 | } | |||
1530 | void spirv::AtomicUMinOp::print(OpAsmPrinter &p) { | |||
1531 | ::printAtomicUpdateOp(*this, p); | |||
1532 | } | |||
1533 | ||||
1534 | //===----------------------------------------------------------------------===// | |||
1535 | // spirv.AtomicXorOp | |||
1536 | //===----------------------------------------------------------------------===// | |||
1537 | ||||
1538 | LogicalResult spirv::AtomicXorOp::verify() { | |||
1539 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1540 | } | |||
1541 | ||||
1542 | ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser, | |||
1543 | OperationState &result) { | |||
1544 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1545 | } | |||
1546 | void spirv::AtomicXorOp::print(OpAsmPrinter &p) { | |||
1547 | ::printAtomicUpdateOp(*this, p); | |||
1548 | } | |||
1549 | ||||
1550 | //===----------------------------------------------------------------------===// | |||
1551 | // spirv.BitcastOp | |||
1552 | //===----------------------------------------------------------------------===// | |||
1553 | ||||
1554 | LogicalResult spirv::BitcastOp::verify() { | |||
1555 | // TODO: The SPIR-V spec validation rules are different for different | |||
1556 | // versions. | |||
1557 | auto operandType = getOperand().getType(); | |||
1558 | auto resultType = getResult().getType(); | |||
1559 | if (operandType == resultType) { | |||
1560 | return emitError("result type must be different from operand type"); | |||
1561 | } | |||
1562 | if (operandType.isa<spirv::PointerType>() && | |||
1563 | !resultType.isa<spirv::PointerType>()) { | |||
1564 | return emitError( | |||
1565 | "unhandled bit cast conversion from pointer type to non-pointer type"); | |||
1566 | } | |||
1567 | if (!operandType.isa<spirv::PointerType>() && | |||
1568 | resultType.isa<spirv::PointerType>()) { | |||
1569 | return emitError( | |||
1570 | "unhandled bit cast conversion from non-pointer type to pointer type"); | |||
1571 | } | |||
1572 | auto operandBitWidth = getBitWidth(operandType); | |||
1573 | auto resultBitWidth = getBitWidth(resultType); | |||
1574 | if (operandBitWidth != resultBitWidth) { | |||
1575 | return emitOpError("mismatch in result type bitwidth ") | |||
1576 | << resultBitWidth << " and operand type bitwidth " | |||
1577 | << operandBitWidth; | |||
1578 | } | |||
1579 | return success(); | |||
1580 | } | |||
1581 | ||||
1582 | //===----------------------------------------------------------------------===// | |||
1583 | // spirv.PtrCastToGenericOp | |||
1584 | //===----------------------------------------------------------------------===// | |||
1585 | ||||
1586 | LogicalResult spirv::PtrCastToGenericOp::verify() { | |||
1587 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1588 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1589 | ||||
1590 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1591 | if (operandStorage != spirv::StorageClass::Workgroup && | |||
1592 | operandStorage != spirv::StorageClass::CrossWorkgroup && | |||
1593 | operandStorage != spirv::StorageClass::Function) | |||
1594 | return emitError("pointer must point to the Workgroup, CrossWorkgroup" | |||
1595 | ", or Function Storage Class"); | |||
1596 | ||||
1597 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1598 | if (resultStorage != spirv::StorageClass::Generic) | |||
1599 | return emitError("result type must be of storage class Generic"); | |||
1600 | ||||
1601 | Type operandPointeeType = operandType.getPointeeType(); | |||
1602 | Type resultPointeeType = resultType.getPointeeType(); | |||
1603 | if (operandPointeeType != resultPointeeType) | |||
1604 | return emitOpError("pointer operand's pointee type must have the same " | |||
1605 | "as the op result type, but found ") | |||
1606 | << operandPointeeType << " vs " << resultPointeeType; | |||
1607 | return success(); | |||
1608 | } | |||
1609 | ||||
1610 | //===----------------------------------------------------------------------===// | |||
1611 | // spirv.GenericCastToPtrOp | |||
1612 | //===----------------------------------------------------------------------===// | |||
1613 | ||||
1614 | LogicalResult spirv::GenericCastToPtrOp::verify() { | |||
1615 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1616 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1617 | ||||
1618 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1619 | if (operandStorage != spirv::StorageClass::Generic) | |||
1620 | return emitError("pointer type must be of storage class Generic"); | |||
1621 | ||||
1622 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1623 | if (resultStorage != spirv::StorageClass::Workgroup && | |||
1624 | resultStorage != spirv::StorageClass::CrossWorkgroup && | |||
1625 | resultStorage != spirv::StorageClass::Function) | |||
1626 | return emitError("result must point to the Workgroup, CrossWorkgroup, " | |||
1627 | "or Function Storage Class"); | |||
1628 | ||||
1629 | Type operandPointeeType = operandType.getPointeeType(); | |||
1630 | Type resultPointeeType = resultType.getPointeeType(); | |||
1631 | if (operandPointeeType != resultPointeeType) | |||
1632 | return emitOpError("pointer operand's pointee type must have the same " | |||
1633 | "as the op result type, but found ") | |||
1634 | << operandPointeeType << " vs " << resultPointeeType; | |||
1635 | return success(); | |||
1636 | } | |||
1637 | ||||
1638 | //===----------------------------------------------------------------------===// | |||
1639 | // spirv.GenericCastToPtrExplicitOp | |||
1640 | //===----------------------------------------------------------------------===// | |||
1641 | ||||
1642 | LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { | |||
1643 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1644 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1645 | ||||
1646 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1647 | if (operandStorage != spirv::StorageClass::Generic) | |||
1648 | return emitError("pointer type must be of storage class Generic"); | |||
1649 | ||||
1650 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1651 | if (resultStorage != spirv::StorageClass::Workgroup && | |||
1652 | resultStorage != spirv::StorageClass::CrossWorkgroup && | |||
1653 | resultStorage != spirv::StorageClass::Function) | |||
1654 | return emitError("result must point to the Workgroup, CrossWorkgroup, " | |||
1655 | "or Function Storage Class"); | |||
1656 | ||||
1657 | Type operandPointeeType = operandType.getPointeeType(); | |||
1658 | Type resultPointeeType = resultType.getPointeeType(); | |||
1659 | if (operandPointeeType != resultPointeeType) | |||
1660 | return emitOpError("pointer operand's pointee type must have the same " | |||
1661 | "as the op result type, but found ") | |||
1662 | << operandPointeeType << " vs " << resultPointeeType; | |||
1663 | return success(); | |||
1664 | } | |||
1665 | ||||
1666 | //===----------------------------------------------------------------------===// | |||
1667 | // spirv.BranchOp | |||
1668 | //===----------------------------------------------------------------------===// | |||
1669 | ||||
1670 | SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { | |||
1671 | assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index" ) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1671, __extension__ __PRETTY_FUNCTION__)); | |||
1672 | return SuccessorOperands(0, getTargetOperandsMutable()); | |||
1673 | } | |||
1674 | ||||
1675 | //===----------------------------------------------------------------------===// | |||
1676 | // spirv.BranchConditionalOp | |||
1677 | //===----------------------------------------------------------------------===// | |||
1678 | ||||
1679 | SuccessorOperands | |||
1680 | spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { | |||
1681 | assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index" ) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1681, __extension__ __PRETTY_FUNCTION__)); | |||
1682 | return SuccessorOperands(index == kTrueIndex | |||
1683 | ? getTrueTargetOperandsMutable() | |||
1684 | : getFalseTargetOperandsMutable()); | |||
1685 | } | |||
1686 | ||||
1687 | ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, | |||
1688 | OperationState &result) { | |||
1689 | auto &builder = parser.getBuilder(); | |||
1690 | OpAsmParser::UnresolvedOperand condInfo; | |||
1691 | Block *dest; | |||
1692 | ||||
1693 | // Parse the condition. | |||
1694 | Type boolTy = builder.getI1Type(); | |||
1695 | if (parser.parseOperand(condInfo) || | |||
1696 | parser.resolveOperand(condInfo, boolTy, result.operands)) | |||
1697 | return failure(); | |||
1698 | ||||
1699 | // Parse the optional branch weights. | |||
1700 | if (succeeded(parser.parseOptionalLSquare())) { | |||
1701 | IntegerAttr trueWeight, falseWeight; | |||
1702 | NamedAttrList weights; | |||
1703 | ||||
1704 | auto i32Type = builder.getIntegerType(32); | |||
1705 | if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || | |||
1706 | parser.parseComma() || | |||
1707 | parser.parseAttribute(falseWeight, i32Type, "weight", weights) || | |||
1708 | parser.parseRSquare()) | |||
1709 | return failure(); | |||
1710 | ||||
1711 | result.addAttribute(kBranchWeightAttrName, | |||
1712 | builder.getArrayAttr({trueWeight, falseWeight})); | |||
1713 | } | |||
1714 | ||||
1715 | // Parse the true branch. | |||
1716 | SmallVector<Value, 4> trueOperands; | |||
1717 | if (parser.parseComma() || | |||
1718 | parser.parseSuccessorAndUseList(dest, trueOperands)) | |||
1719 | return failure(); | |||
1720 | result.addSuccessors(dest); | |||
1721 | result.addOperands(trueOperands); | |||
1722 | ||||
1723 | // Parse the false branch. | |||
1724 | SmallVector<Value, 4> falseOperands; | |||
1725 | if (parser.parseComma() || | |||
1726 | parser.parseSuccessorAndUseList(dest, falseOperands)) | |||
1727 | return failure(); | |||
1728 | result.addSuccessors(dest); | |||
1729 | result.addOperands(falseOperands); | |||
1730 | result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), | |||
1731 | builder.getDenseI32ArrayAttr( | |||
1732 | {1, static_cast<int32_t>(trueOperands.size()), | |||
1733 | static_cast<int32_t>(falseOperands.size())})); | |||
1734 | ||||
1735 | return success(); | |||
1736 | } | |||
1737 | ||||
1738 | void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { | |||
1739 | printer << ' ' << getCondition(); | |||
1740 | ||||
1741 | if (auto weights = getBranchWeights()) { | |||
1742 | printer << " ["; | |||
1743 | llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { | |||
1744 | printer << a.cast<IntegerAttr>().getInt(); | |||
1745 | }); | |||
1746 | printer << "]"; | |||
1747 | } | |||
1748 | ||||
1749 | printer << ", "; | |||
1750 | printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); | |||
1751 | printer << ", "; | |||
1752 | printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); | |||
1753 | } | |||
1754 | ||||
1755 | LogicalResult spirv::BranchConditionalOp::verify() { | |||
1756 | if (auto weights = getBranchWeights()) { | |||
1757 | if (weights->getValue().size() != 2) { | |||
1758 | return emitOpError("must have exactly two branch weights"); | |||
1759 | } | |||
1760 | if (llvm::all_of(*weights, [](Attribute attr) { | |||
1761 | return attr.cast<IntegerAttr>().getValue().isNullValue(); | |||
1762 | })) | |||
1763 | return emitOpError("branch weights cannot both be zero"); | |||
1764 | } | |||
1765 | ||||
1766 | return success(); | |||
1767 | } | |||
1768 | ||||
1769 | //===----------------------------------------------------------------------===// | |||
1770 | // spirv.CompositeConstruct | |||
1771 | //===----------------------------------------------------------------------===// | |||
1772 | ||||
1773 | LogicalResult spirv::CompositeConstructOp::verify() { | |||
1774 | auto cType = getType().cast<spirv::CompositeType>(); | |||
1775 | operand_range constituents = this->getConstituents(); | |||
1776 | ||||
1777 | if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
1778 | if (constituents.size() != 1) | |||
1779 | return emitOpError("has incorrect number of operands: expected ") | |||
1780 | << "1, but provided " << constituents.size(); | |||
1781 | if (coopType.getElementType() != constituents.front().getType()) | |||
1782 | return emitOpError("operand type mismatch: expected operand type ") | |||
1783 | << coopType.getElementType() << ", but provided " | |||
1784 | << constituents.front().getType(); | |||
1785 | return success(); | |||
1786 | } | |||
1787 | ||||
1788 | if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) { | |||
1789 | if (constituents.size() != 1) | |||
1790 | return emitOpError("has incorrect number of operands: expected ") | |||
1791 | << "1, but provided " << constituents.size(); | |||
1792 | if (jointType.getElementType() != constituents.front().getType()) | |||
1793 | return emitOpError("operand type mismatch: expected operand type ") | |||
1794 | << jointType.getElementType() << ", but provided " | |||
1795 | << constituents.front().getType(); | |||
1796 | return success(); | |||
1797 | } | |||
1798 | ||||
1799 | if (constituents.size() == cType.getNumElements()) { | |||
1800 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { | |||
1801 | if (constituents[index].getType() != cType.getElementType(index)) { | |||
1802 | return emitOpError("operand type mismatch: expected operand type ") | |||
1803 | << cType.getElementType(index) << ", but provided " | |||
1804 | << constituents[index].getType(); | |||
1805 | } | |||
1806 | } | |||
1807 | return success(); | |||
1808 | } | |||
1809 | ||||
1810 | // If not constructing a cooperative matrix type, then we must be constructing | |||
1811 | // a vector type. | |||
1812 | auto resultType = cType.dyn_cast<VectorType>(); | |||
1813 | if (!resultType) | |||
1814 | return emitOpError( | |||
1815 | "expected to return a vector or cooperative matrix when the number of " | |||
1816 | "constituents is less than what the result needs"); | |||
1817 | ||||
1818 | SmallVector<unsigned> sizes; | |||
1819 | for (Value component : constituents) { | |||
1820 | if (!component.getType().isa<VectorType>() && | |||
1821 | !component.getType().isIntOrFloat()) | |||
1822 | return emitOpError("operand type mismatch: expected operand to have " | |||
1823 | "a scalar or vector type, but provided ") | |||
1824 | << component.getType(); | |||
1825 | ||||
1826 | Type elementType = component.getType(); | |||
1827 | if (auto vectorType = component.getType().dyn_cast<VectorType>()) { | |||
1828 | sizes.push_back(vectorType.getNumElements()); | |||
1829 | elementType = vectorType.getElementType(); | |||
1830 | } else { | |||
1831 | sizes.push_back(1); | |||
1832 | } | |||
1833 | ||||
1834 | if (elementType != resultType.getElementType()) | |||
1835 | return emitOpError("operand element type mismatch: expected to be ") | |||
1836 | << resultType.getElementType() << ", but provided " << elementType; | |||
1837 | } | |||
1838 | unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); | |||
1839 | if (totalCount != cType.getNumElements()) | |||
1840 | return emitOpError("has incorrect number of operands: expected ") | |||
1841 | << cType.getNumElements() << ", but provided " << totalCount; | |||
1842 | return success(); | |||
1843 | } | |||
1844 | ||||
1845 | //===----------------------------------------------------------------------===// | |||
1846 | // spirv.CompositeExtractOp | |||
1847 | //===----------------------------------------------------------------------===// | |||
1848 | ||||
1849 | void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, | |||
1850 | Value composite, | |||
1851 | ArrayRef<int32_t> indices) { | |||
1852 | auto indexAttr = builder.getI32ArrayAttr(indices); | |||
1853 | auto elementType = | |||
1854 | getElementType(composite.getType(), indexAttr, state.location); | |||
1855 | if (!elementType) { | |||
1856 | return; | |||
1857 | } | |||
1858 | build(builder, state, elementType, composite, indexAttr); | |||
1859 | } | |||
1860 | ||||
1861 | ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, | |||
1862 | OperationState &result) { | |||
1863 | OpAsmParser::UnresolvedOperand compositeInfo; | |||
1864 | Attribute indicesAttr; | |||
1865 | Type compositeType; | |||
1866 | SMLoc attrLocation; | |||
1867 | ||||
1868 | if (parser.parseOperand(compositeInfo) || | |||
1869 | parser.getCurrentLocation(&attrLocation) || | |||
1870 | parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || | |||
1871 | parser.parseColonType(compositeType) || | |||
1872 | parser.resolveOperand(compositeInfo, compositeType, result.operands)) { | |||
1873 | return failure(); | |||
1874 | } | |||
1875 | ||||
1876 | Type resultType = | |||
1877 | getElementType(compositeType, indicesAttr, parser, attrLocation); | |||
1878 | if (!resultType) { | |||
1879 | return failure(); | |||
1880 | } | |||
1881 | result.addTypes(resultType); | |||
1882 | return success(); | |||
1883 | } | |||
1884 | ||||
1885 | void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { | |||
1886 | printer << ' ' << getComposite() << getIndices() << " : " | |||
1887 | << getComposite().getType(); | |||
1888 | } | |||
1889 | ||||
1890 | LogicalResult spirv::CompositeExtractOp::verify() { | |||
1891 | auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>(); | |||
1892 | auto resultType = | |||
1893 | getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); | |||
1894 | if (!resultType) | |||
1895 | return failure(); | |||
1896 | ||||
1897 | if (resultType != getType()) { | |||
1898 | return emitOpError("invalid result type: expected ") | |||
1899 | << resultType << " but provided " << getType(); | |||
1900 | } | |||
1901 | ||||
1902 | return success(); | |||
1903 | } | |||
1904 | ||||
1905 | //===----------------------------------------------------------------------===// | |||
1906 | // spirv.CompositeInsert | |||
1907 | //===----------------------------------------------------------------------===// | |||
1908 | ||||
1909 | void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, | |||
1910 | Value object, Value composite, | |||
1911 | ArrayRef<int32_t> indices) { | |||
1912 | auto indexAttr = builder.getI32ArrayAttr(indices); | |||
1913 | build(builder, state, composite.getType(), object, composite, indexAttr); | |||
1914 | } | |||
1915 | ||||
1916 | ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, | |||
1917 | OperationState &result) { | |||
1918 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; | |||
1919 | Type objectType, compositeType; | |||
1920 | Attribute indicesAttr; | |||
1921 | auto loc = parser.getCurrentLocation(); | |||
1922 | ||||
1923 | return failure( | |||
1924 | parser.parseOperandList(operands, 2) || | |||
1925 | parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || | |||
1926 | parser.parseColonType(objectType) || | |||
1927 | parser.parseKeywordType("into", compositeType) || | |||
1928 | parser.resolveOperands(operands, {objectType, compositeType}, loc, | |||
1929 | result.operands) || | |||
1930 | parser.addTypesToList(compositeType, result.types)); | |||
1931 | } | |||
1932 | ||||
1933 | LogicalResult spirv::CompositeInsertOp::verify() { | |||
1934 | auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>(); | |||
1935 | auto objectType = | |||
1936 | getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); | |||
1937 | if (!objectType) | |||
1938 | return failure(); | |||
1939 | ||||
1940 | if (objectType != getObject().getType()) { | |||
1941 | return emitOpError("object operand type should be ") | |||
1942 | << objectType << ", but found " << getObject().getType(); | |||
1943 | } | |||
1944 | ||||
1945 | if (getComposite().getType() != getType()) { | |||
1946 | return emitOpError("result type should be the same as " | |||
1947 | "the composite type, but found ") | |||
1948 | << getComposite().getType() << " vs " << getType(); | |||
1949 | } | |||
1950 | ||||
1951 | return success(); | |||
1952 | } | |||
1953 | ||||
1954 | void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { | |||
1955 | printer << " " << getObject() << ", " << getComposite() << getIndices() | |||
1956 | << " : " << getObject().getType() << " into " | |||
1957 | << getComposite().getType(); | |||
1958 | } | |||
1959 | ||||
1960 | //===----------------------------------------------------------------------===// | |||
1961 | // spirv.Constant | |||
1962 | //===----------------------------------------------------------------------===// | |||
1963 | ||||
1964 | ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, | |||
1965 | OperationState &result) { | |||
1966 | Attribute value; | |||
1967 | if (parser.parseAttribute(value, kValueAttrName, result.attributes)) | |||
1968 | return failure(); | |||
1969 | ||||
1970 | Type type = NoneType::get(parser.getContext()); | |||
1971 | if (auto typedAttr = value.dyn_cast<TypedAttr>()) | |||
1972 | type = typedAttr.getType(); | |||
1973 | if (type.isa<NoneType, TensorType>()) { | |||
1974 | if (parser.parseColonType(type)) | |||
1975 | return failure(); | |||
1976 | } | |||
1977 | ||||
1978 | return parser.addTypeToList(type, result.types); | |||
1979 | } | |||
1980 | ||||
1981 | void spirv::ConstantOp::print(OpAsmPrinter &printer) { | |||
1982 | printer << ' ' << getValue(); | |||
1983 | if (getType().isa<spirv::ArrayType>()) | |||
1984 | printer << " : " << getType(); | |||
1985 | } | |||
1986 | ||||
1987 | static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, | |||
1988 | Type opType) { | |||
1989 | if (value.isa<IntegerAttr, FloatAttr>()) { | |||
1990 | auto valueType = value.cast<TypedAttr>().getType(); | |||
1991 | if (valueType != opType) | |||
1992 | return op.emitOpError("result type (") | |||
1993 | << opType << ") does not match value type (" << valueType << ")"; | |||
1994 | return success(); | |||
1995 | } | |||
1996 | if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) { | |||
1997 | auto valueType = value.cast<TypedAttr>().getType(); | |||
1998 | if (valueType == opType) | |||
1999 | return success(); | |||
2000 | auto arrayType = opType.dyn_cast<spirv::ArrayType>(); | |||
2001 | auto shapedType = valueType.dyn_cast<ShapedType>(); | |||
2002 | if (!arrayType) | |||
2003 | return op.emitOpError("result or element type (") | |||
2004 | << opType << ") does not match value type (" << valueType | |||
2005 | << "), must be the same or spirv.array"; | |||
2006 | ||||
2007 | int numElements = arrayType.getNumElements(); | |||
2008 | auto opElemType = arrayType.getElementType(); | |||
2009 | while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) { | |||
2010 | numElements *= t.getNumElements(); | |||
2011 | opElemType = t.getElementType(); | |||
2012 | } | |||
2013 | if (!opElemType.isIntOrFloat()) | |||
2014 | return op.emitOpError("only support nested array result type"); | |||
2015 | ||||
2016 | auto valueElemType = shapedType.getElementType(); | |||
2017 | if (valueElemType != opElemType) { | |||
2018 | return op.emitOpError("result element type (") | |||
2019 | << opElemType << ") does not match value element type (" | |||
2020 | << valueElemType << ")"; | |||
2021 | } | |||
2022 | ||||
2023 | if (numElements != shapedType.getNumElements()) { | |||
2024 | return op.emitOpError("result number of elements (") | |||
2025 | << numElements << ") does not match value number of elements (" | |||
2026 | << shapedType.getNumElements() << ")"; | |||
2027 | } | |||
2028 | return success(); | |||
2029 | } | |||
2030 | if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) { | |||
2031 | auto arrayType = opType.dyn_cast<spirv::ArrayType>(); | |||
2032 | if (!arrayType) | |||
2033 | return op.emitOpError( | |||
2034 | "must have spirv.array result type for array value"); | |||
2035 | Type elemType = arrayType.getElementType(); | |||
2036 | for (Attribute element : arrayAttr.getValue()) { | |||
2037 | // Verify array elements recursively. | |||
2038 | if (failed(verifyConstantType(op, element, elemType))) | |||
2039 | return failure(); | |||
2040 | } | |||
2041 | return success(); | |||
2042 | } | |||
2043 | return op.emitOpError("cannot have attribute: ") << value; | |||
2044 | } | |||
2045 | ||||
2046 | LogicalResult spirv::ConstantOp::verify() { | |||
2047 | // ODS already generates checks to make sure the result type is valid. We just | |||
2048 | // need to additionally check that the value's attribute type is consistent | |||
2049 | // with the result type. | |||
2050 | return verifyConstantType(*this, getValueAttr(), getType()); | |||
2051 | } | |||
2052 | ||||
2053 | bool spirv::ConstantOp::isBuildableWith(Type type) { | |||
2054 | // Must be valid SPIR-V type first. | |||
2055 | if (!type.isa<spirv::SPIRVType>()) | |||
2056 | return false; | |||
2057 | ||||
2058 | if (isa<SPIRVDialect>(type.getDialect())) { | |||
2059 | // TODO: support constant struct | |||
2060 | return type.isa<spirv::ArrayType>(); | |||
2061 | } | |||
2062 | ||||
2063 | return true; | |||
2064 | } | |||
2065 | ||||
2066 | spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, | |||
2067 | OpBuilder &builder) { | |||
2068 | if (auto intType = type.dyn_cast<IntegerType>()) { | |||
2069 | unsigned width = intType.getWidth(); | |||
2070 | if (width == 1) | |||
2071 | return builder.create<spirv::ConstantOp>(loc, type, | |||
2072 | builder.getBoolAttr(false)); | |||
2073 | return builder.create<spirv::ConstantOp>( | |||
2074 | loc, type, builder.getIntegerAttr(type, APInt(width, 0))); | |||
2075 | } | |||
2076 | if (auto floatType = type.dyn_cast<FloatType>()) { | |||
2077 | return builder.create<spirv::ConstantOp>( | |||
2078 | loc, type, builder.getFloatAttr(floatType, 0.0)); | |||
2079 | } | |||
2080 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
2081 | Type elemType = vectorType.getElementType(); | |||
2082 | if (elemType.isa<IntegerType>()) { | |||
2083 | return builder.create<spirv::ConstantOp>( | |||
2084 | loc, type, | |||
2085 | DenseElementsAttr::get(vectorType, | |||
2086 | IntegerAttr::get(elemType, 0).getValue())); | |||
2087 | } | |||
2088 | if (elemType.isa<FloatType>()) { | |||
2089 | return builder.create<spirv::ConstantOp>( | |||
2090 | loc, type, | |||
2091 | DenseFPElementsAttr::get(vectorType, | |||
2092 | FloatAttr::get(elemType, 0.0).getValue())); | |||
2093 | } | |||
2094 | } | |||
2095 | ||||
2096 | llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2096); | |||
2097 | } | |||
2098 | ||||
2099 | spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, | |||
2100 | OpBuilder &builder) { | |||
2101 | if (auto intType = type.dyn_cast<IntegerType>()) { | |||
2102 | unsigned width = intType.getWidth(); | |||
2103 | if (width == 1) | |||
2104 | return builder.create<spirv::ConstantOp>(loc, type, | |||
2105 | builder.getBoolAttr(true)); | |||
2106 | return builder.create<spirv::ConstantOp>( | |||
2107 | loc, type, builder.getIntegerAttr(type, APInt(width, 1))); | |||
2108 | } | |||
2109 | if (auto floatType = type.dyn_cast<FloatType>()) { | |||
2110 | return builder.create<spirv::ConstantOp>( | |||
2111 | loc, type, builder.getFloatAttr(floatType, 1.0)); | |||
2112 | } | |||
2113 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
2114 | Type elemType = vectorType.getElementType(); | |||
2115 | if (elemType.isa<IntegerType>()) { | |||
2116 | return builder.create<spirv::ConstantOp>( | |||
2117 | loc, type, | |||
2118 | DenseElementsAttr::get(vectorType, | |||
2119 | IntegerAttr::get(elemType, 1).getValue())); | |||
2120 | } | |||
2121 | if (elemType.isa<FloatType>()) { | |||
2122 | return builder.create<spirv::ConstantOp>( | |||
2123 | loc, type, | |||
2124 | DenseFPElementsAttr::get(vectorType, | |||
2125 | FloatAttr::get(elemType, 1.0).getValue())); | |||
2126 | } | |||
2127 | } | |||
2128 | ||||
2129 | llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2129); | |||
2130 | } | |||
2131 | ||||
2132 | void mlir::spirv::ConstantOp::getAsmResultNames( | |||
2133 | llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { | |||
2134 | Type type = getType(); | |||
2135 | ||||
2136 | SmallString<32> specialNameBuffer; | |||
2137 | llvm::raw_svector_ostream specialName(specialNameBuffer); | |||
2138 | specialName << "cst"; | |||
2139 | ||||
2140 | IntegerType intTy = type.dyn_cast<IntegerType>(); | |||
2141 | ||||
2142 | if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) { | |||
2143 | if (intTy && intTy.getWidth() == 1) { | |||
2144 | return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); | |||
2145 | } | |||
2146 | ||||
2147 | if (intTy.isSignless()) { | |||
2148 | specialName << intCst.getInt(); | |||
2149 | } else { | |||
2150 | specialName << intCst.getSInt(); | |||
2151 | } | |||
2152 | } | |||
2153 | ||||
2154 | if (intTy || type.isa<FloatType>()) { | |||
2155 | specialName << '_' << type; | |||
2156 | } | |||
2157 | ||||
2158 | if (auto vecType = type.dyn_cast<VectorType>()) { | |||
2159 | specialName << "_vec_"; | |||
2160 | specialName << vecType.getDimSize(0); | |||
2161 | ||||
2162 | Type elementType = vecType.getElementType(); | |||
2163 | ||||
2164 | if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) { | |||
2165 | specialName << "x" << elementType; | |||
2166 | } | |||
2167 | } | |||
2168 | ||||
2169 | setNameFn(getResult(), specialName.str()); | |||
2170 | } | |||
2171 | ||||
2172 | void mlir::spirv::AddressOfOp::getAsmResultNames( | |||
2173 | llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { | |||
2174 | SmallString<32> specialNameBuffer; | |||
2175 | llvm::raw_svector_ostream specialName(specialNameBuffer); | |||
2176 | specialName << getVariable() << "_addr"; | |||
2177 | setNameFn(getResult(), specialName.str()); | |||
2178 | } | |||
2179 | ||||
2180 | //===----------------------------------------------------------------------===// | |||
2181 | // spirv.ControlBarrierOp | |||
2182 | //===----------------------------------------------------------------------===// | |||
2183 | ||||
2184 | LogicalResult spirv::ControlBarrierOp::verify() { | |||
2185 | return verifyMemorySemantics(getOperation(), getMemorySemantics()); | |||
2186 | } | |||
2187 | ||||
2188 | //===----------------------------------------------------------------------===// | |||
2189 | // spirv.ConvertFToSOp | |||
2190 | //===----------------------------------------------------------------------===// | |||
2191 | ||||
2192 | LogicalResult spirv::ConvertFToSOp::verify() { | |||
2193 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2194 | /*skipBitWidthCheck=*/true); | |||
2195 | } | |||
2196 | ||||
2197 | //===----------------------------------------------------------------------===// | |||
2198 | // spirv.ConvertFToUOp | |||
2199 | //===----------------------------------------------------------------------===// | |||
2200 | ||||
2201 | LogicalResult spirv::ConvertFToUOp::verify() { | |||
2202 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2203 | /*skipBitWidthCheck=*/true); | |||
2204 | } | |||
2205 | ||||
2206 | //===----------------------------------------------------------------------===// | |||
2207 | // spirv.ConvertSToFOp | |||
2208 | //===----------------------------------------------------------------------===// | |||
2209 | ||||
2210 | LogicalResult spirv::ConvertSToFOp::verify() { | |||
2211 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2212 | /*skipBitWidthCheck=*/true); | |||
2213 | } | |||
2214 | ||||
2215 | //===----------------------------------------------------------------------===// | |||
2216 | // spirv.ConvertUToFOp | |||
2217 | //===----------------------------------------------------------------------===// | |||
2218 | ||||
2219 | LogicalResult spirv::ConvertUToFOp::verify() { | |||
2220 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2221 | /*skipBitWidthCheck=*/true); | |||
2222 | } | |||
2223 | ||||
2224 | //===----------------------------------------------------------------------===// | |||
2225 | // spirv.EntryPoint | |||
2226 | //===----------------------------------------------------------------------===// | |||
2227 | ||||
2228 | void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, | |||
2229 | spirv::ExecutionModel executionModel, | |||
2230 | spirv::FuncOp function, | |||
2231 | ArrayRef<Attribute> interfaceVars) { | |||
2232 | build(builder, state, | |||
2233 | spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), | |||
2234 | SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); | |||
2235 | } | |||
2236 | ||||
2237 | ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, | |||
2238 | OperationState &result) { | |||
2239 | spirv::ExecutionModel execModel; | |||
2240 | SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers; | |||
2241 | SmallVector<Type, 0> idTypes; | |||
2242 | SmallVector<Attribute, 4> interfaceVars; | |||
2243 | ||||
2244 | FlatSymbolRefAttr fn; | |||
2245 | if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) || | |||
2246 | parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { | |||
2247 | return failure(); | |||
2248 | } | |||
2249 | ||||
2250 | if (!parser.parseOptionalComma()) { | |||
2251 | // Parse the interface variables | |||
2252 | if (parser.parseCommaSeparatedList([&]() -> ParseResult { | |||
2253 | // The name of the interface variable attribute isnt important | |||
2254 | FlatSymbolRefAttr var; | |||
2255 | NamedAttrList attrs; | |||
2256 | if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) | |||
2257 | return failure(); | |||
2258 | interfaceVars.push_back(var); | |||
2259 | return success(); | |||
2260 | })) | |||
2261 | return failure(); | |||
2262 | } | |||
2263 | result.addAttribute(kInterfaceAttrName, | |||
2264 | parser.getBuilder().getArrayAttr(interfaceVars)); | |||
2265 | return success(); | |||
2266 | } | |||
2267 | ||||
2268 | void spirv::EntryPointOp::print(OpAsmPrinter &printer) { | |||
2269 | printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; | |||
2270 | printer.printSymbolName(getFn()); | |||
2271 | auto interfaceVars = getInterface().getValue(); | |||
2272 | if (!interfaceVars.empty()) { | |||
2273 | printer << ", "; | |||
2274 | llvm::interleaveComma(interfaceVars, printer); | |||
2275 | } | |||
2276 | } | |||
2277 | ||||
2278 | LogicalResult spirv::EntryPointOp::verify() { | |||
2279 | // Checks for fn and interface symbol reference are done in spirv::ModuleOp | |||
2280 | // verification. | |||
2281 | return success(); | |||
2282 | } | |||
2283 | ||||
2284 | //===----------------------------------------------------------------------===// | |||
2285 | // spirv.ExecutionMode | |||
2286 | //===----------------------------------------------------------------------===// | |||
2287 | ||||
2288 | void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, | |||
2289 | spirv::FuncOp function, | |||
2290 | spirv::ExecutionMode executionMode, | |||
2291 | ArrayRef<int32_t> params) { | |||
2292 | build(builder, state, SymbolRefAttr::get(function), | |||
2293 | spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), | |||
2294 | builder.getI32ArrayAttr(params)); | |||
2295 | } | |||
2296 | ||||
2297 | ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, | |||
2298 | OperationState &result) { | |||
2299 | spirv::ExecutionMode execMode; | |||
2300 | Attribute fn; | |||
2301 | if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) || | |||
2302 | parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) { | |||
2303 | return failure(); | |||
2304 | } | |||
2305 | ||||
2306 | SmallVector<int32_t, 4> values; | |||
2307 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
2308 | while (!parser.parseOptionalComma()) { | |||
2309 | NamedAttrList attr; | |||
2310 | Attribute value; | |||
2311 | if (parser.parseAttribute(value, i32Type, "value", attr)) { | |||
2312 | return failure(); | |||
2313 | } | |||
2314 | values.push_back(value.cast<IntegerAttr>().getInt()); | |||
2315 | } | |||
2316 | result.addAttribute(kValuesAttrName, | |||
2317 | parser.getBuilder().getI32ArrayAttr(values)); | |||
2318 | return success(); | |||
2319 | } | |||
2320 | ||||
2321 | void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { | |||
2322 | printer << " "; | |||
2323 | printer.printSymbolName(getFn()); | |||
2324 | printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; | |||
2325 | auto values = this->getValues(); | |||
2326 | if (values.empty()) | |||
2327 | return; | |||
2328 | printer << ", "; | |||
2329 | llvm::interleaveComma(values, printer, [&](Attribute a) { | |||
2330 | printer << a.cast<IntegerAttr>().getInt(); | |||
2331 | }); | |||
2332 | } | |||
2333 | ||||
2334 | //===----------------------------------------------------------------------===// | |||
2335 | // spirv.FConvertOp | |||
2336 | //===----------------------------------------------------------------------===// | |||
2337 | ||||
2338 | LogicalResult spirv::FConvertOp::verify() { | |||
2339 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2340 | } | |||
2341 | ||||
2342 | //===----------------------------------------------------------------------===// | |||
2343 | // spirv.SConvertOp | |||
2344 | //===----------------------------------------------------------------------===// | |||
2345 | ||||
2346 | LogicalResult spirv::SConvertOp::verify() { | |||
2347 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2348 | } | |||
2349 | ||||
2350 | //===----------------------------------------------------------------------===// | |||
2351 | // spirv.UConvertOp | |||
2352 | //===----------------------------------------------------------------------===// | |||
2353 | ||||
2354 | LogicalResult spirv::UConvertOp::verify() { | |||
2355 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2356 | } | |||
2357 | ||||
2358 | //===----------------------------------------------------------------------===// | |||
2359 | // spirv.func | |||
2360 | //===----------------------------------------------------------------------===// | |||
2361 | ||||
2362 | ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { | |||
2363 | SmallVector<OpAsmParser::Argument> entryArgs; | |||
2364 | SmallVector<DictionaryAttr> resultAttrs; | |||
2365 | SmallVector<Type> resultTypes; | |||
2366 | auto &builder = parser.getBuilder(); | |||
2367 | ||||
2368 | // Parse the name as a symbol. | |||
2369 | StringAttr nameAttr; | |||
2370 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
2371 | result.attributes)) | |||
2372 | return failure(); | |||
2373 | ||||
2374 | // Parse the function signature. | |||
2375 | bool isVariadic = false; | |||
2376 | if (function_interface_impl::parseFunctionSignature( | |||
2377 | parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, | |||
2378 | resultAttrs)) | |||
2379 | return failure(); | |||
2380 | ||||
2381 | SmallVector<Type> argTypes; | |||
2382 | for (auto &arg : entryArgs) | |||
2383 | argTypes.push_back(arg.type); | |||
2384 | auto fnType = builder.getFunctionType(argTypes, resultTypes); | |||
2385 | result.addAttribute(getFunctionTypeAttrName(result.name), | |||
2386 | TypeAttr::get(fnType)); | |||
2387 | ||||
2388 | // Parse the optional function control keyword. | |||
2389 | spirv::FunctionControl fnControl; | |||
2390 | if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result)) | |||
2391 | return failure(); | |||
2392 | ||||
2393 | // If additional attributes are present, parse them. | |||
2394 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) | |||
2395 | return failure(); | |||
2396 | ||||
2397 | // Add the attributes to the function arguments. | |||
2398 | assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes. size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2398, __extension__ __PRETTY_FUNCTION__)); | |||
2399 | function_interface_impl::addArgAndResultAttrs( | |||
2400 | builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), | |||
2401 | getResAttrsAttrName(result.name)); | |||
2402 | ||||
2403 | // Parse the optional function body. | |||
2404 | auto *body = result.addRegion(); | |||
2405 | OptionalParseResult parseResult = | |||
2406 | parser.parseOptionalRegion(*body, entryArgs); | |||
2407 | return failure(parseResult.has_value() && failed(*parseResult)); | |||
2408 | } | |||
2409 | ||||
2410 | void spirv::FuncOp::print(OpAsmPrinter &printer) { | |||
2411 | // Print function name, signature, and control. | |||
2412 | printer << " "; | |||
2413 | printer.printSymbolName(getSymName()); | |||
2414 | auto fnType = getFunctionType(); | |||
2415 | function_interface_impl::printFunctionSignature( | |||
2416 | printer, *this, fnType.getInputs(), | |||
2417 | /*isVariadic=*/false, fnType.getResults()); | |||
2418 | printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) | |||
2419 | << "\""; | |||
2420 | function_interface_impl::printFunctionAttributes( | |||
2421 | printer, *this, | |||
2422 | {spirv::attributeName<spirv::FunctionControl>(), | |||
2423 | getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), | |||
2424 | getFunctionControlAttrName()}); | |||
2425 | ||||
2426 | // Print the body if this is not an external function. | |||
2427 | Region &body = this->getBody(); | |||
2428 | if (!body.empty()) { | |||
2429 | printer << ' '; | |||
2430 | printer.printRegion(body, /*printEntryBlockArgs=*/false, | |||
2431 | /*printBlockTerminators=*/true); | |||
2432 | } | |||
2433 | } | |||
2434 | ||||
2435 | LogicalResult spirv::FuncOp::verifyType() { | |||
2436 | if (getFunctionType().getNumResults() > 1) | |||
2437 | return emitOpError("cannot have more than one result"); | |||
2438 | return success(); | |||
2439 | } | |||
2440 | ||||
2441 | LogicalResult spirv::FuncOp::verifyBody() { | |||
2442 | FunctionType fnType = getFunctionType(); | |||
2443 | ||||
2444 | auto walkResult = walk([fnType](Operation *op) -> WalkResult { | |||
2445 | if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) { | |||
2446 | if (fnType.getNumResults() != 0) | |||
2447 | return retOp.emitOpError("cannot be used in functions returning value"); | |||
2448 | } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) { | |||
2449 | if (fnType.getNumResults() != 1) | |||
2450 | return retOp.emitOpError( | |||
2451 | "returns 1 value but enclosing function requires ") | |||
2452 | << fnType.getNumResults() << " results"; | |||
2453 | ||||
2454 | auto retOperandType = retOp.getValue().getType(); | |||
2455 | auto fnResultType = fnType.getResult(0); | |||
2456 | if (retOperandType != fnResultType) | |||
2457 | return retOp.emitOpError(" return value's type (") | |||
2458 | << retOperandType << ") mismatch with function's result type (" | |||
2459 | << fnResultType << ")"; | |||
2460 | } | |||
2461 | return WalkResult::advance(); | |||
2462 | }); | |||
2463 | ||||
2464 | // TODO: verify other bits like linkage type. | |||
2465 | ||||
2466 | return failure(walkResult.wasInterrupted()); | |||
2467 | } | |||
2468 | ||||
2469 | void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, | |||
2470 | StringRef name, FunctionType type, | |||
2471 | spirv::FunctionControl control, | |||
2472 | ArrayRef<NamedAttribute> attrs) { | |||
2473 | state.addAttribute(SymbolTable::getSymbolAttrName(), | |||
2474 | builder.getStringAttr(name)); | |||
2475 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); | |||
2476 | state.addAttribute(spirv::attributeName<spirv::FunctionControl>(), | |||
2477 | builder.getAttr<spirv::FunctionControlAttr>(control)); | |||
2478 | state.attributes.append(attrs.begin(), attrs.end()); | |||
2479 | state.addRegion(); | |||
2480 | } | |||
2481 | ||||
2482 | // CallableOpInterface | |||
2483 | Region *spirv::FuncOp::getCallableRegion() { | |||
2484 | return isExternal() ? nullptr : &getBody(); | |||
2485 | } | |||
2486 | ||||
2487 | // CallableOpInterface | |||
2488 | ArrayRef<Type> spirv::FuncOp::getCallableResults() { | |||
2489 | return getFunctionType().getResults(); | |||
2490 | } | |||
2491 | ||||
2492 | //===----------------------------------------------------------------------===// | |||
2493 | // spirv.FunctionCall | |||
2494 | //===----------------------------------------------------------------------===// | |||
2495 | ||||
2496 | LogicalResult spirv::FunctionCallOp::verify() { | |||
2497 | auto fnName = getCalleeAttr(); | |||
2498 | ||||
2499 | auto funcOp = dyn_cast_or_null<spirv::FuncOp>( | |||
2500 | SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); | |||
2501 | if (!funcOp) { | |||
2502 | return emitOpError("callee function '") | |||
2503 | << fnName.getValue() << "' not found in nearest symbol table"; | |||
2504 | } | |||
2505 | ||||
2506 | auto functionType = funcOp.getFunctionType(); | |||
2507 | ||||
2508 | if (getNumResults() > 1) { | |||
2509 | return emitOpError( | |||
2510 | "expected callee function to have 0 or 1 result, but provided ") | |||
2511 | << getNumResults(); | |||
2512 | } | |||
2513 | ||||
2514 | if (functionType.getNumInputs() != getNumOperands()) { | |||
2515 | return emitOpError("has incorrect number of operands for callee: expected ") | |||
2516 | << functionType.getNumInputs() << ", but provided " | |||
2517 | << getNumOperands(); | |||
2518 | } | |||
2519 | ||||
2520 | for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { | |||
2521 | if (getOperand(i).getType() != functionType.getInput(i)) { | |||
2522 | return emitOpError("operand type mismatch: expected operand type ") | |||
2523 | << functionType.getInput(i) << ", but provided " | |||
2524 | << getOperand(i).getType() << " for operand number " << i; | |||
2525 | } | |||
2526 | } | |||
2527 | ||||
2528 | if (functionType.getNumResults() != getNumResults()) { | |||
2529 | return emitOpError( | |||
2530 | "has incorrect number of results has for callee: expected ") | |||
2531 | << functionType.getNumResults() << ", but provided " | |||
2532 | << getNumResults(); | |||
2533 | } | |||
2534 | ||||
2535 | if (getNumResults() && | |||
2536 | (getResult(0).getType() != functionType.getResult(0))) { | |||
2537 | return emitOpError("result type mismatch: expected ") | |||
2538 | << functionType.getResult(0) << ", but provided " | |||
2539 | << getResult(0).getType(); | |||
2540 | } | |||
2541 | ||||
2542 | return success(); | |||
2543 | } | |||
2544 | ||||
2545 | CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() { | |||
2546 | return (*this)->getAttrOfType<SymbolRefAttr>(kCallee); | |||
2547 | } | |||
2548 | ||||
2549 | Operation::operand_range spirv::FunctionCallOp::getArgOperands() { | |||
2550 | return getArguments(); | |||
2551 | } | |||
2552 | ||||
2553 | //===----------------------------------------------------------------------===// | |||
2554 | // spirv.GLFClampOp | |||
2555 | //===----------------------------------------------------------------------===// | |||
2556 | ||||
2557 | ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, | |||
2558 | OperationState &result) { | |||
2559 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2560 | } | |||
2561 | void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2562 | ||||
2563 | //===----------------------------------------------------------------------===// | |||
2564 | // spirv.GLUClampOp | |||
2565 | //===----------------------------------------------------------------------===// | |||
2566 | ||||
2567 | ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, | |||
2568 | OperationState &result) { | |||
2569 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2570 | } | |||
2571 | void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2572 | ||||
2573 | //===----------------------------------------------------------------------===// | |||
2574 | // spirv.GLSClampOp | |||
2575 | //===----------------------------------------------------------------------===// | |||
2576 | ||||
2577 | ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, | |||
2578 | OperationState &result) { | |||
2579 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2580 | } | |||
2581 | void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2582 | ||||
2583 | //===----------------------------------------------------------------------===// | |||
2584 | // spirv.GLFmaOp | |||
2585 | //===----------------------------------------------------------------------===// | |||
2586 | ||||
2587 | ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { | |||
2588 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2589 | } | |||
2590 | void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2591 | ||||
2592 | //===----------------------------------------------------------------------===// | |||
2593 | // spirv.GlobalVariable | |||
2594 | //===----------------------------------------------------------------------===// | |||
2595 | ||||
2596 | void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, | |||
2597 | Type type, StringRef name, | |||
2598 | unsigned descriptorSet, unsigned binding) { | |||
2599 | build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); | |||
2600 | state.addAttribute( | |||
2601 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), | |||
2602 | builder.getI32IntegerAttr(descriptorSet)); | |||
2603 | state.addAttribute( | |||
2604 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), | |||
2605 | builder.getI32IntegerAttr(binding)); | |||
2606 | } | |||
2607 | ||||
2608 | void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, | |||
2609 | Type type, StringRef name, | |||
2610 | spirv::BuiltIn builtin) { | |||
2611 | build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); | |||
2612 | state.addAttribute( | |||
2613 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), | |||
2614 | builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); | |||
2615 | } | |||
2616 | ||||
2617 | ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, | |||
2618 | OperationState &result) { | |||
2619 | // Parse variable name. | |||
2620 | StringAttr nameAttr; | |||
2621 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
2622 | result.attributes)) { | |||
2623 | return failure(); | |||
2624 | } | |||
2625 | ||||
2626 | // Parse optional initializer | |||
2627 | if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { | |||
2628 | FlatSymbolRefAttr initSymbol; | |||
2629 | if (parser.parseLParen() || | |||
2630 | parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, | |||
2631 | result.attributes) || | |||
2632 | parser.parseRParen()) | |||
2633 | return failure(); | |||
2634 | } | |||
2635 | ||||
2636 | if (parseVariableDecorations(parser, result)) { | |||
2637 | return failure(); | |||
2638 | } | |||
2639 | ||||
2640 | Type type; | |||
2641 | auto loc = parser.getCurrentLocation(); | |||
2642 | if (parser.parseColonType(type)) { | |||
2643 | return failure(); | |||
2644 | } | |||
2645 | if (!type.isa<spirv::PointerType>()) { | |||
2646 | return parser.emitError(loc, "expected spirv.ptr type"); | |||
2647 | } | |||
2648 | result.addAttribute(kTypeAttrName, TypeAttr::get(type)); | |||
2649 | ||||
2650 | return success(); | |||
2651 | } | |||
2652 | ||||
2653 | void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { | |||
2654 | SmallVector<StringRef, 4> elidedAttrs{ | |||
2655 | spirv::attributeName<spirv::StorageClass>()}; | |||
2656 | ||||
2657 | // Print variable name. | |||
2658 | printer << ' '; | |||
2659 | printer.printSymbolName(getSymName()); | |||
2660 | elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); | |||
2661 | ||||
2662 | // Print optional initializer | |||
2663 | if (auto initializer = this->getInitializer()) { | |||
2664 | printer << " " << kInitializerAttrName << '('; | |||
2665 | printer.printSymbolName(*initializer); | |||
2666 | printer << ')'; | |||
2667 | elidedAttrs.push_back(kInitializerAttrName); | |||
2668 | } | |||
2669 | ||||
2670 | elidedAttrs.push_back(kTypeAttrName); | |||
2671 | printVariableDecorations(*this, printer, elidedAttrs); | |||
2672 | printer << " : " << getType(); | |||
2673 | } | |||
2674 | ||||
2675 | LogicalResult spirv::GlobalVariableOp::verify() { | |||
2676 | if (!getType().isa<spirv::PointerType>()) | |||
2677 | return emitOpError("result must be of a !spv.ptr type"); | |||
2678 | ||||
2679 | // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the | |||
2680 | // object. It cannot be Generic. It must be the same as the Storage Class | |||
2681 | // operand of the Result Type." | |||
2682 | // Also, Function storage class is reserved by spirv.Variable. | |||
2683 | auto storageClass = this->storageClass(); | |||
2684 | if (storageClass == spirv::StorageClass::Generic || | |||
2685 | storageClass == spirv::StorageClass::Function) { | |||
2686 | return emitOpError("storage class cannot be '") | |||
2687 | << stringifyStorageClass(storageClass) << "'"; | |||
2688 | } | |||
2689 | ||||
2690 | if (auto init = | |||
2691 | (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) { | |||
2692 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( | |||
2693 | (*this)->getParentOp(), init.getAttr()); | |||
2694 | // TODO: Currently only variable initialization with specialization | |||
2695 | // constants and other variables is supported. They could be normal | |||
2696 | // constants in the module scope as well. | |||
2697 | if (!initOp || | |||
2698 | !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) { | |||
2699 | return emitOpError("initializer must be result of a " | |||
2700 | "spirv.SpecConstant or spirv.GlobalVariable op"); | |||
2701 | } | |||
2702 | } | |||
2703 | ||||
2704 | return success(); | |||
2705 | } | |||
2706 | ||||
2707 | //===----------------------------------------------------------------------===// | |||
2708 | // spirv.GroupBroadcast | |||
2709 | //===----------------------------------------------------------------------===// | |||
2710 | ||||
2711 | LogicalResult spirv::GroupBroadcastOp::verify() { | |||
2712 | spirv::Scope scope = getExecutionScope(); | |||
2713 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2714 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2715 | ||||
2716 | if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>()) | |||
2717 | if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) | |||
2718 | return emitOpError("localid is a vector and can be with only " | |||
2719 | " 2 or 3 components, actual number is ") | |||
2720 | << localIdTy.getNumElements(); | |||
2721 | ||||
2722 | return success(); | |||
2723 | } | |||
2724 | ||||
2725 | //===----------------------------------------------------------------------===// | |||
2726 | // spirv.GroupNonUniformBallotOp | |||
2727 | //===----------------------------------------------------------------------===// | |||
2728 | ||||
2729 | LogicalResult spirv::GroupNonUniformBallotOp::verify() { | |||
2730 | spirv::Scope scope = getExecutionScope(); | |||
2731 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2732 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2733 | ||||
2734 | return success(); | |||
2735 | } | |||
2736 | ||||
2737 | //===----------------------------------------------------------------------===// | |||
2738 | // spirv.GroupNonUniformBroadcast | |||
2739 | //===----------------------------------------------------------------------===// | |||
2740 | ||||
2741 | LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { | |||
2742 | spirv::Scope scope = getExecutionScope(); | |||
2743 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2744 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2745 | ||||
2746 | // SPIR-V spec: "Before version 1.5, Id must come from a | |||
2747 | // constant instruction. | |||
2748 | auto targetEnv = spirv::getDefaultTargetEnv(getContext()); | |||
2749 | if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>()) | |||
2750 | targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); | |||
2751 | ||||
2752 | if (targetEnv.getVersion() < spirv::Version::V_1_5) { | |||
2753 | auto *idOp = getId().getDefiningOp(); | |||
2754 | if (!idOp || !isa<spirv::ConstantOp, // for normal constant | |||
2755 | spirv::ReferenceOfOp>(idOp)) // for spec constant | |||
2756 | return emitOpError("id must be the result of a constant op"); | |||
2757 | } | |||
2758 | ||||
2759 | return success(); | |||
2760 | } | |||
2761 | ||||
2762 | //===----------------------------------------------------------------------===// | |||
2763 | // spirv.GroupNonUniformShuffle* | |||
2764 | //===----------------------------------------------------------------------===// | |||
2765 | ||||
2766 | template <typename OpTy> | |||
2767 | static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { | |||
2768 | spirv::Scope scope = op.getExecutionScope(); | |||
2769 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2770 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2771 | ||||
2772 | if (op.getOperands().back().getType().isSignedInteger()) | |||
2773 | return op.emitOpError("second operand must be a singless/unsigned integer"); | |||
2774 | ||||
2775 | return success(); | |||
2776 | } | |||
2777 | ||||
2778 | LogicalResult spirv::GroupNonUniformShuffleOp::verify() { | |||
2779 | return verifyGroupNonUniformShuffleOp(*this); | |||
2780 | } | |||
2781 | LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() { | |||
2782 | return verifyGroupNonUniformShuffleOp(*this); | |||
2783 | } | |||
2784 | LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() { | |||
2785 | return verifyGroupNonUniformShuffleOp(*this); | |||
2786 | } | |||
2787 | LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() { | |||
2788 | return verifyGroupNonUniformShuffleOp(*this); | |||
2789 | } | |||
2790 | ||||
2791 | //===----------------------------------------------------------------------===// | |||
2792 | // spirv.INTEL.SubgroupBlockRead | |||
2793 | //===----------------------------------------------------------------------===// | |||
2794 | ||||
2795 | ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser, | |||
2796 | OperationState &result) { | |||
2797 | // Parse the storage class specification | |||
2798 | spirv::StorageClass storageClass; | |||
2799 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
2800 | Type elementType; | |||
2801 | if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || | |||
2802 | parser.parseColon() || parser.parseType(elementType)) { | |||
2803 | return failure(); | |||
2804 | } | |||
2805 | ||||
2806 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
2807 | if (auto valVecTy = elementType.dyn_cast<VectorType>()) | |||
2808 | ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); | |||
2809 | ||||
2810 | if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { | |||
2811 | return failure(); | |||
2812 | } | |||
2813 | ||||
2814 | result.addTypes(elementType); | |||
2815 | return success(); | |||
2816 | } | |||
2817 | ||||
2818 | void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { | |||
2819 | printer << " " << getPtr() << " : " << getType(); | |||
2820 | } | |||
2821 | ||||
2822 | LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { | |||
2823 | if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) | |||
2824 | return failure(); | |||
2825 | ||||
2826 | return success(); | |||
2827 | } | |||
2828 | ||||
2829 | //===----------------------------------------------------------------------===// | |||
2830 | // spirv.INTEL.SubgroupBlockWrite | |||
2831 | //===----------------------------------------------------------------------===// | |||
2832 | ||||
2833 | ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, | |||
2834 | OperationState &result) { | |||
2835 | // Parse the storage class specification | |||
2836 | spirv::StorageClass storageClass; | |||
2837 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
2838 | auto loc = parser.getCurrentLocation(); | |||
2839 | Type elementType; | |||
2840 | if (parseEnumStrAttr(storageClass, parser) || | |||
2841 | parser.parseOperandList(operandInfo, 2) || parser.parseColon() || | |||
2842 | parser.parseType(elementType)) { | |||
2843 | return failure(); | |||
2844 | } | |||
2845 | ||||
2846 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
2847 | if (auto valVecTy = elementType.dyn_cast<VectorType>()) | |||
2848 | ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); | |||
2849 | ||||
2850 | if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, | |||
2851 | result.operands)) { | |||
2852 | return failure(); | |||
2853 | } | |||
2854 | return success(); | |||
2855 | } | |||
2856 | ||||
2857 | void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { | |||
2858 | printer << " " << getPtr() << ", " << getValue() << " : " | |||
2859 | << getValue().getType(); | |||
2860 | } | |||
2861 | ||||
2862 | LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { | |||
2863 | if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) | |||
2864 | return failure(); | |||
2865 | ||||
2866 | return success(); | |||
2867 | } | |||
2868 | ||||
2869 | //===----------------------------------------------------------------------===// | |||
2870 | // spirv.GroupNonUniformElectOp | |||
2871 | //===----------------------------------------------------------------------===// | |||
2872 | ||||
2873 | LogicalResult spirv::GroupNonUniformElectOp::verify() { | |||
2874 | spirv::Scope scope = getExecutionScope(); | |||
2875 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2876 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2877 | ||||
2878 | return success(); | |||
2879 | } | |||
2880 | ||||
2881 | //===----------------------------------------------------------------------===// | |||
2882 | // spirv.GroupNonUniformFAddOp | |||
2883 | //===----------------------------------------------------------------------===// | |||
2884 | ||||
2885 | LogicalResult spirv::GroupNonUniformFAddOp::verify() { | |||
2886 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2887 | } | |||
2888 | ||||
2889 | ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser, | |||
2890 | OperationState &result) { | |||
2891 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2892 | } | |||
2893 | void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) { | |||
2894 | printGroupNonUniformArithmeticOp(*this, p); | |||
2895 | } | |||
2896 | ||||
2897 | //===----------------------------------------------------------------------===// | |||
2898 | // spirv.GroupNonUniformFMaxOp | |||
2899 | //===----------------------------------------------------------------------===// | |||
2900 | ||||
2901 | LogicalResult spirv::GroupNonUniformFMaxOp::verify() { | |||
2902 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2903 | } | |||
2904 | ||||
2905 | ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser, | |||
2906 | OperationState &result) { | |||
2907 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2908 | } | |||
2909 | void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { | |||
2910 | printGroupNonUniformArithmeticOp(*this, p); | |||
2911 | } | |||
2912 | ||||
2913 | //===----------------------------------------------------------------------===// | |||
2914 | // spirv.GroupNonUniformFMinOp | |||
2915 | //===----------------------------------------------------------------------===// | |||
2916 | ||||
2917 | LogicalResult spirv::GroupNonUniformFMinOp::verify() { | |||
2918 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2919 | } | |||
2920 | ||||
2921 | ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser, | |||
2922 | OperationState &result) { | |||
2923 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2924 | } | |||
2925 | void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) { | |||
2926 | printGroupNonUniformArithmeticOp(*this, p); | |||
2927 | } | |||
2928 | ||||
2929 | //===----------------------------------------------------------------------===// | |||
2930 | // spirv.GroupNonUniformFMulOp | |||
2931 | //===----------------------------------------------------------------------===// | |||
2932 | ||||
2933 | LogicalResult spirv::GroupNonUniformFMulOp::verify() { | |||
2934 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2935 | } | |||
2936 | ||||
2937 | ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser, | |||
2938 | OperationState &result) { | |||
2939 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2940 | } | |||
2941 | void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) { | |||
2942 | printGroupNonUniformArithmeticOp(*this, p); | |||
2943 | } | |||
2944 | ||||
2945 | //===----------------------------------------------------------------------===// | |||
2946 | // spirv.GroupNonUniformIAddOp | |||
2947 | //===----------------------------------------------------------------------===// | |||
2948 | ||||
2949 | LogicalResult spirv::GroupNonUniformIAddOp::verify() { | |||
2950 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2951 | } | |||
2952 | ||||
2953 | ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser, | |||
2954 | OperationState &result) { | |||
2955 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2956 | } | |||
2957 | void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) { | |||
2958 | printGroupNonUniformArithmeticOp(*this, p); | |||
2959 | } | |||
2960 | ||||
2961 | //===----------------------------------------------------------------------===// | |||
2962 | // spirv.GroupNonUniformIMulOp | |||
2963 | //===----------------------------------------------------------------------===// | |||
2964 | ||||
2965 | LogicalResult spirv::GroupNonUniformIMulOp::verify() { | |||
2966 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2967 | } | |||
2968 | ||||
2969 | ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser, | |||
2970 | OperationState &result) { | |||
2971 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2972 | } | |||
2973 | void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) { | |||
2974 | printGroupNonUniformArithmeticOp(*this, p); | |||
2975 | } | |||
2976 | ||||
2977 | //===----------------------------------------------------------------------===// | |||
2978 | // spirv.GroupNonUniformSMaxOp | |||
2979 | //===----------------------------------------------------------------------===// | |||
2980 | ||||
2981 | LogicalResult spirv::GroupNonUniformSMaxOp::verify() { | |||
2982 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2983 | } | |||
2984 | ||||
2985 | ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser, | |||
2986 | OperationState &result) { | |||
2987 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2988 | } | |||
2989 | void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { | |||
2990 | printGroupNonUniformArithmeticOp(*this, p); | |||
2991 | } | |||
2992 | ||||
2993 | //===----------------------------------------------------------------------===// | |||
2994 | // spirv.GroupNonUniformSMinOp | |||
2995 | //===----------------------------------------------------------------------===// | |||
2996 | ||||
2997 | LogicalResult spirv::GroupNonUniformSMinOp::verify() { | |||
2998 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2999 | } | |||
3000 | ||||
3001 | ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser, | |||
3002 | OperationState &result) { | |||
3003 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3004 | } | |||
3005 | void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) { | |||
3006 | printGroupNonUniformArithmeticOp(*this, p); | |||
3007 | } | |||
3008 | ||||
3009 | //===----------------------------------------------------------------------===// | |||
3010 | // spirv.GroupNonUniformUMaxOp | |||
3011 | //===----------------------------------------------------------------------===// | |||
3012 | ||||
3013 | LogicalResult spirv::GroupNonUniformUMaxOp::verify() { | |||
3014 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3015 | } | |||
3016 | ||||
3017 | ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser, | |||
3018 | OperationState &result) { | |||
3019 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3020 | } | |||
3021 | void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { | |||
3022 | printGroupNonUniformArithmeticOp(*this, p); | |||
3023 | } | |||
3024 | ||||
3025 | //===----------------------------------------------------------------------===// | |||
3026 | // spirv.GroupNonUniformUMinOp | |||
3027 | //===----------------------------------------------------------------------===// | |||
3028 | ||||
3029 | LogicalResult spirv::GroupNonUniformUMinOp::verify() { | |||
3030 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3031 | } | |||
3032 | ||||
3033 | ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser, | |||
3034 | OperationState &result) { | |||
3035 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3036 | } | |||
3037 | void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { | |||
3038 | printGroupNonUniformArithmeticOp(*this, p); | |||
3039 | } | |||
3040 | ||||
3041 | //===----------------------------------------------------------------------===// | |||
3042 | // spirv.IAddCarryOp | |||
3043 | //===----------------------------------------------------------------------===// | |||
3044 | ||||
3045 | LogicalResult spirv::IAddCarryOp::verify() { | |||
3046 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3047 | } | |||
3048 | ||||
3049 | ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, | |||
3050 | OperationState &result) { | |||
3051 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3052 | } | |||
3053 | ||||
3054 | void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { | |||
3055 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3056 | } | |||
3057 | ||||
3058 | //===----------------------------------------------------------------------===// | |||
3059 | // spirv.ISubBorrowOp | |||
3060 | //===----------------------------------------------------------------------===// | |||
3061 | ||||
3062 | LogicalResult spirv::ISubBorrowOp::verify() { | |||
3063 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3064 | } | |||
3065 | ||||
3066 | ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, | |||
3067 | OperationState &result) { | |||
3068 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3069 | } | |||
3070 | ||||
3071 | void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { | |||
3072 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3073 | } | |||
3074 | ||||
3075 | //===----------------------------------------------------------------------===// | |||
3076 | // spirv.SMulExtended | |||
3077 | //===----------------------------------------------------------------------===// | |||
3078 | ||||
3079 | LogicalResult spirv::SMulExtendedOp::verify() { | |||
3080 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3081 | } | |||
3082 | ||||
3083 | ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, | |||
3084 | OperationState &result) { | |||
3085 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3086 | } | |||
3087 | ||||
3088 | void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { | |||
3089 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3090 | } | |||
3091 | ||||
3092 | //===----------------------------------------------------------------------===// | |||
3093 | // spirv.UMulExtended | |||
3094 | //===----------------------------------------------------------------------===// | |||
3095 | ||||
3096 | LogicalResult spirv::UMulExtendedOp::verify() { | |||
3097 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3098 | } | |||
3099 | ||||
3100 | ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, | |||
3101 | OperationState &result) { | |||
3102 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3103 | } | |||
3104 | ||||
3105 | void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { | |||
3106 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3107 | } | |||
3108 | ||||
3109 | //===----------------------------------------------------------------------===// | |||
3110 | // spirv.LoadOp | |||
3111 | //===----------------------------------------------------------------------===// | |||
3112 | ||||
3113 | void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, | |||
3114 | Value basePtr, MemoryAccessAttr memoryAccess, | |||
3115 | IntegerAttr alignment) { | |||
3116 | auto ptrType = basePtr.getType().cast<spirv::PointerType>(); | |||
3117 | build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, | |||
3118 | alignment); | |||
3119 | } | |||
3120 | ||||
3121 | ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3122 | // Parse the storage class specification | |||
3123 | spirv::StorageClass storageClass; | |||
3124 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
3125 | Type elementType; | |||
3126 | if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || | |||
3127 | parseMemoryAccessAttributes(parser, result) || | |||
3128 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || | |||
3129 | parser.parseType(elementType)) { | |||
3130 | return failure(); | |||
3131 | } | |||
3132 | ||||
3133 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
3134 | if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { | |||
3135 | return failure(); | |||
3136 | } | |||
3137 | ||||
3138 | result.addTypes(elementType); | |||
3139 | return success(); | |||
3140 | } | |||
3141 | ||||
3142 | void spirv::LoadOp::print(OpAsmPrinter &printer) { | |||
3143 | SmallVector<StringRef, 4> elidedAttrs; | |||
3144 | StringRef sc = stringifyStorageClass( | |||
3145 | getPtr().getType().cast<spirv::PointerType>().getStorageClass()); | |||
3146 | printer << " \"" << sc << "\" " << getPtr(); | |||
3147 | ||||
3148 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
3149 | ||||
3150 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
3151 | printer << " : " << getType(); | |||
3152 | } | |||
3153 | ||||
3154 | LogicalResult spirv::LoadOp::verify() { | |||
3155 | // SPIR-V spec : "Result Type is the type of the loaded object. It must be a | |||
3156 | // type with fixed size; i.e., it cannot be, nor include, any | |||
3157 | // OpTypeRuntimeArray types." | |||
3158 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { | |||
3159 | return failure(); | |||
3160 | } | |||
3161 | return verifyMemoryAccessAttribute(*this); | |||
3162 | } | |||
3163 | ||||
3164 | //===----------------------------------------------------------------------===// | |||
3165 | // spirv.mlir.loop | |||
3166 | //===----------------------------------------------------------------------===// | |||
3167 | ||||
3168 | void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { | |||
3169 | state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>( | |||
3170 | spirv::LoopControl::None)); | |||
3171 | state.addRegion(); | |||
3172 | } | |||
3173 | ||||
3174 | ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3175 | if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser, | |||
3176 | result)) | |||
3177 | return failure(); | |||
3178 | return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); | |||
3179 | } | |||
3180 | ||||
3181 | void spirv::LoopOp::print(OpAsmPrinter &printer) { | |||
3182 | auto control = getLoopControl(); | |||
3183 | if (control != spirv::LoopControl::None) | |||
3184 | printer << " control(" << spirv::stringifyLoopControl(control) << ")"; | |||
3185 | printer << ' '; | |||
3186 | printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, | |||
3187 | /*printBlockTerminators=*/true); | |||
3188 | } | |||
3189 | ||||
3190 | /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the | |||
3191 | /// given `dstBlock`. | |||
3192 | static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { | |||
3193 | // Check that there is only one op in the `srcBlock`. | |||
3194 | if (!llvm::hasSingleElement(srcBlock)) | |||
3195 | return false; | |||
3196 | ||||
3197 | auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back()); | |||
3198 | return branchOp && branchOp.getSuccessor() == &dstBlock; | |||
3199 | } | |||
3200 | ||||
3201 | LogicalResult spirv::LoopOp::verifyRegions() { | |||
3202 | auto *op = getOperation(); | |||
3203 | ||||
3204 | // We need to verify that the blocks follow the following layout: | |||
3205 | // | |||
3206 | // +-------------+ | |||
3207 | // | entry block | | |||
3208 | // +-------------+ | |||
3209 | // | | |||
3210 | // v | |||
3211 | // +-------------+ | |||
3212 | // | loop header | <-----+ | |||
3213 | // +-------------+ | | |||
3214 | // | | |||
3215 | // ... | | |||
3216 | // \ | / | | |||
3217 | // v | | |||
3218 | // +---------------+ | | |||
3219 | // | loop continue | -----+ | |||
3220 | // +---------------+ | |||
3221 | // | |||
3222 | // ... | |||
3223 | // \ | / | |||
3224 | // v | |||
3225 | // +-------------+ | |||
3226 | // | merge block | | |||
3227 | // +-------------+ | |||
3228 | ||||
3229 | auto ®ion = op->getRegion(0); | |||
3230 | // Allow empty region as a degenerated case, which can come from | |||
3231 | // optimizations. | |||
3232 | if (region.empty()) | |||
3233 | return success(); | |||
3234 | ||||
3235 | // The last block is the merge block. | |||
3236 | Block &merge = region.back(); | |||
3237 | if (!isMergeBlock(merge)) | |||
3238 | return emitOpError("last block must be the merge block with only one " | |||
3239 | "'spirv.mlir.merge' op"); | |||
3240 | ||||
3241 | if (std::next(region.begin()) == region.end()) | |||
3242 | return emitOpError( | |||
3243 | "must have an entry block branching to the loop header block"); | |||
3244 | // The first block is the entry block. | |||
3245 | Block &entry = region.front(); | |||
3246 | ||||
3247 | if (std::next(region.begin(), 2) == region.end()) | |||
3248 | return emitOpError( | |||
3249 | "must have a loop header block branched from the entry block"); | |||
3250 | // The second block is the loop header block. | |||
3251 | Block &header = *std::next(region.begin(), 1); | |||
3252 | ||||
3253 | if (!hasOneBranchOpTo(entry, header)) | |||
3254 | return emitOpError( | |||
3255 | "entry block must only have one 'spirv.Branch' op to the second block"); | |||
3256 | ||||
3257 | if (std::next(region.begin(), 3) == region.end()) | |||
3258 | return emitOpError( | |||
3259 | "requires a loop continue block branching to the loop header block"); | |||
3260 | // The second to last block is the loop continue block. | |||
3261 | Block &cont = *std::prev(region.end(), 2); | |||
3262 | ||||
3263 | // Make sure that we have a branch from the loop continue block to the loop | |||
3264 | // header block. | |||
3265 | if (llvm::none_of( | |||
3266 | llvm::seq<unsigned>(0, cont.getNumSuccessors()), | |||
3267 | [&](unsigned index) { return cont.getSuccessor(index) == &header; })) | |||
3268 | return emitOpError("second to last block must be the loop continue " | |||
3269 | "block that branches to the loop header block"); | |||
3270 | ||||
3271 | // Make sure that no other blocks (except the entry and loop continue block) | |||
3272 | // branches to the loop header block. | |||
3273 | for (auto &block : llvm::make_range(std::next(region.begin(), 2), | |||
3274 | std::prev(region.end(), 2))) { | |||
3275 | for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) { | |||
3276 | if (block.getSuccessor(i) == &header) { | |||
3277 | return emitOpError("can only have the entry and loop continue " | |||
3278 | "block branching to the loop header block"); | |||
3279 | } | |||
3280 | } | |||
3281 | } | |||
3282 | ||||
3283 | return success(); | |||
3284 | } | |||
3285 | ||||
3286 | Block *spirv::LoopOp::getEntryBlock() { | |||
3287 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3287, __extension__ __PRETTY_FUNCTION__)); | |||
3288 | return &getBody().front(); | |||
3289 | } | |||
3290 | ||||
3291 | Block *spirv::LoopOp::getHeaderBlock() { | |||
3292 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3292, __extension__ __PRETTY_FUNCTION__)); | |||
3293 | // The second block is the loop header block. | |||
3294 | return &*std::next(getBody().begin()); | |||
3295 | } | |||
3296 | ||||
3297 | Block *spirv::LoopOp::getContinueBlock() { | |||
3298 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3298, __extension__ __PRETTY_FUNCTION__)); | |||
3299 | // The second to last block is the loop continue block. | |||
3300 | return &*std::prev(getBody().end(), 2); | |||
3301 | } | |||
3302 | ||||
3303 | Block *spirv::LoopOp::getMergeBlock() { | |||
3304 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3304, __extension__ __PRETTY_FUNCTION__)); | |||
3305 | // The last block is the loop merge block. | |||
3306 | return &getBody().back(); | |||
3307 | } | |||
3308 | ||||
3309 | void spirv::LoopOp::addEntryAndMergeBlock() { | |||
3310 | assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist" ) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3310, __extension__ __PRETTY_FUNCTION__)); | |||
3311 | getBody().push_back(new Block()); | |||
3312 | auto *mergeBlock = new Block(); | |||
3313 | getBody().push_back(mergeBlock); | |||
3314 | OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); | |||
3315 | ||||
3316 | // Add a spirv.mlir.merge op into the merge block. | |||
3317 | builder.create<spirv::MergeOp>(getLoc()); | |||
3318 | } | |||
3319 | ||||
3320 | //===----------------------------------------------------------------------===// | |||
3321 | // spirv.MemoryBarrierOp | |||
3322 | //===----------------------------------------------------------------------===// | |||
3323 | ||||
3324 | LogicalResult spirv::MemoryBarrierOp::verify() { | |||
3325 | return verifyMemorySemantics(getOperation(), getMemorySemantics()); | |||
3326 | } | |||
3327 | ||||
3328 | //===----------------------------------------------------------------------===// | |||
3329 | // spirv.mlir.merge | |||
3330 | //===----------------------------------------------------------------------===// | |||
3331 | ||||
3332 | LogicalResult spirv::MergeOp::verify() { | |||
3333 | auto *parentOp = (*this)->getParentOp(); | |||
3334 | if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp)) | |||
3335 | return emitOpError( | |||
3336 | "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); | |||
3337 | ||||
3338 | // TODO: This check should be done in `verifyRegions` of parent op. | |||
3339 | Block &parentLastBlock = (*this)->getParentRegion()->back(); | |||
3340 | if (getOperation() != parentLastBlock.getTerminator()) | |||
3341 | return emitOpError("can only be used in the last block of " | |||
3342 | "'spirv.mlir.selection' or 'spirv.mlir.loop'"); | |||
3343 | return success(); | |||
3344 | } | |||
3345 | ||||
3346 | //===----------------------------------------------------------------------===// | |||
3347 | // spirv.module | |||
3348 | //===----------------------------------------------------------------------===// | |||
3349 | ||||
3350 | void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, | |||
3351 | std::optional<StringRef> name) { | |||
3352 | OpBuilder::InsertionGuard guard(builder); | |||
3353 | builder.createBlock(state.addRegion()); | |||
3354 | if (name) { | |||
3355 | state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), | |||
3356 | builder.getStringAttr(*name)); | |||
3357 | } | |||
3358 | } | |||
3359 | ||||
3360 | void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, | |||
3361 | spirv::AddressingModel addressingModel, | |||
3362 | spirv::MemoryModel memoryModel, | |||
3363 | std::optional<VerCapExtAttr> vceTriple, | |||
3364 | std::optional<StringRef> name) { | |||
3365 | state.addAttribute( | |||
3366 | "addressing_model", | |||
3367 | builder.getAttr<spirv::AddressingModelAttr>(addressingModel)); | |||
3368 | state.addAttribute("memory_model", | |||
3369 | builder.getAttr<spirv::MemoryModelAttr>(memoryModel)); | |||
3370 | OpBuilder::InsertionGuard guard(builder); | |||
3371 | builder.createBlock(state.addRegion()); | |||
3372 | if (vceTriple) | |||
3373 | state.addAttribute(getVCETripleAttrName(), *vceTriple); | |||
3374 | if (name) | |||
3375 | state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), | |||
3376 | builder.getStringAttr(*name)); | |||
3377 | } | |||
3378 | ||||
3379 | ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, | |||
3380 | OperationState &result) { | |||
3381 | Region *body = result.addRegion(); | |||
3382 | ||||
3383 | // If the name is present, parse it. | |||
3384 | StringAttr nameAttr; | |||
3385 | (void)parser.parseOptionalSymbolName( | |||
3386 | nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes); | |||
3387 | ||||
3388 | // Parse attributes | |||
3389 | spirv::AddressingModel addrModel; | |||
3390 | spirv::MemoryModel memoryModel; | |||
3391 | if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser, | |||
3392 | result) || | |||
3393 | ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser, | |||
3394 | result)) | |||
3395 | return failure(); | |||
3396 | ||||
3397 | if (succeeded(parser.parseOptionalKeyword("requires"))) { | |||
3398 | spirv::VerCapExtAttr vceTriple; | |||
3399 | if (parser.parseAttribute(vceTriple, | |||
3400 | spirv::ModuleOp::getVCETripleAttrName(), | |||
3401 | result.attributes)) | |||
3402 | return failure(); | |||
3403 | } | |||
3404 | ||||
3405 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || | |||
3406 | parser.parseRegion(*body, /*arguments=*/{})) | |||
3407 | return failure(); | |||
3408 | ||||
3409 | // Make sure we have at least one block. | |||
3410 | if (body->empty()) | |||
3411 | body->push_back(new Block()); | |||
3412 | ||||
3413 | return success(); | |||
3414 | } | |||
3415 | ||||
3416 | void spirv::ModuleOp::print(OpAsmPrinter &printer) { | |||
3417 | if (std::optional<StringRef> name = getName()) { | |||
3418 | printer << ' '; | |||
3419 | printer.printSymbolName(*name); | |||
3420 | } | |||
3421 | ||||
3422 | SmallVector<StringRef, 2> elidedAttrs; | |||
3423 | ||||
3424 | printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " | |||
3425 | << spirv::stringifyMemoryModel(getMemoryModel()); | |||
3426 | auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); | |||
3427 | auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); | |||
3428 | elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, | |||
3429 | mlir::SymbolTable::getSymbolAttrName()}); | |||
3430 | ||||
3431 | if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) { | |||
3432 | printer << " requires " << *triple; | |||
3433 | elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); | |||
3434 | } | |||
3435 | ||||
3436 | printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); | |||
3437 | printer << ' '; | |||
3438 | printer.printRegion(getRegion()); | |||
3439 | } | |||
3440 | ||||
3441 | LogicalResult spirv::ModuleOp::verifyRegions() { | |||
3442 | Dialect *dialect = (*this)->getDialect(); | |||
3443 | DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> | |||
3444 | entryPoints; | |||
3445 | mlir::SymbolTable table(*this); | |||
3446 | ||||
3447 | for (auto &op : *getBody()) { | |||
3448 | if (op.getDialect() != dialect) | |||
3449 | return op.emitError("'spirv.module' can only contain spirv.* ops"); | |||
3450 | ||||
3451 | // For EntryPoint op, check that the function and execution model is not | |||
3452 | // duplicated in EntryPointOps. Also verify that the interface specified | |||
3453 | // comes from globalVariables here to make this check cheaper. | |||
3454 | if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) { | |||
3455 | auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn()); | |||
3456 | if (!funcOp) { | |||
3457 | return entryPointOp.emitError("function '") | |||
3458 | << entryPointOp.getFn() << "' not found in 'spirv.module'"; | |||
3459 | } | |||
3460 | if (auto interface = entryPointOp.getInterface()) { | |||
3461 | for (Attribute varRef : interface) { | |||
3462 | auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>(); | |||
3463 | if (!varSymRef) { | |||
3464 | return entryPointOp.emitError( | |||
3465 | "expected symbol reference for interface " | |||
3466 | "specification instead of '") | |||
3467 | << varRef; | |||
3468 | } | |||
3469 | auto variableOp = | |||
3470 | table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); | |||
3471 | if (!variableOp) { | |||
3472 | return entryPointOp.emitError("expected spirv.GlobalVariable " | |||
3473 | "symbol reference instead of'") | |||
3474 | << varSymRef << "'"; | |||
3475 | } | |||
3476 | } | |||
3477 | } | |||
3478 | ||||
3479 | auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>( | |||
3480 | funcOp, entryPointOp.getExecutionModel()); | |||
3481 | auto entryPtIt = entryPoints.find(key); | |||
3482 | if (entryPtIt != entryPoints.end()) { | |||
3483 | return entryPointOp.emitError("duplicate of a previous EntryPointOp"); | |||
3484 | } | |||
3485 | entryPoints[key] = entryPointOp; | |||
3486 | } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) { | |||
3487 | if (funcOp.isExternal()) | |||
3488 | return op.emitError("'spirv.module' cannot contain external functions"); | |||
3489 | ||||
3490 | // TODO: move this check to spirv.func. | |||
3491 | for (auto &block : funcOp) | |||
3492 | for (auto &op : block) { | |||
3493 | if (op.getDialect() != dialect) | |||
3494 | return op.emitError( | |||
3495 | "functions in 'spirv.module' can only contain spirv.* ops"); | |||
3496 | } | |||
3497 | } | |||
3498 | } | |||
3499 | ||||
3500 | return success(); | |||
3501 | } | |||
3502 | ||||
3503 | //===----------------------------------------------------------------------===// | |||
3504 | // spirv.mlir.referenceof | |||
3505 | //===----------------------------------------------------------------------===// | |||
3506 | ||||
3507 | LogicalResult spirv::ReferenceOfOp::verify() { | |||
3508 | auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( | |||
3509 | (*this)->getParentOp(), getSpecConstAttr()); | |||
3510 | Type constType; | |||
3511 | ||||
3512 | auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym); | |||
3513 | if (specConstOp) | |||
3514 | constType = specConstOp.getDefaultValue().getType(); | |||
3515 | ||||
3516 | auto specConstCompositeOp = | |||
3517 | dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym); | |||
3518 | if (specConstCompositeOp) | |||
3519 | constType = specConstCompositeOp.getType(); | |||
3520 | ||||
3521 | if (!specConstOp && !specConstCompositeOp) | |||
3522 | return emitOpError( | |||
3523 | "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol"); | |||
3524 | ||||
3525 | if (getReference().getType() != constType) | |||
3526 | return emitOpError("result type mismatch with the referenced " | |||
3527 | "specialization constant's type"); | |||
3528 | ||||
3529 | return success(); | |||
3530 | } | |||
3531 | ||||
3532 | //===----------------------------------------------------------------------===// | |||
3533 | // spirv.Return | |||
3534 | //===----------------------------------------------------------------------===// | |||
3535 | ||||
3536 | LogicalResult spirv::ReturnOp::verify() { | |||
3537 | // Verification is performed in spirv.func op. | |||
3538 | return success(); | |||
3539 | } | |||
3540 | ||||
3541 | //===----------------------------------------------------------------------===// | |||
3542 | // spirv.ReturnValue | |||
3543 | //===----------------------------------------------------------------------===// | |||
3544 | ||||
3545 | LogicalResult spirv::ReturnValueOp::verify() { | |||
3546 | // Verification is performed in spirv.func op. | |||
3547 | return success(); | |||
3548 | } | |||
3549 | ||||
3550 | //===----------------------------------------------------------------------===// | |||
3551 | // spirv.Select | |||
3552 | //===----------------------------------------------------------------------===// | |||
3553 | ||||
3554 | LogicalResult spirv::SelectOp::verify() { | |||
3555 | if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) { | |||
3556 | auto resultVectorTy = getResult().getType().dyn_cast<VectorType>(); | |||
3557 | if (!resultVectorTy) { | |||
3558 | return emitOpError("result expected to be of vector type when " | |||
3559 | "condition is of vector type"); | |||
3560 | } | |||
3561 | if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { | |||
3562 | return emitOpError("result should have the same number of elements as " | |||
3563 | "the condition when condition is of vector type"); | |||
3564 | } | |||
3565 | } | |||
3566 | return success(); | |||
3567 | } | |||
3568 | ||||
3569 | //===----------------------------------------------------------------------===// | |||
3570 | // spirv.mlir.selection | |||
3571 | //===----------------------------------------------------------------------===// | |||
3572 | ||||
3573 | ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, | |||
3574 | OperationState &result) { | |||
3575 | if (parseControlAttribute<spirv::SelectionControlAttr, | |||
3576 | spirv::SelectionControl>(parser, result)) | |||
3577 | return failure(); | |||
3578 | return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); | |||
3579 | } | |||
3580 | ||||
3581 | void spirv::SelectionOp::print(OpAsmPrinter &printer) { | |||
3582 | auto control = getSelectionControl(); | |||
3583 | if (control != spirv::SelectionControl::None) | |||
3584 | printer << " control(" << spirv::stringifySelectionControl(control) << ")"; | |||
3585 | printer << ' '; | |||
3586 | printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, | |||
3587 | /*printBlockTerminators=*/true); | |||
3588 | } | |||
3589 | ||||
3590 | LogicalResult spirv::SelectionOp::verifyRegions() { | |||
3591 | auto *op = getOperation(); | |||
3592 | ||||
3593 | // We need to verify that the blocks follow the following layout: | |||
3594 | // | |||
3595 | // +--------------+ | |||
3596 | // | header block | | |||
3597 | // +--------------+ | |||
3598 | // / | \ | |||
3599 | // ... | |||
3600 | // | |||
3601 | // | |||
3602 | // +---------+ +---------+ +---------+ | |||
3603 | // | case #0 | | case #1 | | case #2 | ... | |||
3604 | // +---------+ +---------+ +---------+ | |||
3605 | // | |||
3606 | // | |||
3607 | // ... | |||
3608 | // \ | / | |||
3609 | // v | |||
3610 | // +-------------+ | |||
3611 | // | merge block | | |||
3612 | // +-------------+ | |||
3613 | ||||
3614 | auto ®ion = op->getRegion(0); | |||
3615 | // Allow empty region as a degenerated case, which can come from | |||
3616 | // optimizations. | |||
3617 | if (region.empty()) | |||
3618 | return success(); | |||
3619 | ||||
3620 | // The last block is the merge block. | |||
3621 | if (!isMergeBlock(region.back())) | |||
3622 | return emitOpError("last block must be the merge block with only one " | |||
3623 | "'spirv.mlir.merge' op"); | |||
3624 | ||||
3625 | if (std::next(region.begin()) == region.end()) | |||
3626 | return emitOpError("must have a selection header block"); | |||
3627 | ||||
3628 | return success(); | |||
3629 | } | |||
3630 | ||||
3631 | Block *spirv::SelectionOp::getHeaderBlock() { | |||
3632 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3632, __extension__ __PRETTY_FUNCTION__)); | |||
3633 | // The first block is the loop header block. | |||
3634 | return &getBody().front(); | |||
3635 | } | |||
3636 | ||||
3637 | Block *spirv::SelectionOp::getMergeBlock() { | |||
3638 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3638, __extension__ __PRETTY_FUNCTION__)); | |||
3639 | // The last block is the loop merge block. | |||
3640 | return &getBody().back(); | |||
3641 | } | |||
3642 | ||||
3643 | void spirv::SelectionOp::addMergeBlock() { | |||
3644 | assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist" ) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3644, __extension__ __PRETTY_FUNCTION__)); | |||
3645 | auto *mergeBlock = new Block(); | |||
3646 | getBody().push_back(mergeBlock); | |||
3647 | OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); | |||
3648 | ||||
3649 | // Add a spirv.mlir.merge op into the merge block. | |||
3650 | builder.create<spirv::MergeOp>(getLoc()); | |||
3651 | } | |||
3652 | ||||
3653 | spirv::SelectionOp spirv::SelectionOp::createIfThen( | |||
3654 | Location loc, Value condition, | |||
3655 | function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) { | |||
3656 | auto selectionOp = | |||
3657 | builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); | |||
3658 | ||||
3659 | selectionOp.addMergeBlock(); | |||
3660 | Block *mergeBlock = selectionOp.getMergeBlock(); | |||
3661 | Block *thenBlock = nullptr; | |||
3662 | ||||
3663 | // Build the "then" block. | |||
3664 | { | |||
3665 | OpBuilder::InsertionGuard guard(builder); | |||
3666 | thenBlock = builder.createBlock(mergeBlock); | |||
3667 | thenBody(builder); | |||
3668 | builder.create<spirv::BranchOp>(loc, mergeBlock); | |||
3669 | } | |||
3670 | ||||
3671 | // Build the header block. | |||
3672 | { | |||
3673 | OpBuilder::InsertionGuard guard(builder); | |||
3674 | builder.createBlock(thenBlock); | |||
3675 | builder.create<spirv::BranchConditionalOp>( | |||
3676 | loc, condition, thenBlock, | |||
3677 | /*trueArguments=*/ArrayRef<Value>(), mergeBlock, | |||
3678 | /*falseArguments=*/ArrayRef<Value>()); | |||
3679 | } | |||
3680 | ||||
3681 | return selectionOp; | |||
3682 | } | |||
3683 | ||||
3684 | //===----------------------------------------------------------------------===// | |||
3685 | // spirv.SpecConstant | |||
3686 | //===----------------------------------------------------------------------===// | |||
3687 | ||||
3688 | ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, | |||
3689 | OperationState &result) { | |||
3690 | StringAttr nameAttr; | |||
3691 | Attribute valueAttr; | |||
3692 | ||||
3693 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
3694 | result.attributes)) | |||
3695 | return failure(); | |||
3696 | ||||
3697 | // Parse optional spec_id. | |||
3698 | if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { | |||
3699 | IntegerAttr specIdAttr; | |||
3700 | if (parser.parseLParen() || | |||
3701 | parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) || | |||
3702 | parser.parseRParen()) | |||
3703 | return failure(); | |||
3704 | } | |||
3705 | ||||
3706 | if (parser.parseEqual() || | |||
3707 | parser.parseAttribute(valueAttr, kDefaultValueAttrName, | |||
3708 | result.attributes)) | |||
3709 | return failure(); | |||
3710 | ||||
3711 | return success(); | |||
3712 | } | |||
3713 | ||||
3714 | void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { | |||
3715 | printer << ' '; | |||
3716 | printer.printSymbolName(getSymName()); | |||
3717 | if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) | |||
3718 | printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; | |||
3719 | printer << " = " << getDefaultValue(); | |||
3720 | } | |||
3721 | ||||
3722 | LogicalResult spirv::SpecConstantOp::verify() { | |||
3723 | if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) | |||
3724 | if (specID.getValue().isNegative()) | |||
3725 | return emitOpError("SpecId cannot be negative"); | |||
3726 | ||||
3727 | auto value = getDefaultValue(); | |||
3728 | if (value.isa<IntegerAttr, FloatAttr>()) { | |||
3729 | // Make sure bitwidth is allowed. | |||
3730 | if (!value.getType().isa<spirv::SPIRVType>()) | |||
3731 | return emitOpError("default value bitwidth disallowed"); | |||
3732 | return success(); | |||
3733 | } | |||
3734 | return emitOpError( | |||
3735 | "default value can only be a bool, integer, or float scalar"); | |||
3736 | } | |||
3737 | ||||
3738 | //===----------------------------------------------------------------------===// | |||
3739 | // spirv.StoreOp | |||
3740 | //===----------------------------------------------------------------------===// | |||
3741 | ||||
3742 | ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3743 | // Parse the storage class specification | |||
3744 | spirv::StorageClass storageClass; | |||
3745 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
3746 | auto loc = parser.getCurrentLocation(); | |||
3747 | Type elementType; | |||
3748 | if (parseEnumStrAttr(storageClass, parser) || | |||
3749 | parser.parseOperandList(operandInfo, 2) || | |||
3750 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
3751 | parser.parseType(elementType)) { | |||
3752 | return failure(); | |||
3753 | } | |||
3754 | ||||
3755 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
3756 | if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, | |||
3757 | result.operands)) { | |||
3758 | return failure(); | |||
3759 | } | |||
3760 | return success(); | |||
3761 | } | |||
3762 | ||||
3763 | void spirv::StoreOp::print(OpAsmPrinter &printer) { | |||
3764 | SmallVector<StringRef, 4> elidedAttrs; | |||
3765 | StringRef sc = stringifyStorageClass( | |||
3766 | getPtr().getType().cast<spirv::PointerType>().getStorageClass()); | |||
3767 | printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); | |||
3768 | ||||
3769 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
3770 | ||||
3771 | printer << " : " << getValue().getType(); | |||
3772 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
3773 | } | |||
3774 | ||||
3775 | LogicalResult spirv::StoreOp::verify() { | |||
3776 | // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an | |||
3777 | // OpTypePointer whose Type operand is the same as the type of Object." | |||
3778 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) | |||
3779 | return failure(); | |||
3780 | return verifyMemoryAccessAttribute(*this); | |||
3781 | } | |||
3782 | ||||
3783 | //===----------------------------------------------------------------------===// | |||
3784 | // spirv.Unreachable | |||
3785 | //===----------------------------------------------------------------------===// | |||
3786 | ||||
3787 | LogicalResult spirv::UnreachableOp::verify() { | |||
3788 | auto *block = (*this)->getBlock(); | |||
3789 | // Fast track: if this is in entry block, its invalid. Otherwise, if no | |||
3790 | // predecessors, it's valid. | |||
3791 | if (block->isEntryBlock()) | |||
3792 | return emitOpError("cannot be used in reachable block"); | |||
3793 | if (block->hasNoPredecessors()) | |||
3794 | return success(); | |||
3795 | ||||
3796 | // TODO: further verification needs to analyze reachability from | |||
3797 | // the entry block. | |||
3798 | ||||
3799 | return success(); | |||
3800 | } | |||
3801 | ||||
3802 | //===----------------------------------------------------------------------===// | |||
3803 | // spirv.Variable | |||
3804 | //===----------------------------------------------------------------------===// | |||
3805 | ||||
3806 | ParseResult spirv::VariableOp::parse(OpAsmParser &parser, | |||
3807 | OperationState &result) { | |||
3808 | // Parse optional initializer | |||
3809 | std::optional<OpAsmParser::UnresolvedOperand> initInfo; | |||
3810 | if (succeeded(parser.parseOptionalKeyword("init"))) { | |||
3811 | initInfo = OpAsmParser::UnresolvedOperand(); | |||
3812 | if (parser.parseLParen() || parser.parseOperand(*initInfo) || | |||
3813 | parser.parseRParen()) | |||
3814 | return failure(); | |||
3815 | } | |||
3816 | ||||
3817 | if (parseVariableDecorations(parser, result)) { | |||
3818 | return failure(); | |||
3819 | } | |||
3820 | ||||
3821 | // Parse result pointer type | |||
3822 | Type type; | |||
3823 | if (parser.parseColon()) | |||
3824 | return failure(); | |||
3825 | auto loc = parser.getCurrentLocation(); | |||
3826 | if (parser.parseType(type)) | |||
3827 | return failure(); | |||
3828 | ||||
3829 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
3830 | if (!ptrType) | |||
3831 | return parser.emitError(loc, "expected spirv.ptr type"); | |||
3832 | result.addTypes(ptrType); | |||
3833 | ||||
3834 | // Resolve the initializer operand | |||
3835 | if (initInfo) { | |||
3836 | if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), | |||
3837 | result.operands)) | |||
3838 | return failure(); | |||
3839 | } | |||
3840 | ||||
3841 | auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>( | |||
3842 | ptrType.getStorageClass()); | |||
3843 | result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr); | |||
3844 | ||||
3845 | return success(); | |||
3846 | } | |||
3847 | ||||
3848 | void spirv::VariableOp::print(OpAsmPrinter &printer) { | |||
3849 | SmallVector<StringRef, 4> elidedAttrs{ | |||
3850 | spirv::attributeName<spirv::StorageClass>()}; | |||
3851 | // Print optional initializer | |||
3852 | if (getNumOperands() != 0) | |||
3853 | printer << " init(" << getInitializer() << ")"; | |||
3854 | ||||
3855 | printVariableDecorations(*this, printer, elidedAttrs); | |||
3856 | printer << " : " << getType(); | |||
3857 | } | |||
3858 | ||||
3859 | LogicalResult spirv::VariableOp::verify() { | |||
3860 | // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the | |||
3861 | // object. It cannot be Generic. It must be the same as the Storage Class | |||
3862 | // operand of the Result Type." | |||
3863 | if (getStorageClass() != spirv::StorageClass::Function) { | |||
3864 | return emitOpError( | |||
3865 | "can only be used to model function-level variables. Use " | |||
3866 | "spirv.GlobalVariable for module-level variables."); | |||
3867 | } | |||
3868 | ||||
3869 | auto pointerType = getPointer().getType().cast<spirv::PointerType>(); | |||
3870 | if (getStorageClass() != pointerType.getStorageClass()) | |||
3871 | return emitOpError( | |||
3872 | "storage class must match result pointer's storage class"); | |||
3873 | ||||
3874 | if (getNumOperands() != 0) { | |||
3875 | // SPIR-V spec: "Initializer must be an <id> from a constant instruction or | |||
3876 | // a global (module scope) OpVariable instruction". | |||
3877 | auto *initOp = getOperand(0).getDefiningOp(); | |||
3878 | if (!initOp || !isa<spirv::ConstantOp, // for normal constant | |||
3879 | spirv::ReferenceOfOp, // for spec constant | |||
3880 | spirv::AddressOfOp>(initOp)) | |||
3881 | return emitOpError("initializer must be the result of a " | |||
3882 | "constant or spirv.GlobalVariable op"); | |||
3883 | } | |||
3884 | ||||
3885 | // TODO: generate these strings using ODS. | |||
3886 | auto *op = getOperation(); | |||
3887 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
3888 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
3889 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
3890 | stringifyDecoration(spirv::Decoration::Binding)); | |||
3891 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
3892 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
3893 | ||||
3894 | for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { | |||
3895 | if (op->getAttr(attr)) | |||
3896 | return emitOpError("cannot have '") | |||
3897 | << attr << "' attribute (only allowed in spirv.GlobalVariable)"; | |||
3898 | } | |||
3899 | ||||
3900 | return success(); | |||
3901 | } | |||
3902 | ||||
3903 | //===----------------------------------------------------------------------===// | |||
3904 | // spirv.VectorShuffle | |||
3905 | //===----------------------------------------------------------------------===// | |||
3906 | ||||
3907 | LogicalResult spirv::VectorShuffleOp::verify() { | |||
3908 | VectorType resultType = getType().cast<VectorType>(); | |||
3909 | ||||
3910 | size_t numResultElements = resultType.getNumElements(); | |||
3911 | if (numResultElements != getComponents().size()) | |||
3912 | return emitOpError("result type element count (") | |||
3913 | << numResultElements | |||
3914 | << ") mismatch with the number of component selectors (" | |||
3915 | << getComponents().size() << ")"; | |||
3916 | ||||
3917 | size_t totalSrcElements = | |||
3918 | getVector1().getType().cast<VectorType>().getNumElements() + | |||
3919 | getVector2().getType().cast<VectorType>().getNumElements(); | |||
3920 | ||||
3921 | for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) { | |||
3922 | uint32_t index = selector.getZExtValue(); | |||
3923 | if (index >= totalSrcElements && | |||
3924 | index != std::numeric_limits<uint32_t>().max()) | |||
3925 | return emitOpError("component selector ") | |||
3926 | << index << " out of range: expected to be in [0, " | |||
3927 | << totalSrcElements << ") or 0xffffffff"; | |||
3928 | } | |||
3929 | return success(); | |||
3930 | } | |||
3931 | ||||
3932 | //===----------------------------------------------------------------------===// | |||
3933 | // spirv.NV.CooperativeMatrixLoad | |||
3934 | //===----------------------------------------------------------------------===// | |||
3935 | ||||
3936 | ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, | |||
3937 | OperationState &result) { | |||
3938 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; | |||
3939 | Type strideType = parser.getBuilder().getIntegerType(32); | |||
3940 | Type columnMajorType = parser.getBuilder().getIntegerType(1); | |||
3941 | Type ptrType; | |||
3942 | Type elementType; | |||
3943 | if (parser.parseOperandList(operandInfo, 3) || | |||
3944 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
3945 | parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { | |||
3946 | return failure(); | |||
3947 | } | |||
3948 | if (parser.resolveOperands(operandInfo, | |||
3949 | {ptrType, strideType, columnMajorType}, | |||
3950 | parser.getNameLoc(), result.operands)) { | |||
3951 | return failure(); | |||
3952 | } | |||
3953 | ||||
3954 | result.addTypes(elementType); | |||
3955 | return success(); | |||
3956 | } | |||
3957 | ||||
3958 | void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { | |||
3959 | printer << " " << getPointer() << ", " << getStride() << ", " | |||
3960 | << getColumnmajor(); | |||
3961 | // Print optional memory access attribute. | |||
3962 | if (auto memAccess = getMemoryAccess()) | |||
3963 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; | |||
3964 | printer << " : " << getPointer().getType() << " as " << getType(); | |||
3965 | } | |||
3966 | ||||
3967 | static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, | |||
3968 | Type coopMatrix) { | |||
3969 | Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType(); | |||
3970 | if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>()) | |||
3971 | return op->emitError( | |||
3972 | "Pointer must point to a scalar or vector type but provided ") | |||
3973 | << pointeeType; | |||
3974 | spirv::StorageClass storage = | |||
3975 | pointer.cast<spirv::PointerType>().getStorageClass(); | |||
3976 | if (storage != spirv::StorageClass::Workgroup && | |||
3977 | storage != spirv::StorageClass::StorageBuffer && | |||
3978 | storage != spirv::StorageClass::PhysicalStorageBuffer) | |||
3979 | return op->emitError( | |||
3980 | "Pointer storage class must be Workgroup, StorageBuffer or " | |||
3981 | "PhysicalStorageBufferEXT but provided ") | |||
3982 | << stringifyStorageClass(storage); | |||
3983 | return success(); | |||
3984 | } | |||
3985 | ||||
3986 | LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { | |||
3987 | return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), | |||
3988 | getResult().getType()); | |||
3989 | } | |||
3990 | ||||
3991 | //===----------------------------------------------------------------------===// | |||
3992 | // spirv.NV.CooperativeMatrixStore | |||
3993 | //===----------------------------------------------------------------------===// | |||
3994 | ||||
3995 | ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, | |||
3996 | OperationState &result) { | |||
3997 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo; | |||
3998 | Type strideType = parser.getBuilder().getIntegerType(32); | |||
3999 | Type columnMajorType = parser.getBuilder().getIntegerType(1); | |||
4000 | Type ptrType; | |||
4001 | Type elementType; | |||
4002 | if (parser.parseOperandList(operandInfo, 4) || | |||
4003 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
4004 | parser.parseType(ptrType) || parser.parseComma() || | |||
4005 | parser.parseType(elementType)) { | |||
4006 | return failure(); | |||
4007 | } | |||
4008 | if (parser.resolveOperands( | |||
4009 | operandInfo, {ptrType, elementType, strideType, columnMajorType}, | |||
4010 | parser.getNameLoc(), result.operands)) { | |||
4011 | return failure(); | |||
4012 | } | |||
4013 | ||||
4014 | return success(); | |||
4015 | } | |||
4016 | ||||
4017 | void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { | |||
4018 | printer << " " << getPointer() << ", " << getObject() << ", " << getStride() | |||
4019 | << ", " << getColumnmajor(); | |||
4020 | // Print optional memory access attribute. | |||
4021 | if (auto memAccess = getMemoryAccess()) | |||
4022 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; | |||
4023 | printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); | |||
4024 | } | |||
4025 | ||||
4026 | LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { | |||
4027 | return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), | |||
4028 | getObject().getType()); | |||
4029 | } | |||
4030 | ||||
4031 | //===----------------------------------------------------------------------===// | |||
4032 | // spirv.NV.CooperativeMatrixMulAdd | |||
4033 | //===----------------------------------------------------------------------===// | |||
4034 | ||||
4035 | static LogicalResult | |||
4036 | verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { | |||
4037 | if (op.getC().getType() != op.getResult().getType()) | |||
4038 | return op.emitOpError("result and third operand must have the same type"); | |||
4039 | auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4040 | auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4041 | auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4042 | auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4043 | if (typeA.getRows() != typeR.getRows() || | |||
4044 | typeA.getColumns() != typeB.getRows() || | |||
4045 | typeB.getColumns() != typeR.getColumns()) | |||
4046 | return op.emitOpError("matrix size must match"); | |||
4047 | if (typeR.getScope() != typeA.getScope() || | |||
4048 | typeR.getScope() != typeB.getScope() || | |||
4049 | typeR.getScope() != typeC.getScope()) | |||
4050 | return op.emitOpError("matrix scope must match"); | |||
4051 | if (typeA.getElementType() != typeB.getElementType() || | |||
4052 | typeR.getElementType() != typeC.getElementType()) | |||
4053 | return op.emitOpError("matrix element type must match"); | |||
4054 | return success(); | |||
4055 | } | |||
4056 | ||||
4057 | LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() { | |||
4058 | return verifyCoopMatrixMulAdd(*this); | |||
4059 | } | |||
4060 | ||||
4061 | static LogicalResult | |||
4062 | verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { | |||
4063 | Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType(); | |||
4064 | if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>()) | |||
4065 | return op->emitError( | |||
4066 | "Pointer must point to a scalar or vector type but provided ") | |||
4067 | << pointeeType; | |||
4068 | spirv::StorageClass storage = | |||
4069 | pointer.cast<spirv::PointerType>().getStorageClass(); | |||
4070 | if (storage != spirv::StorageClass::Workgroup && | |||
4071 | storage != spirv::StorageClass::CrossWorkgroup && | |||
4072 | storage != spirv::StorageClass::UniformConstant && | |||
4073 | storage != spirv::StorageClass::Generic) | |||
4074 | return op->emitError("Pointer storage class must be Workgroup or " | |||
4075 | "CrossWorkgroup but provided ") | |||
4076 | << stringifyStorageClass(storage); | |||
4077 | return success(); | |||
4078 | } | |||
4079 | ||||
4080 | //===----------------------------------------------------------------------===// | |||
4081 | // spirv.INTEL.JointMatrixLoad | |||
4082 | //===----------------------------------------------------------------------===// | |||
4083 | ||||
4084 | LogicalResult spirv::INTELJointMatrixLoadOp::verify() { | |||
4085 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), | |||
4086 | getResult().getType()); | |||
4087 | } | |||
4088 | ||||
4089 | //===----------------------------------------------------------------------===// | |||
4090 | // spirv.INTEL.JointMatrixStore | |||
4091 | //===----------------------------------------------------------------------===// | |||
4092 | ||||
4093 | LogicalResult spirv::INTELJointMatrixStoreOp::verify() { | |||
4094 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), | |||
4095 | getObject().getType()); | |||
4096 | } | |||
4097 | ||||
4098 | //===----------------------------------------------------------------------===// | |||
4099 | // spirv.INTEL.JointMatrixMad | |||
4100 | //===----------------------------------------------------------------------===// | |||
4101 | ||||
4102 | static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { | |||
4103 | if (op.getC().getType() != op.getResult().getType()) | |||
4104 | return op.emitOpError("result and third operand must have the same type"); | |||
4105 | auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>(); | |||
4106 | auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>(); | |||
4107 | auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>(); | |||
4108 | auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>(); | |||
4109 | if (typeA.getRows() != typeR.getRows() || | |||
4110 | typeA.getColumns() != typeB.getRows() || | |||
4111 | typeB.getColumns() != typeR.getColumns()) | |||
4112 | return op.emitOpError("matrix size must match"); | |||
4113 | if (typeR.getScope() != typeA.getScope() || | |||
4114 | typeR.getScope() != typeB.getScope() || | |||
4115 | typeR.getScope() != typeC.getScope()) | |||
4116 | return op.emitOpError("matrix scope must match"); | |||
4117 | if (typeA.getElementType() != typeB.getElementType() || | |||
4118 | typeR.getElementType() != typeC.getElementType()) | |||
4119 | return op.emitOpError("matrix element type must match"); | |||
4120 | return success(); | |||
4121 | } | |||
4122 | ||||
4123 | LogicalResult spirv::INTELJointMatrixMadOp::verify() { | |||
4124 | return verifyJointMatrixMad(*this); | |||
4125 | } | |||
4126 | ||||
4127 | //===----------------------------------------------------------------------===// | |||
4128 | // spirv.MatrixTimesScalar | |||
4129 | //===----------------------------------------------------------------------===// | |||
4130 | ||||
4131 | LogicalResult spirv::MatrixTimesScalarOp::verify() { | |||
4132 | if (auto inputCoopmat = | |||
4133 | getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
4134 | if (inputCoopmat.getElementType() != getScalar().getType()) | |||
4135 | return emitError("input matrix components' type and scaling value must " | |||
4136 | "have the same type"); | |||
4137 | return success(); | |||
4138 | } | |||
4139 | ||||
4140 | // Check that the scalar type is the same as the matrix element type. | |||
4141 | auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>(); | |||
4142 | if (getScalar().getType() != inputMatrix.getElementType()) | |||
4143 | return emitError("input matrix components' type and scaling value must " | |||
4144 | "have the same type"); | |||
4145 | ||||
4146 | return success(); | |||
4147 | } | |||
4148 | ||||
4149 | //===----------------------------------------------------------------------===// | |||
4150 | // spirv.CopyMemory | |||
4151 | //===----------------------------------------------------------------------===// | |||
4152 | ||||
4153 | void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) { | |||
4154 | printer << ' '; | |||
4155 | ||||
4156 | StringRef targetStorageClass = stringifyStorageClass( | |||
4157 | getTarget().getType().cast<spirv::PointerType>().getStorageClass()); | |||
4158 | printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; | |||
4159 | ||||
4160 | StringRef sourceStorageClass = stringifyStorageClass( | |||
4161 | getSource().getType().cast<spirv::PointerType>().getStorageClass()); | |||
4162 | printer << " \"" << sourceStorageClass << "\" " << getSource(); | |||
4163 | ||||
4164 | SmallVector<StringRef, 4> elidedAttrs; | |||
4165 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
4166 | printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, | |||
4167 | getSourceMemoryAccess(), | |||
4168 | getSourceAlignment()); | |||
4169 | ||||
4170 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
4171 | ||||
4172 | Type pointeeType = | |||
4173 | getTarget().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4174 | printer << " : " << pointeeType; | |||
4175 | } | |||
4176 | ||||
4177 | ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser, | |||
4178 | OperationState &result) { | |||
4179 | spirv::StorageClass targetStorageClass; | |||
4180 | OpAsmParser::UnresolvedOperand targetPtrInfo; | |||
4181 | ||||
4182 | spirv::StorageClass sourceStorageClass; | |||
| ||||
4183 | OpAsmParser::UnresolvedOperand sourcePtrInfo; | |||
4184 | ||||
4185 | Type elementType; | |||
4186 | ||||
4187 | if (parseEnumStrAttr(targetStorageClass, parser) || | |||
4188 | parser.parseOperand(targetPtrInfo) || parser.parseComma() || | |||
4189 | parseEnumStrAttr(sourceStorageClass, parser) || | |||
4190 | parser.parseOperand(sourcePtrInfo) || | |||
4191 | parseMemoryAccessAttributes(parser, result)) { | |||
4192 | return failure(); | |||
4193 | } | |||
4194 | ||||
4195 | if (!parser.parseOptionalComma()) { | |||
4196 | // Parse 2nd memory access attributes. | |||
4197 | if (parseSourceMemoryAccessAttributes(parser, result)) { | |||
4198 | return failure(); | |||
4199 | } | |||
4200 | } | |||
4201 | ||||
4202 | if (parser.parseColon() || parser.parseType(elementType)) | |||
4203 | return failure(); | |||
4204 | ||||
4205 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
4206 | return failure(); | |||
4207 | ||||
4208 | auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); | |||
4209 | auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); | |||
| ||||
4210 | ||||
4211 | if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || | |||
4212 | parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { | |||
4213 | return failure(); | |||
4214 | } | |||
4215 | ||||
4216 | return success(); | |||
4217 | } | |||
4218 | ||||
4219 | LogicalResult spirv::CopyMemoryOp::verify() { | |||
4220 | Type targetType = | |||
4221 | getTarget().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4222 | ||||
4223 | Type sourceType = | |||
4224 | getSource().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4225 | ||||
4226 | if (targetType != sourceType) | |||
4227 | return emitOpError("both operands must be pointers to the same type"); | |||
4228 | ||||
4229 | if (failed(verifyMemoryAccessAttribute(*this))) | |||
4230 | return failure(); | |||
4231 | ||||
4232 | // TODO - According to the spec: | |||
4233 | // | |||
4234 | // If two masks are present, the first applies to Target and cannot include | |||
4235 | // MakePointerVisible, and the second applies to Source and cannot include | |||
4236 | // MakePointerAvailable. | |||
4237 | // | |||
4238 | // Add such verification here. | |||
4239 | ||||
4240 | return verifySourceMemoryAccessAttribute(*this); | |||
4241 | } | |||
4242 | ||||
4243 | //===----------------------------------------------------------------------===// | |||
4244 | // spirv.Transpose | |||
4245 | //===----------------------------------------------------------------------===// | |||
4246 | ||||
4247 | LogicalResult spirv::TransposeOp::verify() { | |||
4248 | auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>(); | |||
4249 | auto resultMatrix = getResult().getType().cast<spirv::MatrixType>(); | |||
4250 | ||||
4251 | // Verify that the input and output matrices have correct shapes. | |||
4252 | if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) | |||
4253 | return emitError("input matrix rows count must be equal to " | |||
4254 | "output matrix columns count"); | |||
4255 | ||||
4256 | if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) | |||
4257 | return emitError("input matrix columns count must be equal to " | |||
4258 | "output matrix rows count"); | |||
4259 | ||||
4260 | // Verify that the input and output matrices have the same component type | |||
4261 | if (inputMatrix.getElementType() != resultMatrix.getElementType()) | |||
4262 | return emitError("input and output matrices must have the same " | |||
4263 | "component type"); | |||
4264 | ||||
4265 | return success(); | |||
4266 | } | |||
4267 | ||||
4268 | //===----------------------------------------------------------------------===// | |||
4269 | // spirv.MatrixTimesMatrix | |||
4270 | //===----------------------------------------------------------------------===// | |||
4271 | ||||
4272 | LogicalResult spirv::MatrixTimesMatrixOp::verify() { | |||
4273 | auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>(); | |||
4274 | auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>(); | |||
4275 | auto resultMatrix = getResult().getType().cast<spirv::MatrixType>(); | |||
4276 | ||||
4277 | // left matrix columns' count and right matrix rows' count must be equal | |||
4278 | if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) | |||
4279 | return emitError("left matrix columns' count must be equal to " | |||
4280 | "the right matrix rows' count"); | |||
4281 | ||||
4282 | // right and result matrices columns' count must be the same | |||
4283 | if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) | |||
4284 | return emitError( | |||
4285 | "right and result matrices must have equal columns' count"); | |||
4286 | ||||
4287 | // right and result matrices component type must be the same | |||
4288 | if (rightMatrix.getElementType() != resultMatrix.getElementType()) | |||
4289 | return emitError("right and result matrices' component type must" | |||
4290 | " be the same"); | |||
4291 | ||||
4292 | // left and result matrices component type must be the same | |||
4293 | if (leftMatrix.getElementType() != resultMatrix.getElementType()) | |||
4294 | return emitError("left and result matrices' component type" | |||
4295 | " must be the same"); | |||
4296 | ||||
4297 | // left and result matrices rows count must be the same | |||
4298 | if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) | |||
4299 | return emitError("left and result matrices must have equal rows' count"); | |||
4300 | ||||
4301 | return success(); | |||
4302 | } | |||
4303 | ||||
4304 | //===----------------------------------------------------------------------===// | |||
4305 | // spirv.SpecConstantComposite | |||
4306 | //===----------------------------------------------------------------------===// | |||
4307 | ||||
4308 | ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, | |||
4309 | OperationState &result) { | |||
4310 | ||||
4311 | StringAttr compositeName; | |||
4312 | if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), | |||
4313 | result.attributes)) | |||
4314 | return failure(); | |||
4315 | ||||
4316 | if (parser.parseLParen()) | |||
4317 | return failure(); | |||
4318 | ||||
4319 | SmallVector<Attribute, 4> constituents; | |||
4320 | ||||
4321 | do { | |||
4322 | // The name of the constituent attribute isn't important | |||
4323 | const char *attrName = "spec_const"; | |||
4324 | FlatSymbolRefAttr specConstRef; | |||
4325 | NamedAttrList attrs; | |||
4326 | ||||
4327 | if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) | |||
4328 | return failure(); | |||
4329 | ||||
4330 | constituents.push_back(specConstRef); | |||
4331 | } while (!parser.parseOptionalComma()); | |||
4332 | ||||
4333 | if (parser.parseRParen()) | |||
4334 | return failure(); | |||
4335 | ||||
4336 | result.addAttribute(kCompositeSpecConstituentsName, | |||
4337 | parser.getBuilder().getArrayAttr(constituents)); | |||
4338 | ||||
4339 | Type type; | |||
4340 | if (parser.parseColonType(type)) | |||
4341 | return failure(); | |||
4342 | ||||
4343 | result.addAttribute(kTypeAttrName, TypeAttr::get(type)); | |||
4344 | ||||
4345 | return success(); | |||
4346 | } | |||
4347 | ||||
4348 | void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { | |||
4349 | printer << " "; | |||
4350 | printer.printSymbolName(getSymName()); | |||
4351 | printer << " ("; | |||
4352 | auto constituents = this->getConstituents().getValue(); | |||
4353 | ||||
4354 | if (!constituents.empty()) | |||
4355 | llvm::interleaveComma(constituents, printer); | |||
4356 | ||||
4357 | printer << ") : " << getType(); | |||
4358 | } | |||
4359 | ||||
4360 | LogicalResult spirv::SpecConstantCompositeOp::verify() { | |||
4361 | auto cType = getType().dyn_cast<spirv::CompositeType>(); | |||
4362 | auto constituents = this->getConstituents().getValue(); | |||
4363 | ||||
4364 | if (!cType) | |||
4365 | return emitError("result type must be a composite type, but provided ") | |||
4366 | << getType(); | |||
4367 | ||||
4368 | if (cType.isa<spirv::CooperativeMatrixNVType>()) | |||
4369 | return emitError("unsupported composite type ") << cType; | |||
4370 | if (cType.isa<spirv::JointMatrixINTELType>()) | |||
4371 | return emitError("unsupported composite type ") << cType; | |||
4372 | if (constituents.size() != cType.getNumElements()) | |||
4373 | return emitError("has incorrect number of operands: expected ") | |||
4374 | << cType.getNumElements() << ", but provided " | |||
4375 | << constituents.size(); | |||
4376 | ||||
4377 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { | |||
4378 | auto constituent = constituents[index].cast<FlatSymbolRefAttr>(); | |||
4379 | ||||
4380 | auto constituentSpecConstOp = | |||
4381 | dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom( | |||
4382 | (*this)->getParentOp(), constituent.getAttr())); | |||
4383 | ||||
4384 | if (constituentSpecConstOp.getDefaultValue().getType() != | |||
4385 | cType.getElementType(index)) | |||
4386 | return emitError("has incorrect types of operands: expected ") | |||
4387 | << cType.getElementType(index) << ", but provided " | |||
4388 | << constituentSpecConstOp.getDefaultValue().getType(); | |||
4389 | } | |||
4390 | ||||
4391 | return success(); | |||
4392 | } | |||
4393 | ||||
4394 | //===----------------------------------------------------------------------===// | |||
4395 | // spirv.SpecConstantOperation | |||
4396 | //===----------------------------------------------------------------------===// | |||
4397 | ||||
4398 | ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, | |||
4399 | OperationState &result) { | |||
4400 | Region *body = result.addRegion(); | |||
4401 | ||||
4402 | if (parser.parseKeyword("wraps")) | |||
4403 | return failure(); | |||
4404 | ||||
4405 | body->push_back(new Block); | |||
4406 | Block &block = body->back(); | |||
4407 | Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); | |||
4408 | ||||
4409 | if (!wrappedOp) | |||
4410 | return failure(); | |||
4411 | ||||
4412 | OpBuilder builder(parser.getContext()); | |||
4413 | builder.setInsertionPointToEnd(&block); | |||
4414 | builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0)); | |||
4415 | result.location = wrappedOp->getLoc(); | |||
4416 | ||||
4417 | result.addTypes(wrappedOp->getResult(0).getType()); | |||
4418 | ||||
4419 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
4420 | return failure(); | |||
4421 | ||||
4422 | return success(); | |||
4423 | } | |||
4424 | ||||
4425 | void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { | |||
4426 | printer << " wraps "; | |||
4427 | printer.printGenericOp(&getBody().front().front()); | |||
4428 | } | |||
4429 | ||||
4430 | LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { | |||
4431 | Block &block = getRegion().getBlocks().front(); | |||
4432 | ||||
4433 | if (block.getOperations().size() != 2) | |||
4434 | return emitOpError("expected exactly 2 nested ops"); | |||
4435 | ||||
4436 | Operation &enclosedOp = block.getOperations().front(); | |||
4437 | ||||
4438 | if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>()) | |||
4439 | return emitOpError("invalid enclosed op"); | |||
4440 | ||||
4441 | for (auto operand : enclosedOp.getOperands()) | |||
4442 | if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp, | |||
4443 | spirv::SpecConstantOperationOp>(operand.getDefiningOp())) | |||
4444 | return emitOpError( | |||
4445 | "invalid operand, must be defined by a constant operation"); | |||
4446 | ||||
4447 | return success(); | |||
4448 | } | |||
4449 | ||||
4450 | //===----------------------------------------------------------------------===// | |||
4451 | // spirv.GL.FrexpStruct | |||
4452 | //===----------------------------------------------------------------------===// | |||
4453 | ||||
4454 | LogicalResult spirv::GLFrexpStructOp::verify() { | |||
4455 | spirv::StructType structTy = | |||
4456 | getResult().getType().dyn_cast<spirv::StructType>(); | |||
4457 | ||||
4458 | if (structTy.getNumElements() != 2) | |||
4459 | return emitError("result type must be a struct type with two memebers"); | |||
4460 | ||||
4461 | Type significandTy = structTy.getElementType(0); | |||
4462 | Type exponentTy = structTy.getElementType(1); | |||
4463 | VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>(); | |||
4464 | IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>(); | |||
4465 | ||||
4466 | Type operandTy = getOperand().getType(); | |||
4467 | VectorType operandVecTy = operandTy.dyn_cast<VectorType>(); | |||
4468 | FloatType operandFTy = operandTy.dyn_cast<FloatType>(); | |||
4469 | ||||
4470 | if (significandTy != operandTy) | |||
4471 | return emitError("member zero of the resulting struct type must be the " | |||
4472 | "same type as the operand"); | |||
4473 | ||||
4474 | if (exponentVecTy) { | |||
4475 | IntegerType componentIntTy = | |||
4476 | exponentVecTy.getElementType().dyn_cast<IntegerType>(); | |||
4477 | if (!componentIntTy || componentIntTy.getWidth() != 32) | |||
4478 | return emitError("member one of the resulting struct type must" | |||
4479 | "be a scalar or vector of 32 bit integer type"); | |||
4480 | } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { | |||
4481 | return emitError("member one of the resulting struct type " | |||
4482 | "must be a scalar or vector of 32 bit integer type"); | |||
4483 | } | |||
4484 | ||||
4485 | // Check that the two member types have the same number of components | |||
4486 | if (operandVecTy && exponentVecTy && | |||
4487 | (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) | |||
4488 | return success(); | |||
4489 | ||||
4490 | if (operandFTy && exponentIntTy) | |||
4491 | return success(); | |||
4492 | ||||
4493 | return emitError("member one of the resulting struct type must have the same " | |||
4494 | "number of components as the operand type"); | |||
4495 | } | |||
4496 | ||||
4497 | //===----------------------------------------------------------------------===// | |||
4498 | // spirv.GL.Ldexp | |||
4499 | //===----------------------------------------------------------------------===// | |||
4500 | ||||
4501 | LogicalResult spirv::GLLdexpOp::verify() { | |||
4502 | Type significandType = getX().getType(); | |||
4503 | Type exponentType = getExp().getType(); | |||
4504 | ||||
4505 | if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>()) | |||
4506 | return emitOpError("operands must both be scalars or vectors"); | |||
4507 | ||||
4508 | auto getNumElements = [](Type type) -> unsigned { | |||
4509 | if (auto vectorType = type.dyn_cast<VectorType>()) | |||
4510 | return vectorType.getNumElements(); | |||
4511 | return 1; | |||
4512 | }; | |||
4513 | ||||
4514 | if (getNumElements(significandType) != getNumElements(exponentType)) | |||
4515 | return emitOpError("operands must have the same number of elements"); | |||
4516 | ||||
4517 | return success(); | |||
4518 | } | |||
4519 | ||||
4520 | //===----------------------------------------------------------------------===// | |||
4521 | // spirv.ImageDrefGather | |||
4522 | //===----------------------------------------------------------------------===// | |||
4523 | ||||
4524 | LogicalResult spirv::ImageDrefGatherOp::verify() { | |||
4525 | VectorType resultType = getResult().getType().cast<VectorType>(); | |||
4526 | auto sampledImageType = | |||
4527 | getSampledimage().getType().cast<spirv::SampledImageType>(); | |||
4528 | auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>(); | |||
4529 | ||||
4530 | if (resultType.getNumElements() != 4) | |||
4531 | return emitOpError("result type must be a vector of four components"); | |||
4532 | ||||
4533 | Type elementType = resultType.getElementType(); | |||
4534 | Type sampledElementType = imageType.getElementType(); | |||
4535 | if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType) | |||
4536 | return emitOpError( | |||
4537 | "the component type of result must be the same as sampled type of the " | |||
4538 | "underlying image type"); | |||
4539 | ||||
4540 | spirv::Dim imageDim = imageType.getDim(); | |||
4541 | spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); | |||
4542 | ||||
4543 | if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && | |||
4544 | imageDim != spirv::Dim::Rect) | |||
4545 | return emitOpError( | |||
4546 | "the Dim operand of the underlying image type must be 2D, Cube, or " | |||
4547 | "Rect"); | |||
4548 | ||||
4549 | if (imageMS != spirv::ImageSamplingInfo::SingleSampled) | |||
4550 | return emitOpError("the MS operand of the underlying image type must be 0"); | |||
4551 | ||||
4552 | spirv::ImageOperandsAttr attr = getImageoperandsAttr(); | |||
4553 | auto operandArguments = getOperandArguments(); | |||
4554 | ||||
4555 | return verifyImageOperands(*this, attr, operandArguments); | |||
4556 | } | |||
4557 | ||||
4558 | //===----------------------------------------------------------------------===// | |||
4559 | // spirv.ShiftLeftLogicalOp | |||
4560 | //===----------------------------------------------------------------------===// | |||
4561 | ||||
4562 | LogicalResult spirv::ShiftLeftLogicalOp::verify() { | |||
4563 | return verifyShiftOp(*this); | |||
4564 | } | |||
4565 | ||||
4566 | //===----------------------------------------------------------------------===// | |||
4567 | // spirv.ShiftRightArithmeticOp | |||
4568 | //===----------------------------------------------------------------------===// | |||
4569 | ||||
4570 | LogicalResult spirv::ShiftRightArithmeticOp::verify() { | |||
4571 | return verifyShiftOp(*this); | |||
4572 | } | |||
4573 | ||||
4574 | //===----------------------------------------------------------------------===// | |||
4575 | // spirv.ShiftRightLogicalOp | |||
4576 | //===----------------------------------------------------------------------===// | |||
4577 | ||||
4578 | LogicalResult spirv::ShiftRightLogicalOp::verify() { | |||
4579 | return verifyShiftOp(*this); | |||
4580 | } | |||
4581 | ||||
4582 | //===----------------------------------------------------------------------===// | |||
4583 | // spirv.ImageQuerySize | |||
4584 | //===----------------------------------------------------------------------===// | |||
4585 | ||||
4586 | LogicalResult spirv::ImageQuerySizeOp::verify() { | |||
4587 | spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>(); | |||
4588 | Type resultType = getResult().getType(); | |||
4589 | ||||
4590 | spirv::Dim dim = imageType.getDim(); | |||
4591 | spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); | |||
4592 | spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); | |||
4593 | switch (dim) { | |||
4594 | case spirv::Dim::Dim1D: | |||
4595 | case spirv::Dim::Dim2D: | |||
4596 | case spirv::Dim::Dim3D: | |||
4597 | case spirv::Dim::Cube: | |||
4598 | if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && | |||
4599 | samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && | |||
4600 | samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) | |||
4601 | return emitError( | |||
4602 | "if Dim is 1D, 2D, 3D, or Cube, " | |||
4603 | "it must also have either an MS of 1 or a Sampled of 0 or 2"); | |||
4604 | break; | |||
4605 | case spirv::Dim::Buffer: | |||
4606 | case spirv::Dim::Rect: | |||
4607 | break; | |||
4608 | default: | |||
4609 | return emitError("the Dim operand of the image type must " | |||
4610 | "be 1D, 2D, 3D, Buffer, Cube, or Rect"); | |||
4611 | } | |||
4612 | ||||
4613 | unsigned componentNumber = 0; | |||
4614 | switch (dim) { | |||
4615 | case spirv::Dim::Dim1D: | |||
4616 | case spirv::Dim::Buffer: | |||
4617 | componentNumber = 1; | |||
4618 | break; | |||
4619 | case spirv::Dim::Dim2D: | |||
4620 | case spirv::Dim::Cube: | |||
4621 | case spirv::Dim::Rect: | |||
4622 | componentNumber = 2; | |||
4623 | break; | |||
4624 | case spirv::Dim::Dim3D: | |||
4625 | componentNumber = 3; | |||
4626 | break; | |||
4627 | default: | |||
4628 | break; | |||
4629 | } | |||
4630 | ||||
4631 | if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) | |||
4632 | componentNumber += 1; | |||
4633 | ||||
4634 | unsigned resultComponentNumber = 1; | |||
4635 | if (auto resultVectorType = resultType.dyn_cast<VectorType>()) | |||
4636 | resultComponentNumber = resultVectorType.getNumElements(); | |||
4637 | ||||
4638 | if (componentNumber != resultComponentNumber) | |||
4639 | return emitError("expected the result to have ") | |||
4640 | << componentNumber << " component(s), but found " | |||
4641 | << resultComponentNumber << " component(s)"; | |||
4642 | ||||
4643 | return success(); | |||
4644 | } | |||
4645 | ||||
4646 | static ParseResult parsePtrAccessChainOpImpl(StringRef opName, | |||
4647 | OpAsmParser &parser, | |||
4648 | OperationState &state) { | |||
4649 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
4650 | SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo; | |||
4651 | Type type; | |||
4652 | auto loc = parser.getCurrentLocation(); | |||
4653 | SmallVector<Type, 4> indicesTypes; | |||
4654 | ||||
4655 | if (parser.parseOperand(ptrInfo) || | |||
4656 | parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || | |||
4657 | parser.parseColonType(type) || | |||
4658 | parser.resolveOperand(ptrInfo, type, state.operands)) | |||
4659 | return failure(); | |||
4660 | ||||
4661 | // Check that the provided indices list is not empty before parsing their | |||
4662 | // type list. | |||
4663 | if (indicesInfo.empty()) | |||
4664 | return emitError(state.location) << opName << " expected element"; | |||
4665 | ||||
4666 | if (parser.parseComma() || parser.parseTypeList(indicesTypes)) | |||
4667 | return failure(); | |||
4668 | ||||
4669 | // Check that the indices types list is not empty and that it has a one-to-one | |||
4670 | // mapping to the provided indices. | |||
4671 | if (indicesTypes.size() != indicesInfo.size()) | |||
4672 | return emitError(state.location) | |||
4673 | << opName | |||
4674 | << " indices types' count must be equal to indices info count"; | |||
4675 | ||||
4676 | if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) | |||
4677 | return failure(); | |||
4678 | ||||
4679 | auto resultType = getElementPtrType( | |||
4680 | type, llvm::makeArrayRef(state.operands).drop_front(2), state.location); | |||
4681 | if (!resultType) | |||
4682 | return failure(); | |||
4683 | ||||
4684 | state.addTypes(resultType); | |||
4685 | return success(); | |||
4686 | } | |||
4687 | ||||
4688 | template <typename Op> | |||
4689 | static auto concatElemAndIndices(Op op) { | |||
4690 | SmallVector<Value> ret(op.getIndices().size() + 1); | |||
4691 | ret[0] = op.getElement(); | |||
4692 | llvm::copy(op.getIndices(), ret.begin() + 1); | |||
4693 | return ret; | |||
4694 | } | |||
4695 | ||||
4696 | //===----------------------------------------------------------------------===// | |||
4697 | // spirv.InBoundsPtrAccessChainOp | |||
4698 | //===----------------------------------------------------------------------===// | |||
4699 | ||||
4700 | void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, | |||
4701 | OperationState &state, | |||
4702 | Value basePtr, Value element, | |||
4703 | ValueRange indices) { | |||
4704 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
4705 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4705, __extension__ __PRETTY_FUNCTION__)); | |||
4706 | build(builder, state, type, basePtr, element, indices); | |||
4707 | } | |||
4708 | ||||
4709 | ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, | |||
4710 | OperationState &result) { | |||
4711 | return parsePtrAccessChainOpImpl( | |||
4712 | spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result); | |||
4713 | } | |||
4714 | ||||
4715 | void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { | |||
4716 | printAccessChain(*this, concatElemAndIndices(*this), printer); | |||
4717 | } | |||
4718 | ||||
4719 | LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { | |||
4720 | return verifyAccessChain(*this, getIndices()); | |||
4721 | } | |||
4722 | ||||
4723 | //===----------------------------------------------------------------------===// | |||
4724 | // spirv.PtrAccessChainOp | |||
4725 | //===----------------------------------------------------------------------===// | |||
4726 | ||||
4727 | void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, | |||
4728 | Value basePtr, Value element, | |||
4729 | ValueRange indices) { | |||
4730 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
4731 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4731, __extension__ __PRETTY_FUNCTION__)); | |||
4732 | build(builder, state, type, basePtr, element, indices); | |||
4733 | } | |||
4734 | ||||
4735 | ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser, | |||
4736 | OperationState &result) { | |||
4737 | return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), | |||
4738 | parser, result); | |||
4739 | } | |||
4740 | ||||
4741 | void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) { | |||
4742 | printAccessChain(*this, concatElemAndIndices(*this), printer); | |||
4743 | } | |||
4744 | ||||
4745 | LogicalResult spirv::PtrAccessChainOp::verify() { | |||
4746 | return verifyAccessChain(*this, getIndices()); | |||
4747 | } | |||
4748 | ||||
4749 | //===----------------------------------------------------------------------===// | |||
4750 | // spirv.VectorTimesScalarOp | |||
4751 | //===----------------------------------------------------------------------===// | |||
4752 | ||||
4753 | LogicalResult spirv::VectorTimesScalarOp::verify() { | |||
4754 | if (getVector().getType() != getType()) | |||
4755 | return emitOpError("vector operand and result type mismatch"); | |||
4756 | auto scalarType = getType().cast<VectorType>().getElementType(); | |||
4757 | if (getScalar().getType() != scalarType) | |||
4758 | return emitOpError("scalar operand and result element type match"); | |||
4759 | return success(); | |||
4760 | } | |||
4761 | ||||
4762 | //===----------------------------------------------------------------------===// | |||
4763 | // Group ops | |||
4764 | //===----------------------------------------------------------------------===// | |||
4765 | ||||
4766 | template <typename Op> | |||
4767 | static LogicalResult verifyGroupOp(Op op) { | |||
4768 | spirv::Scope scope = op.getExecutionScope(); | |||
4769 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
4770 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
4771 | ||||
4772 | return success(); | |||
4773 | } | |||
4774 | ||||
4775 | LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); } | |||
4776 | ||||
4777 | LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); } | |||
4778 | ||||
4779 | LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); } | |||
4780 | ||||
4781 | LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); } | |||
4782 | ||||
4783 | LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); } | |||
4784 | ||||
4785 | LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); } | |||
4786 | ||||
4787 | LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); } | |||
4788 | ||||
4789 | LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); } | |||
4790 | ||||
4791 | LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); } | |||
4792 | ||||
4793 | LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } | |||
4794 | ||||
4795 | //===----------------------------------------------------------------------===// | |||
4796 | // Integer Dot Product ops | |||
4797 | //===----------------------------------------------------------------------===// | |||
4798 | ||||
4799 | static LogicalResult verifyIntegerDotProduct(Operation *op) { | |||
4800 | assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&(static_cast <bool> (llvm::is_contained({2u, 3u}, op-> getNumOperands()) && "Not an integer dot product op?" ) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4801, __extension__ __PRETTY_FUNCTION__)) | |||
4801 | "Not an integer dot product op?")(static_cast <bool> (llvm::is_contained({2u, 3u}, op-> getNumOperands()) && "Not an integer dot product op?" ) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4801, __extension__ __PRETTY_FUNCTION__)); | |||
4802 | assert(op->getNumResults() == 1 && "Expected a single result")(static_cast <bool> (op->getNumResults() == 1 && "Expected a single result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"Expected a single result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4802, __extension__ __PRETTY_FUNCTION__)); | |||
4803 | ||||
4804 | Type factorTy = op->getOperand(0).getType(); | |||
4805 | if (op->getOperand(1).getType() != factorTy) | |||
4806 | return op->emitOpError("requires the same type for both vector operands"); | |||
4807 | ||||
4808 | unsigned expectedNumAttrs = 0; | |||
4809 | if (auto intTy = factorTy.dyn_cast<IntegerType>()) { | |||
4810 | ++expectedNumAttrs; | |||
4811 | auto packedVectorFormat = | |||
4812 | op->getAttr(kPackedVectorFormatAttrName) | |||
4813 | .dyn_cast_or_null<spirv::PackedVectorFormatAttr>(); | |||
4814 | if (!packedVectorFormat) | |||
4815 | return op->emitOpError("requires Packed Vector Format attribute for " | |||
4816 | "integer vector operands"); | |||
4817 | ||||
4818 | assert(packedVectorFormat.getValue() ==(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__ __PRETTY_FUNCTION__)) | |||
4819 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__ __PRETTY_FUNCTION__)) | |||
4820 | "Unknown Packed Vector Format")(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4820, __extension__ __PRETTY_FUNCTION__)); | |||
4821 | if (intTy.getWidth() != 32) | |||
4822 | return op->emitOpError( | |||
4823 | llvm::formatv("with specified Packed Vector Format ({0}) requires " | |||
4824 | "integer vector operands to be 32-bits wide", | |||
4825 | packedVectorFormat.getValue())); | |||
4826 | } else { | |||
4827 | if (op->hasAttr(kPackedVectorFormatAttrName)) | |||
4828 | return op->emitOpError(llvm::formatv( | |||
4829 | "with invalid format attribute for vector operands of type '{0}'", | |||
4830 | factorTy)); | |||
4831 | } | |||
4832 | ||||
4833 | if (op->getAttrs().size() > expectedNumAttrs) | |||
4834 | return op->emitError( | |||
4835 | "op only supports the 'format' #spirv.packed_vector_format attribute"); | |||
4836 | ||||
4837 | Type resultTy = op->getResultTypes().front(); | |||
4838 | bool hasAccumulator = op->getNumOperands() == 3; | |||
4839 | if (hasAccumulator && op->getOperand(2).getType() != resultTy) | |||
4840 | return op->emitOpError( | |||
4841 | "requires the same accumulator operand and result types"); | |||
4842 | ||||
4843 | unsigned factorBitWidth = getBitWidth(factorTy); | |||
4844 | unsigned resultBitWidth = getBitWidth(resultTy); | |||
4845 | if (factorBitWidth > resultBitWidth) | |||
4846 | return op->emitOpError( | |||
4847 | llvm::formatv("result type has insufficient bit-width ({0} bits) " | |||
4848 | "for the specified vector operand type ({1} bits)", | |||
4849 | resultBitWidth, factorBitWidth)); | |||
4850 | ||||
4851 | return success(); | |||
4852 | } | |||
4853 | ||||
4854 | static Optional<spirv::Version> getIntegerDotProductMinVersion() { | |||
4855 | return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. | |||
4856 | } | |||
4857 | ||||
4858 | static Optional<spirv::Version> getIntegerDotProductMaxVersion() { | |||
4859 | return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. | |||
4860 | } | |||
4861 | ||||
4862 | static SmallVector<ArrayRef<spirv::Extension>, 1> | |||
4863 | getIntegerDotProductExtensions() { | |||
4864 | // Requires the SPV_KHR_integer_dot_product extension, specified either | |||
4865 | // explicitly or implied by target env's SPIR-V version >= 1.6. | |||
4866 | static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; | |||
4867 | return {extension}; | |||
4868 | } | |||
4869 | ||||
4870 | static SmallVector<ArrayRef<spirv::Capability>, 1> | |||
4871 | getIntegerDotProductCapabilities(Operation *op) { | |||
4872 | // Requires the the DotProduct capability and capabilities that depend on | |||
4873 | // exact op types. | |||
4874 | static const auto dotProductCap = spirv::Capability::DotProduct; | |||
4875 | static const auto dotProductInput4x8BitPackedCap = | |||
4876 | spirv::Capability::DotProductInput4x8BitPacked; | |||
4877 | static const auto dotProductInput4x8BitCap = | |||
4878 | spirv::Capability::DotProductInput4x8Bit; | |||
4879 | static const auto dotProductInputAllCap = | |||
4880 | spirv::Capability::DotProductInputAll; | |||
4881 | ||||
4882 | SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap}; | |||
4883 | ||||
4884 | Type factorTy = op->getOperand(0).getType(); | |||
4885 | if (auto intTy = factorTy.dyn_cast<IntegerType>()) { | |||
4886 | auto formatAttr = op->getAttr(kPackedVectorFormatAttrName) | |||
4887 | .cast<spirv::PackedVectorFormatAttr>(); | |||
4888 | if (formatAttr.getValue() == | |||
4889 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) | |||
4890 | capabilities.push_back(dotProductInput4x8BitPackedCap); | |||
4891 | ||||
4892 | return capabilities; | |||
4893 | } | |||
4894 | ||||
4895 | auto vecTy = factorTy.cast<VectorType>(); | |||
4896 | if (vecTy.getElementTypeBitWidth() == 8) { | |||
4897 | capabilities.push_back(dotProductInput4x8BitCap); | |||
4898 | return capabilities; | |||
4899 | } | |||
4900 | ||||
4901 | capabilities.push_back(dotProductInputAllCap); | |||
4902 | return capabilities; | |||
4903 | } | |||
4904 | ||||
4905 | #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ | |||
4906 | LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \ | |||
4907 | SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \ | |||
4908 | return getIntegerDotProductExtensions(); \ | |||
4909 | } \ | |||
4910 | SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \ | |||
4911 | return getIntegerDotProductCapabilities(*this); \ | |||
4912 | } \ | |||
4913 | Optional<spirv::Version> OpName::getMinVersion() { \ | |||
4914 | return getIntegerDotProductMinVersion(); \ | |||
4915 | } \ | |||
4916 | Optional<spirv::Version> OpName::getMaxVersion() { \ | |||
4917 | return getIntegerDotProductMaxVersion(); \ | |||
4918 | } | |||
4919 | ||||
4920 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp) | |||
4921 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp) | |||
4922 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp) | |||
4923 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp) | |||
4924 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp) | |||
4925 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp) | |||
4926 | ||||
4927 | #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP | |||
4928 | ||||
4929 | // TableGen'erated operation interfaces for querying versions, extensions, and | |||
4930 | // capabilities. | |||
4931 | #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" | |||
4932 | ||||
4933 | // TablenGen'erated operation definitions. | |||
4934 | #define GET_OP_CLASSES | |||
4935 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" | |||
4936 | ||||
4937 | namespace mlir { | |||
4938 | namespace spirv { | |||
4939 | // TableGen'erated operation availability interface implementations. | |||
4940 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" | |||
4941 | } // namespace spirv | |||
4942 | } // namespace mlir |