File: | build/source/mlir/include/mlir/IR/Builders.h |
Warning: | line 100, column 12 2nd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- SPIRVOps.cpp - MLIR SPIR-V operations ------------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file defines the operations in the SPIR-V dialect. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" | |||
14 | ||||
15 | #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" | |||
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" | |||
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | |||
18 | #include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h" | |||
19 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpTraits.h" | |||
20 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" | |||
21 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" | |||
22 | #include "mlir/IR/Builders.h" | |||
23 | #include "mlir/IR/BuiltinTypes.h" | |||
24 | #include "mlir/IR/FunctionImplementation.h" | |||
25 | #include "mlir/IR/OpDefinition.h" | |||
26 | #include "mlir/IR/OpImplementation.h" | |||
27 | #include "mlir/IR/Operation.h" | |||
28 | #include "mlir/IR/TypeUtilities.h" | |||
29 | #include "mlir/Interfaces/CallInterfaces.h" | |||
30 | #include "llvm/ADT/APFloat.h" | |||
31 | #include "llvm/ADT/APInt.h" | |||
32 | #include "llvm/ADT/ArrayRef.h" | |||
33 | #include "llvm/ADT/STLExtras.h" | |||
34 | #include "llvm/ADT/StringExtras.h" | |||
35 | #include "llvm/Support/FormatVariadic.h" | |||
36 | #include <cassert> | |||
37 | #include <numeric> | |||
38 | ||||
39 | using namespace mlir; | |||
40 | ||||
41 | // TODO: generate these strings using ODS. | |||
42 | constexpr char kAlignmentAttrName[] = "alignment"; | |||
43 | constexpr char kBranchWeightAttrName[] = "branch_weights"; | |||
44 | constexpr char kCallee[] = "callee"; | |||
45 | constexpr char kClusterSize[] = "cluster_size"; | |||
46 | constexpr char kControl[] = "control"; | |||
47 | constexpr char kDefaultValueAttrName[] = "default_value"; | |||
48 | constexpr char kEqualSemanticsAttrName[] = "equal_semantics"; | |||
49 | constexpr char kExecutionScopeAttrName[] = "execution_scope"; | |||
50 | constexpr char kFnNameAttrName[] = "fn"; | |||
51 | constexpr char kGroupOperationAttrName[] = "group_operation"; | |||
52 | constexpr char kIndicesAttrName[] = "indices"; | |||
53 | constexpr char kInitializerAttrName[] = "initializer"; | |||
54 | constexpr char kInterfaceAttrName[] = "interface"; | |||
55 | constexpr char kMemoryAccessAttrName[] = "memory_access"; | |||
56 | constexpr char kMemoryScopeAttrName[] = "memory_scope"; | |||
57 | constexpr char kPackedVectorFormatAttrName[] = "format"; | |||
58 | constexpr char kSemanticsAttrName[] = "semantics"; | |||
59 | constexpr char kSourceAlignmentAttrName[] = "source_alignment"; | |||
60 | constexpr char kSourceMemoryAccessAttrName[] = "source_memory_access"; | |||
61 | constexpr char kSpecIdAttrName[] = "spec_id"; | |||
62 | constexpr char kTypeAttrName[] = "type"; | |||
63 | constexpr char kUnequalSemanticsAttrName[] = "unequal_semantics"; | |||
64 | constexpr char kValueAttrName[] = "value"; | |||
65 | constexpr char kValuesAttrName[] = "values"; | |||
66 | constexpr char kCompositeSpecConstituentsName[] = "constituents"; | |||
67 | ||||
68 | //===----------------------------------------------------------------------===// | |||
69 | // Common utility functions | |||
70 | //===----------------------------------------------------------------------===// | |||
71 | ||||
72 | static ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser, | |||
73 | OperationState &result) { | |||
74 | SmallVector<OpAsmParser::UnresolvedOperand, 2> ops; | |||
75 | Type type; | |||
76 | // If the operand list is in-between parentheses, then we have a generic form. | |||
77 | // (see the fallback in `printOneResultOp`). | |||
78 | SMLoc loc = parser.getCurrentLocation(); | |||
79 | if (!parser.parseOptionalLParen()) { | |||
80 | if (parser.parseOperandList(ops) || parser.parseRParen() || | |||
81 | parser.parseOptionalAttrDict(result.attributes) || | |||
82 | parser.parseColon() || parser.parseType(type)) | |||
83 | return failure(); | |||
84 | auto fnType = type.dyn_cast<FunctionType>(); | |||
85 | if (!fnType) { | |||
86 | parser.emitError(loc, "expected function type"); | |||
87 | return failure(); | |||
88 | } | |||
89 | if (parser.resolveOperands(ops, fnType.getInputs(), loc, result.operands)) | |||
90 | return failure(); | |||
91 | result.addTypes(fnType.getResults()); | |||
92 | return success(); | |||
93 | } | |||
94 | return failure(parser.parseOperandList(ops) || | |||
95 | parser.parseOptionalAttrDict(result.attributes) || | |||
96 | parser.parseColonType(type) || | |||
97 | parser.resolveOperands(ops, type, result.operands) || | |||
98 | parser.addTypeToList(type, result.types)); | |||
99 | } | |||
100 | ||||
101 | static void printOneResultOp(Operation *op, OpAsmPrinter &p) { | |||
102 | assert(op->getNumResults() == 1 && "op should have one result")(static_cast <bool> (op->getNumResults() == 1 && "op should have one result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"op should have one result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 102, __extension__ __PRETTY_FUNCTION__)); | |||
103 | ||||
104 | // If not all the operand and result types are the same, just use the | |||
105 | // generic assembly form to avoid omitting information in printing. | |||
106 | auto resultType = op->getResult(0).getType(); | |||
107 | if (llvm::any_of(op->getOperandTypes(), | |||
108 | [&](Type type) { return type != resultType; })) { | |||
109 | p.printGenericOp(op, /*printOpName=*/false); | |||
110 | return; | |||
111 | } | |||
112 | ||||
113 | p << ' '; | |||
114 | p.printOperands(op->getOperands()); | |||
115 | p.printOptionalAttrDict(op->getAttrs()); | |||
116 | // Now we can output only one type for all operands and the result. | |||
117 | p << " : " << resultType; | |||
118 | } | |||
119 | ||||
120 | /// Returns true if the given op is a function-like op or nested in a | |||
121 | /// function-like op without a module-like op in the middle. | |||
122 | static bool isNestedInFunctionOpInterface(Operation *op) { | |||
123 | if (!op) | |||
124 | return false; | |||
125 | if (op->hasTrait<OpTrait::SymbolTable>()) | |||
126 | return false; | |||
127 | if (isa<FunctionOpInterface>(op)) | |||
128 | return true; | |||
129 | return isNestedInFunctionOpInterface(op->getParentOp()); | |||
130 | } | |||
131 | ||||
132 | /// Returns true if the given op is an module-like op that maintains a symbol | |||
133 | /// table. | |||
134 | static bool isDirectInModuleLikeOp(Operation *op) { | |||
135 | return op && op->hasTrait<OpTrait::SymbolTable>(); | |||
136 | } | |||
137 | ||||
138 | static LogicalResult extractValueFromConstOp(Operation *op, int32_t &value) { | |||
139 | auto constOp = dyn_cast_or_null<spirv::ConstantOp>(op); | |||
140 | if (!constOp) { | |||
141 | return failure(); | |||
142 | } | |||
143 | auto valueAttr = constOp.getValue(); | |||
144 | auto integerValueAttr = valueAttr.dyn_cast<IntegerAttr>(); | |||
145 | if (!integerValueAttr) { | |||
146 | return failure(); | |||
147 | } | |||
148 | ||||
149 | if (integerValueAttr.getType().isSignlessInteger()) | |||
150 | value = integerValueAttr.getInt(); | |||
151 | else | |||
152 | value = integerValueAttr.getSInt(); | |||
153 | ||||
154 | return success(); | |||
155 | } | |||
156 | ||||
157 | template <typename Ty> | |||
158 | static ArrayAttr | |||
159 | getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues, | |||
160 | function_ref<StringRef(Ty)> stringifyFn) { | |||
161 | if (enumValues.empty()) { | |||
162 | return nullptr; | |||
163 | } | |||
164 | SmallVector<StringRef, 1> enumValStrs; | |||
165 | enumValStrs.reserve(enumValues.size()); | |||
166 | for (auto val : enumValues) { | |||
167 | enumValStrs.emplace_back(stringifyFn(val)); | |||
168 | } | |||
169 | return builder.getStrArrayAttr(enumValStrs); | |||
170 | } | |||
171 | ||||
172 | /// Parses the next string attribute in `parser` as an enumerant of the given | |||
173 | /// `EnumClass`. | |||
174 | template <typename EnumClass> | |||
175 | static ParseResult | |||
176 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, | |||
177 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
178 | Attribute attrVal; | |||
179 | NamedAttrList attr; | |||
180 | auto loc = parser.getCurrentLocation(); | |||
181 | if (parser.parseAttribute(attrVal, parser.getBuilder().getNoneType(), | |||
182 | attrName, attr)) | |||
183 | return failure(); | |||
184 | if (!attrVal.isa<StringAttr>()) | |||
185 | return parser.emitError(loc, "expected ") | |||
186 | << attrName << " attribute specified as string"; | |||
187 | auto attrOptional = | |||
188 | spirv::symbolizeEnum<EnumClass>(attrVal.cast<StringAttr>().getValue()); | |||
189 | if (!attrOptional) | |||
190 | return parser.emitError(loc, "invalid ") | |||
191 | << attrName << " attribute specification: " << attrVal; | |||
192 | value = *attrOptional; | |||
193 | return success(); | |||
194 | } | |||
195 | ||||
196 | /// Parses the next string attribute in `parser` as an enumerant of the given | |||
197 | /// `EnumClass` and inserts the enumerant into `state` as an 32-bit integer | |||
198 | /// attribute with the enum class's name as attribute name. | |||
199 | template <typename EnumAttrClass, | |||
200 | typename EnumClass = typename EnumAttrClass::ValueType> | |||
201 | static ParseResult | |||
202 | parseEnumStrAttr(EnumClass &value, OpAsmParser &parser, OperationState &state, | |||
203 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
204 | if (parseEnumStrAttr(value, parser)) | |||
205 | return failure(); | |||
206 | state.addAttribute(attrName, | |||
207 | parser.getBuilder().getAttr<EnumAttrClass>(value)); | |||
208 | return success(); | |||
209 | } | |||
210 | ||||
211 | /// Parses the next keyword in `parser` as an enumerant of the given `EnumClass` | |||
212 | /// and inserts the enumerant into `state` as an 32-bit integer attribute with | |||
213 | /// the enum class's name as attribute name. | |||
214 | template <typename EnumAttrClass, | |||
215 | typename EnumClass = typename EnumAttrClass::ValueType> | |||
216 | static ParseResult | |||
217 | parseEnumKeywordAttr(EnumClass &value, OpAsmParser &parser, | |||
218 | OperationState &state, | |||
219 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
220 | if (parseEnumKeywordAttr(value, parser)) | |||
221 | return failure(); | |||
222 | state.addAttribute(attrName, | |||
223 | parser.getBuilder().getAttr<EnumAttrClass>(value)); | |||
224 | return success(); | |||
225 | } | |||
226 | ||||
227 | /// Parses Function, Selection and Loop control attributes. If no control is | |||
228 | /// specified, "None" is used as a default. | |||
229 | template <typename EnumAttrClass, typename EnumClass> | |||
230 | static ParseResult | |||
231 | parseControlAttribute(OpAsmParser &parser, OperationState &state, | |||
232 | StringRef attrName = spirv::attributeName<EnumClass>()) { | |||
233 | if (succeeded(parser.parseOptionalKeyword(kControl))) { | |||
234 | EnumClass control; | |||
235 | if (parser.parseLParen() || | |||
236 | parseEnumKeywordAttr<EnumAttrClass>(control, parser, state) || | |||
237 | parser.parseRParen()) | |||
238 | return failure(); | |||
239 | return success(); | |||
240 | } | |||
241 | // Set control to "None" otherwise. | |||
242 | Builder builder = parser.getBuilder(); | |||
243 | state.addAttribute(attrName, | |||
244 | builder.getAttr<EnumAttrClass>(static_cast<EnumClass>(0))); | |||
245 | return success(); | |||
246 | } | |||
247 | ||||
248 | /// Parses optional memory access attributes attached to a memory access | |||
249 | /// operand/pointer. Specifically, parses the following syntax: | |||
250 | /// (`[` memory-access `]`)? | |||
251 | /// where: | |||
252 | /// memory-access ::= `"None"` | `"Volatile"` | `"Aligned", ` | |||
253 | /// integer-literal | `"NonTemporal"` | |||
254 | static ParseResult parseMemoryAccessAttributes(OpAsmParser &parser, | |||
255 | OperationState &state) { | |||
256 | // Parse an optional list of attributes staring with '[' | |||
257 | if (parser.parseOptionalLSquare()) { | |||
258 | // Nothing to do | |||
259 | return success(); | |||
260 | } | |||
261 | ||||
262 | spirv::MemoryAccess memoryAccessAttr; | |||
263 | if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state, | |||
264 | kMemoryAccessAttrName)) | |||
265 | return failure(); | |||
266 | ||||
267 | if (spirv::bitEnumContainsAll(memoryAccessAttr, | |||
268 | spirv::MemoryAccess::Aligned)) { | |||
269 | // Parse integer attribute for alignment. | |||
270 | Attribute alignmentAttr; | |||
271 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
272 | if (parser.parseComma() || | |||
273 | parser.parseAttribute(alignmentAttr, i32Type, kAlignmentAttrName, | |||
274 | state.attributes)) { | |||
275 | return failure(); | |||
276 | } | |||
277 | } | |||
278 | return parser.parseRSquare(); | |||
279 | } | |||
280 | ||||
281 | // TODO Make sure to merge this and the previous function into one template | |||
282 | // parameterized by memory access attribute name and alignment. Doing so now | |||
283 | // results in VS2017 in producing an internal error (at the call site) that's | |||
284 | // not detailed enough to understand what is happening. | |||
285 | static ParseResult parseSourceMemoryAccessAttributes(OpAsmParser &parser, | |||
286 | OperationState &state) { | |||
287 | // Parse an optional list of attributes staring with '[' | |||
288 | if (parser.parseOptionalLSquare()) { | |||
289 | // Nothing to do | |||
290 | return success(); | |||
291 | } | |||
292 | ||||
293 | spirv::MemoryAccess memoryAccessAttr; | |||
294 | if (parseEnumStrAttr<spirv::MemoryAccessAttr>(memoryAccessAttr, parser, state, | |||
295 | kSourceMemoryAccessAttrName)) | |||
296 | return failure(); | |||
297 | ||||
298 | if (spirv::bitEnumContainsAll(memoryAccessAttr, | |||
299 | spirv::MemoryAccess::Aligned)) { | |||
300 | // Parse integer attribute for alignment. | |||
301 | Attribute alignmentAttr; | |||
302 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
303 | if (parser.parseComma() || | |||
304 | parser.parseAttribute(alignmentAttr, i32Type, kSourceAlignmentAttrName, | |||
305 | state.attributes)) { | |||
306 | return failure(); | |||
307 | } | |||
308 | } | |||
309 | return parser.parseRSquare(); | |||
310 | } | |||
311 | ||||
312 | template <typename MemoryOpTy> | |||
313 | static void printMemoryAccessAttribute( | |||
314 | MemoryOpTy memoryOp, OpAsmPrinter &printer, | |||
315 | SmallVectorImpl<StringRef> &elidedAttrs, | |||
316 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, | |||
317 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { | |||
318 | // Print optional memory access attribute. | |||
319 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue | |||
320 | : memoryOp.getMemoryAccess())) { | |||
321 | elidedAttrs.push_back(kMemoryAccessAttrName); | |||
322 | ||||
323 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; | |||
324 | ||||
325 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { | |||
326 | // Print integer alignment attribute. | |||
327 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue | |||
328 | : memoryOp.getAlignment())) { | |||
329 | elidedAttrs.push_back(kAlignmentAttrName); | |||
330 | printer << ", " << *alignment; | |||
331 | } | |||
332 | } | |||
333 | printer << "]"; | |||
334 | } | |||
335 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); | |||
336 | } | |||
337 | ||||
338 | // TODO Make sure to merge this and the previous function into one template | |||
339 | // parameterized by memory access attribute name and alignment. Doing so now | |||
340 | // results in VS2017 in producing an internal error (at the call site) that's | |||
341 | // not detailed enough to understand what is happening. | |||
342 | template <typename MemoryOpTy> | |||
343 | static void printSourceMemoryAccessAttribute( | |||
344 | MemoryOpTy memoryOp, OpAsmPrinter &printer, | |||
345 | SmallVectorImpl<StringRef> &elidedAttrs, | |||
346 | std::optional<spirv::MemoryAccess> memoryAccessAtrrValue = std::nullopt, | |||
347 | std::optional<uint32_t> alignmentAttrValue = std::nullopt) { | |||
348 | ||||
349 | printer << ", "; | |||
350 | ||||
351 | // Print optional memory access attribute. | |||
352 | if (auto memAccess = (memoryAccessAtrrValue ? memoryAccessAtrrValue | |||
353 | : memoryOp.getMemoryAccess())) { | |||
354 | elidedAttrs.push_back(kSourceMemoryAccessAttrName); | |||
355 | ||||
356 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\""; | |||
357 | ||||
358 | if (spirv::bitEnumContainsAll(*memAccess, spirv::MemoryAccess::Aligned)) { | |||
359 | // Print integer alignment attribute. | |||
360 | if (auto alignment = (alignmentAttrValue ? alignmentAttrValue | |||
361 | : memoryOp.getAlignment())) { | |||
362 | elidedAttrs.push_back(kSourceAlignmentAttrName); | |||
363 | printer << ", " << *alignment; | |||
364 | } | |||
365 | } | |||
366 | printer << "]"; | |||
367 | } | |||
368 | elidedAttrs.push_back(spirv::attributeName<spirv::StorageClass>()); | |||
369 | } | |||
370 | ||||
371 | static ParseResult parseImageOperands(OpAsmParser &parser, | |||
372 | spirv::ImageOperandsAttr &attr) { | |||
373 | // Expect image operands | |||
374 | if (parser.parseOptionalLSquare()) | |||
375 | return success(); | |||
376 | ||||
377 | spirv::ImageOperands imageOperands; | |||
378 | if (parseEnumStrAttr(imageOperands, parser)) | |||
379 | return failure(); | |||
380 | ||||
381 | attr = spirv::ImageOperandsAttr::get(parser.getContext(), imageOperands); | |||
382 | ||||
383 | return parser.parseRSquare(); | |||
384 | } | |||
385 | ||||
386 | static void printImageOperands(OpAsmPrinter &printer, Operation *imageOp, | |||
387 | spirv::ImageOperandsAttr attr) { | |||
388 | if (attr) { | |||
389 | auto strImageOperands = stringifyImageOperands(attr.getValue()); | |||
390 | printer << "[\"" << strImageOperands << "\"]"; | |||
391 | } | |||
392 | } | |||
393 | ||||
394 | template <typename Op> | |||
395 | static LogicalResult verifyImageOperands(Op imageOp, | |||
396 | spirv::ImageOperandsAttr attr, | |||
397 | Operation::operand_range operands) { | |||
398 | if (!attr) { | |||
399 | if (operands.empty()) | |||
400 | return success(); | |||
401 | ||||
402 | return imageOp.emitError("the Image Operands should encode what operands " | |||
403 | "follow, as per Image Operands"); | |||
404 | } | |||
405 | ||||
406 | // TODO: Add the validation rules for the following Image Operands. | |||
407 | spirv::ImageOperands noSupportOperands = | |||
408 | spirv::ImageOperands::Bias | spirv::ImageOperands::Lod | | |||
409 | spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset | | |||
410 | spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets | | |||
411 | spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod | | |||
412 | spirv::ImageOperands::MakeTexelAvailable | | |||
413 | spirv::ImageOperands::MakeTexelVisible | | |||
414 | spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend; | |||
415 | ||||
416 | if (spirv::bitEnumContainsAll(attr.getValue(), noSupportOperands)) | |||
417 | llvm_unreachable("unimplemented operands of Image Operands")::llvm::llvm_unreachable_internal("unimplemented operands of Image Operands" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 417); | |||
418 | ||||
419 | return success(); | |||
420 | } | |||
421 | ||||
422 | static LogicalResult verifyCastOp(Operation *op, | |||
423 | bool requireSameBitWidth = true, | |||
424 | bool skipBitWidthCheck = false) { | |||
425 | // Some CastOps have no limit on bit widths for result and operand type. | |||
426 | if (skipBitWidthCheck) | |||
427 | return success(); | |||
428 | ||||
429 | Type operandType = op->getOperand(0).getType(); | |||
430 | Type resultType = op->getResult(0).getType(); | |||
431 | ||||
432 | // ODS checks that result type and operand type have the same shape. | |||
433 | if (auto vectorType = operandType.dyn_cast<VectorType>()) { | |||
434 | operandType = vectorType.getElementType(); | |||
435 | resultType = resultType.cast<VectorType>().getElementType(); | |||
436 | } | |||
437 | ||||
438 | if (auto coopMatrixType = | |||
439 | operandType.dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
440 | operandType = coopMatrixType.getElementType(); | |||
441 | resultType = | |||
442 | resultType.cast<spirv::CooperativeMatrixNVType>().getElementType(); | |||
443 | } | |||
444 | ||||
445 | if (auto jointMatrixType = | |||
446 | operandType.dyn_cast<spirv::JointMatrixINTELType>()) { | |||
447 | operandType = jointMatrixType.getElementType(); | |||
448 | resultType = | |||
449 | resultType.cast<spirv::JointMatrixINTELType>().getElementType(); | |||
450 | } | |||
451 | ||||
452 | auto operandTypeBitWidth = operandType.getIntOrFloatBitWidth(); | |||
453 | auto resultTypeBitWidth = resultType.getIntOrFloatBitWidth(); | |||
454 | auto isSameBitWidth = operandTypeBitWidth == resultTypeBitWidth; | |||
455 | ||||
456 | if (requireSameBitWidth) { | |||
457 | if (!isSameBitWidth) { | |||
458 | return op->emitOpError( | |||
459 | "expected the same bit widths for operand type and result " | |||
460 | "type, but provided ") | |||
461 | << operandType << " and " << resultType; | |||
462 | } | |||
463 | return success(); | |||
464 | } | |||
465 | ||||
466 | if (isSameBitWidth) { | |||
467 | return op->emitOpError( | |||
468 | "expected the different bit widths for operand type and result " | |||
469 | "type, but provided ") | |||
470 | << operandType << " and " << resultType; | |||
471 | } | |||
472 | return success(); | |||
473 | } | |||
474 | ||||
475 | template <typename MemoryOpTy> | |||
476 | static LogicalResult verifyMemoryAccessAttribute(MemoryOpTy memoryOp) { | |||
477 | // ODS checks for attributes values. Just need to verify that if the | |||
478 | // memory-access attribute is Aligned, then the alignment attribute must be | |||
479 | // present. | |||
480 | auto *op = memoryOp.getOperation(); | |||
481 | auto memAccessAttr = op->getAttr(kMemoryAccessAttrName); | |||
482 | if (!memAccessAttr) { | |||
483 | // Alignment attribute shouldn't be present if memory access attribute is | |||
484 | // not present. | |||
485 | if (op->getAttr(kAlignmentAttrName)) { | |||
486 | return memoryOp.emitOpError( | |||
487 | "invalid alignment specification without aligned memory access " | |||
488 | "specification"); | |||
489 | } | |||
490 | return success(); | |||
491 | } | |||
492 | ||||
493 | auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>(); | |||
494 | ||||
495 | if (!memAccess) { | |||
496 | return memoryOp.emitOpError("invalid memory access specifier: ") | |||
497 | << memAccessAttr; | |||
498 | } | |||
499 | ||||
500 | if (spirv::bitEnumContainsAll(memAccess.getValue(), | |||
501 | spirv::MemoryAccess::Aligned)) { | |||
502 | if (!op->getAttr(kAlignmentAttrName)) { | |||
503 | return memoryOp.emitOpError("missing alignment value"); | |||
504 | } | |||
505 | } else { | |||
506 | if (op->getAttr(kAlignmentAttrName)) { | |||
507 | return memoryOp.emitOpError( | |||
508 | "invalid alignment specification with non-aligned memory access " | |||
509 | "specification"); | |||
510 | } | |||
511 | } | |||
512 | return success(); | |||
513 | } | |||
514 | ||||
515 | // TODO Make sure to merge this and the previous function into one template | |||
516 | // parameterized by memory access attribute name and alignment. Doing so now | |||
517 | // results in VS2017 in producing an internal error (at the call site) that's | |||
518 | // not detailed enough to understand what is happening. | |||
519 | template <typename MemoryOpTy> | |||
520 | static LogicalResult verifySourceMemoryAccessAttribute(MemoryOpTy memoryOp) { | |||
521 | // ODS checks for attributes values. Just need to verify that if the | |||
522 | // memory-access attribute is Aligned, then the alignment attribute must be | |||
523 | // present. | |||
524 | auto *op = memoryOp.getOperation(); | |||
525 | auto memAccessAttr = op->getAttr(kSourceMemoryAccessAttrName); | |||
526 | if (!memAccessAttr) { | |||
527 | // Alignment attribute shouldn't be present if memory access attribute is | |||
528 | // not present. | |||
529 | if (op->getAttr(kSourceAlignmentAttrName)) { | |||
530 | return memoryOp.emitOpError( | |||
531 | "invalid alignment specification without aligned memory access " | |||
532 | "specification"); | |||
533 | } | |||
534 | return success(); | |||
535 | } | |||
536 | ||||
537 | auto memAccess = memAccessAttr.template cast<spirv::MemoryAccessAttr>(); | |||
538 | ||||
539 | if (!memAccess) { | |||
540 | return memoryOp.emitOpError("invalid memory access specifier: ") | |||
541 | << memAccess; | |||
542 | } | |||
543 | ||||
544 | if (spirv::bitEnumContainsAll(memAccess.getValue(), | |||
545 | spirv::MemoryAccess::Aligned)) { | |||
546 | if (!op->getAttr(kSourceAlignmentAttrName)) { | |||
547 | return memoryOp.emitOpError("missing alignment value"); | |||
548 | } | |||
549 | } else { | |||
550 | if (op->getAttr(kSourceAlignmentAttrName)) { | |||
551 | return memoryOp.emitOpError( | |||
552 | "invalid alignment specification with non-aligned memory access " | |||
553 | "specification"); | |||
554 | } | |||
555 | } | |||
556 | return success(); | |||
557 | } | |||
558 | ||||
559 | static LogicalResult | |||
560 | verifyMemorySemantics(Operation *op, spirv::MemorySemantics memorySemantics) { | |||
561 | // According to the SPIR-V specification: | |||
562 | // "Despite being a mask and allowing multiple bits to be combined, it is | |||
563 | // invalid for more than one of these four bits to be set: Acquire, Release, | |||
564 | // AcquireRelease, or SequentiallyConsistent. Requesting both Acquire and | |||
565 | // Release semantics is done by setting the AcquireRelease bit, not by setting | |||
566 | // two bits." | |||
567 | auto atMostOneInSet = spirv::MemorySemantics::Acquire | | |||
568 | spirv::MemorySemantics::Release | | |||
569 | spirv::MemorySemantics::AcquireRelease | | |||
570 | spirv::MemorySemantics::SequentiallyConsistent; | |||
571 | ||||
572 | auto bitCount = | |||
573 | llvm::popcount(static_cast<uint32_t>(memorySemantics & atMostOneInSet)); | |||
574 | if (bitCount > 1) { | |||
575 | return op->emitError( | |||
576 | "expected at most one of these four memory constraints " | |||
577 | "to be set: `Acquire`, `Release`," | |||
578 | "`AcquireRelease` or `SequentiallyConsistent`"); | |||
579 | } | |||
580 | return success(); | |||
581 | } | |||
582 | ||||
583 | template <typename LoadStoreOpTy> | |||
584 | static LogicalResult verifyLoadStorePtrAndValTypes(LoadStoreOpTy op, Value ptr, | |||
585 | Value val) { | |||
586 | // ODS already checks ptr is spirv::PointerType. Just check that the pointee | |||
587 | // type of the pointer and the type of the value are the same | |||
588 | // | |||
589 | // TODO: Check that the value type satisfies restrictions of | |||
590 | // SPIR-V OpLoad/OpStore operations | |||
591 | if (val.getType() != | |||
592 | ptr.getType().cast<spirv::PointerType>().getPointeeType()) { | |||
593 | return op.emitOpError("mismatch in result type and pointer type"); | |||
594 | } | |||
595 | return success(); | |||
596 | } | |||
597 | ||||
598 | template <typename BlockReadWriteOpTy> | |||
599 | static LogicalResult verifyBlockReadWritePtrAndValTypes(BlockReadWriteOpTy op, | |||
600 | Value ptr, Value val) { | |||
601 | auto valType = val.getType(); | |||
602 | if (auto valVecTy = valType.dyn_cast<VectorType>()) | |||
603 | valType = valVecTy.getElementType(); | |||
604 | ||||
605 | if (valType != ptr.getType().cast<spirv::PointerType>().getPointeeType()) { | |||
606 | return op.emitOpError("mismatch in result type and pointer type"); | |||
607 | } | |||
608 | return success(); | |||
609 | } | |||
610 | ||||
611 | static ParseResult parseVariableDecorations(OpAsmParser &parser, | |||
612 | OperationState &state) { | |||
613 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
614 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
615 | if (succeeded(parser.parseOptionalKeyword("bind"))) { | |||
616 | Attribute set, binding; | |||
617 | // Parse optional descriptor binding | |||
618 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
619 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
620 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
621 | stringifyDecoration(spirv::Decoration::Binding)); | |||
622 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
623 | if (parser.parseLParen() || | |||
624 | parser.parseAttribute(set, i32Type, descriptorSetName, | |||
625 | state.attributes) || | |||
626 | parser.parseComma() || | |||
627 | parser.parseAttribute(binding, i32Type, bindingName, | |||
628 | state.attributes) || | |||
629 | parser.parseRParen()) { | |||
630 | return failure(); | |||
631 | } | |||
632 | } else if (succeeded(parser.parseOptionalKeyword(builtInName))) { | |||
633 | StringAttr builtIn; | |||
634 | if (parser.parseLParen() || | |||
635 | parser.parseAttribute(builtIn, builtInName, state.attributes) || | |||
636 | parser.parseRParen()) { | |||
637 | return failure(); | |||
638 | } | |||
639 | } | |||
640 | ||||
641 | // Parse other attributes | |||
642 | if (parser.parseOptionalAttrDict(state.attributes)) | |||
643 | return failure(); | |||
644 | ||||
645 | return success(); | |||
646 | } | |||
647 | ||||
648 | static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, | |||
649 | SmallVectorImpl<StringRef> &elidedAttrs) { | |||
650 | // Print optional descriptor binding | |||
651 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
652 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
653 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
654 | stringifyDecoration(spirv::Decoration::Binding)); | |||
655 | auto descriptorSet = op->getAttrOfType<IntegerAttr>(descriptorSetName); | |||
656 | auto binding = op->getAttrOfType<IntegerAttr>(bindingName); | |||
657 | if (descriptorSet && binding) { | |||
658 | elidedAttrs.push_back(descriptorSetName); | |||
659 | elidedAttrs.push_back(bindingName); | |||
660 | printer << " bind(" << descriptorSet.getInt() << ", " << binding.getInt() | |||
661 | << ")"; | |||
662 | } | |||
663 | ||||
664 | // Print BuiltIn attribute if present | |||
665 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
666 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
667 | if (auto builtin = op->getAttrOfType<StringAttr>(builtInName)) { | |||
668 | printer << " " << builtInName << "(\"" << builtin.getValue() << "\")"; | |||
669 | elidedAttrs.push_back(builtInName); | |||
670 | } | |||
671 | ||||
672 | printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); | |||
673 | } | |||
674 | ||||
675 | // Get bit width of types. | |||
676 | static unsigned getBitWidth(Type type) { | |||
677 | if (type.isa<spirv::PointerType>()) { | |||
678 | // Just return 64 bits for pointer types for now. | |||
679 | // TODO: Make sure not caller relies on the actual pointer width value. | |||
680 | return 64; | |||
681 | } | |||
682 | ||||
683 | if (type.isIntOrFloat()) | |||
684 | return type.getIntOrFloatBitWidth(); | |||
685 | ||||
686 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
687 | assert(vectorType.getElementType().isIntOrFloat())(static_cast <bool> (vectorType.getElementType().isIntOrFloat ()) ? void (0) : __assert_fail ("vectorType.getElementType().isIntOrFloat()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 687, __extension__ __PRETTY_FUNCTION__)); | |||
688 | return vectorType.getNumElements() * | |||
689 | vectorType.getElementType().getIntOrFloatBitWidth(); | |||
690 | } | |||
691 | llvm_unreachable("unhandled bit width computation for type")::llvm::llvm_unreachable_internal("unhandled bit width computation for type" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 691); | |||
692 | } | |||
693 | ||||
694 | /// Walks the given type hierarchy with the given indices, potentially down | |||
695 | /// to component granularity, to select an element type. Returns null type and | |||
696 | /// emits errors with the given loc on failure. | |||
697 | static Type | |||
698 | getElementType(Type type, ArrayRef<int32_t> indices, | |||
699 | function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { | |||
700 | if (indices.empty()) { | |||
701 | emitErrorFn("expected at least one index for spirv.CompositeExtract"); | |||
702 | return nullptr; | |||
703 | } | |||
704 | ||||
705 | for (auto index : indices) { | |||
706 | if (auto cType = type.dyn_cast<spirv::CompositeType>()) { | |||
707 | if (cType.hasCompileTimeKnownNumElements() && | |||
708 | (index < 0 || | |||
709 | static_cast<uint64_t>(index) >= cType.getNumElements())) { | |||
710 | emitErrorFn("index ") << index << " out of bounds for " << type; | |||
711 | return nullptr; | |||
712 | } | |||
713 | type = cType.getElementType(index); | |||
714 | } else { | |||
715 | emitErrorFn("cannot extract from non-composite type ") | |||
716 | << type << " with index " << index; | |||
717 | return nullptr; | |||
718 | } | |||
719 | } | |||
720 | return type; | |||
721 | } | |||
722 | ||||
723 | static Type | |||
724 | getElementType(Type type, Attribute indices, | |||
725 | function_ref<InFlightDiagnostic(StringRef)> emitErrorFn) { | |||
726 | auto indicesArrayAttr = indices.dyn_cast<ArrayAttr>(); | |||
727 | if (!indicesArrayAttr) { | |||
728 | emitErrorFn("expected a 32-bit integer array attribute for 'indices'"); | |||
729 | return nullptr; | |||
730 | } | |||
731 | if (indicesArrayAttr.empty()) { | |||
732 | emitErrorFn("expected at least one index for spirv.CompositeExtract"); | |||
733 | return nullptr; | |||
734 | } | |||
735 | ||||
736 | SmallVector<int32_t, 2> indexVals; | |||
737 | for (auto indexAttr : indicesArrayAttr) { | |||
738 | auto indexIntAttr = indexAttr.dyn_cast<IntegerAttr>(); | |||
739 | if (!indexIntAttr) { | |||
740 | emitErrorFn("expected an 32-bit integer for index, but found '") | |||
741 | << indexAttr << "'"; | |||
742 | return nullptr; | |||
743 | } | |||
744 | indexVals.push_back(indexIntAttr.getInt()); | |||
745 | } | |||
746 | return getElementType(type, indexVals, emitErrorFn); | |||
747 | } | |||
748 | ||||
749 | static Type getElementType(Type type, Attribute indices, Location loc) { | |||
750 | auto errorFn = [&](StringRef err) -> InFlightDiagnostic { | |||
751 | return ::mlir::emitError(loc, err); | |||
752 | }; | |||
753 | return getElementType(type, indices, errorFn); | |||
754 | } | |||
755 | ||||
756 | static Type getElementType(Type type, Attribute indices, OpAsmParser &parser, | |||
757 | SMLoc loc) { | |||
758 | auto errorFn = [&](StringRef err) -> InFlightDiagnostic { | |||
759 | return parser.emitError(loc, err); | |||
760 | }; | |||
761 | return getElementType(type, indices, errorFn); | |||
762 | } | |||
763 | ||||
764 | /// Returns true if the given `block` only contains one `spirv.mlir.merge` op. | |||
765 | static inline bool isMergeBlock(Block &block) { | |||
766 | return !block.empty() && std::next(block.begin()) == block.end() && | |||
767 | isa<spirv::MergeOp>(block.front()); | |||
768 | } | |||
769 | ||||
770 | template <typename ExtendedBinaryOp> | |||
771 | static LogicalResult verifyArithmeticExtendedBinaryOp(ExtendedBinaryOp op) { | |||
772 | auto resultType = op.getType().template cast<spirv::StructType>(); | |||
773 | if (resultType.getNumElements() != 2) | |||
774 | return op.emitOpError("expected result struct type containing two members"); | |||
775 | ||||
776 | if (!llvm::all_equal({op.getOperand1().getType(), op.getOperand2().getType(), | |||
777 | resultType.getElementType(0), | |||
778 | resultType.getElementType(1)})) | |||
779 | return op.emitOpError( | |||
780 | "expected all operand types and struct member types are the same"); | |||
781 | ||||
782 | return success(); | |||
783 | } | |||
784 | ||||
785 | static ParseResult parseArithmeticExtendedBinaryOp(OpAsmParser &parser, | |||
786 | OperationState &result) { | |||
787 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; | |||
788 | if (parser.parseOptionalAttrDict(result.attributes) || | |||
789 | parser.parseOperandList(operands) || parser.parseColon()) | |||
790 | return failure(); | |||
791 | ||||
792 | Type resultType; | |||
793 | SMLoc loc = parser.getCurrentLocation(); | |||
794 | if (parser.parseType(resultType)) | |||
795 | return failure(); | |||
796 | ||||
797 | auto structType = resultType.dyn_cast<spirv::StructType>(); | |||
798 | if (!structType || structType.getNumElements() != 2) | |||
799 | return parser.emitError(loc, "expected spirv.struct type with two members"); | |||
800 | ||||
801 | SmallVector<Type, 2> operandTypes(2, structType.getElementType(0)); | |||
802 | if (parser.resolveOperands(operands, operandTypes, loc, result.operands)) | |||
803 | return failure(); | |||
804 | ||||
805 | result.addTypes(resultType); | |||
806 | return success(); | |||
807 | } | |||
808 | ||||
809 | static void printArithmeticExtendedBinaryOp(Operation *op, | |||
810 | OpAsmPrinter &printer) { | |||
811 | printer << ' '; | |||
812 | printer.printOptionalAttrDict(op->getAttrs()); | |||
813 | printer.printOperands(op->getOperands()); | |||
814 | printer << " : " << op->getResultTypes().front(); | |||
815 | } | |||
816 | ||||
817 | //===----------------------------------------------------------------------===// | |||
818 | // Common parsers and printers | |||
819 | //===----------------------------------------------------------------------===// | |||
820 | ||||
821 | // Parses an atomic update op. If the update op does not take a value (like | |||
822 | // AtomicIIncrement) `hasValue` must be false. | |||
823 | static ParseResult parseAtomicUpdateOp(OpAsmParser &parser, | |||
824 | OperationState &state, bool hasValue) { | |||
825 | spirv::Scope scope; | |||
826 | spirv::MemorySemantics memoryScope; | |||
827 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
828 | OpAsmParser::UnresolvedOperand ptrInfo, valueInfo; | |||
829 | Type type; | |||
830 | SMLoc loc; | |||
831 | if (parseEnumStrAttr<spirv::ScopeAttr>(scope, parser, state, | |||
832 | kMemoryScopeAttrName) || | |||
833 | parseEnumStrAttr<spirv::MemorySemanticsAttr>(memoryScope, parser, state, | |||
834 | kSemanticsAttrName) || | |||
835 | parser.parseOperandList(operandInfo, (hasValue ? 2 : 1)) || | |||
836 | parser.getCurrentLocation(&loc) || parser.parseColonType(type)) | |||
837 | return failure(); | |||
838 | ||||
839 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
840 | if (!ptrType) | |||
841 | return parser.emitError(loc, "expected pointer type"); | |||
842 | ||||
843 | SmallVector<Type, 2> operandTypes; | |||
844 | operandTypes.push_back(ptrType); | |||
845 | if (hasValue) | |||
846 | operandTypes.push_back(ptrType.getPointeeType()); | |||
847 | if (parser.resolveOperands(operandInfo, operandTypes, parser.getNameLoc(), | |||
848 | state.operands)) | |||
849 | return failure(); | |||
850 | return parser.addTypeToList(ptrType.getPointeeType(), state.types); | |||
851 | } | |||
852 | ||||
853 | // Prints an atomic update op. | |||
854 | static void printAtomicUpdateOp(Operation *op, OpAsmPrinter &printer) { | |||
855 | printer << " \""; | |||
856 | auto scopeAttr = op->getAttrOfType<spirv::ScopeAttr>(kMemoryScopeAttrName); | |||
857 | printer << spirv::stringifyScope(scopeAttr.getValue()) << "\" \""; | |||
858 | auto memorySemanticsAttr = | |||
859 | op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName); | |||
860 | printer << spirv::stringifyMemorySemantics(memorySemanticsAttr.getValue()) | |||
861 | << "\" " << op->getOperands() << " : " << op->getOperand(0).getType(); | |||
862 | } | |||
863 | ||||
864 | template <typename T> | |||
865 | static StringRef stringifyTypeName(); | |||
866 | ||||
867 | template <> | |||
868 | StringRef stringifyTypeName<IntegerType>() { | |||
869 | return "integer"; | |||
870 | } | |||
871 | ||||
872 | template <> | |||
873 | StringRef stringifyTypeName<FloatType>() { | |||
874 | return "float"; | |||
875 | } | |||
876 | ||||
877 | // Verifies an atomic update op. | |||
878 | template <typename ExpectedElementType> | |||
879 | static LogicalResult verifyAtomicUpdateOp(Operation *op) { | |||
880 | auto ptrType = op->getOperand(0).getType().cast<spirv::PointerType>(); | |||
881 | auto elementType = ptrType.getPointeeType(); | |||
882 | if (!elementType.isa<ExpectedElementType>()) | |||
883 | return op->emitOpError() << "pointer operand must point to an " | |||
884 | << stringifyTypeName<ExpectedElementType>() | |||
885 | << " value, found " << elementType; | |||
886 | ||||
887 | if (op->getNumOperands() > 1) { | |||
888 | auto valueType = op->getOperand(1).getType(); | |||
889 | if (valueType != elementType) | |||
890 | return op->emitOpError("expected value to have the same type as the " | |||
891 | "pointer operand's pointee type ") | |||
892 | << elementType << ", but found " << valueType; | |||
893 | } | |||
894 | auto memorySemantics = | |||
895 | op->getAttrOfType<spirv::MemorySemanticsAttr>(kSemanticsAttrName) | |||
896 | .getValue(); | |||
897 | if (failed(verifyMemorySemantics(op, memorySemantics))) { | |||
898 | return failure(); | |||
899 | } | |||
900 | return success(); | |||
901 | } | |||
902 | ||||
903 | static ParseResult parseGroupNonUniformArithmeticOp(OpAsmParser &parser, | |||
904 | OperationState &state) { | |||
905 | spirv::Scope executionScope; | |||
906 | spirv::GroupOperation groupOperation; | |||
907 | OpAsmParser::UnresolvedOperand valueInfo; | |||
908 | if (parseEnumStrAttr<spirv::ScopeAttr>(executionScope, parser, state, | |||
909 | kExecutionScopeAttrName) || | |||
910 | parseEnumStrAttr<spirv::GroupOperationAttr>(groupOperation, parser, state, | |||
911 | kGroupOperationAttrName) || | |||
912 | parser.parseOperand(valueInfo)) | |||
913 | return failure(); | |||
914 | ||||
915 | std::optional<OpAsmParser::UnresolvedOperand> clusterSizeInfo; | |||
916 | if (succeeded(parser.parseOptionalKeyword(kClusterSize))) { | |||
917 | clusterSizeInfo = OpAsmParser::UnresolvedOperand(); | |||
918 | if (parser.parseLParen() || parser.parseOperand(*clusterSizeInfo) || | |||
919 | parser.parseRParen()) | |||
920 | return failure(); | |||
921 | } | |||
922 | ||||
923 | Type resultType; | |||
924 | if (parser.parseColonType(resultType)) | |||
925 | return failure(); | |||
926 | ||||
927 | if (parser.resolveOperand(valueInfo, resultType, state.operands)) | |||
928 | return failure(); | |||
929 | ||||
930 | if (clusterSizeInfo) { | |||
931 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
932 | if (parser.resolveOperand(*clusterSizeInfo, i32Type, state.operands)) | |||
933 | return failure(); | |||
934 | } | |||
935 | ||||
936 | return parser.addTypeToList(resultType, state.types); | |||
937 | } | |||
938 | ||||
939 | static void printGroupNonUniformArithmeticOp(Operation *groupOp, | |||
940 | OpAsmPrinter &printer) { | |||
941 | printer | |||
942 | << " \"" | |||
943 | << stringifyScope( | |||
944 | groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName) | |||
945 | .getValue()) | |||
946 | << "\" \"" | |||
947 | << stringifyGroupOperation(groupOp | |||
948 | ->getAttrOfType<spirv::GroupOperationAttr>( | |||
949 | kGroupOperationAttrName) | |||
950 | .getValue()) | |||
951 | << "\" " << groupOp->getOperand(0); | |||
952 | ||||
953 | if (groupOp->getNumOperands() > 1) | |||
954 | printer << " " << kClusterSize << '(' << groupOp->getOperand(1) << ')'; | |||
955 | printer << " : " << groupOp->getResult(0).getType(); | |||
956 | } | |||
957 | ||||
958 | static LogicalResult verifyGroupNonUniformArithmeticOp(Operation *groupOp) { | |||
959 | spirv::Scope scope = | |||
960 | groupOp->getAttrOfType<spirv::ScopeAttr>(kExecutionScopeAttrName) | |||
961 | .getValue(); | |||
962 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
963 | return groupOp->emitOpError( | |||
964 | "execution scope must be 'Workgroup' or 'Subgroup'"); | |||
965 | ||||
966 | spirv::GroupOperation operation = | |||
967 | groupOp->getAttrOfType<spirv::GroupOperationAttr>(kGroupOperationAttrName) | |||
968 | .getValue(); | |||
969 | if (operation == spirv::GroupOperation::ClusteredReduce && | |||
970 | groupOp->getNumOperands() == 1) | |||
971 | return groupOp->emitOpError("cluster size operand must be provided for " | |||
972 | "'ClusteredReduce' group operation"); | |||
973 | if (groupOp->getNumOperands() > 1) { | |||
974 | Operation *sizeOp = groupOp->getOperand(1).getDefiningOp(); | |||
975 | int32_t clusterSize = 0; | |||
976 | ||||
977 | // TODO: support specialization constant here. | |||
978 | if (failed(extractValueFromConstOp(sizeOp, clusterSize))) | |||
979 | return groupOp->emitOpError( | |||
980 | "cluster size operand must come from a constant op"); | |||
981 | ||||
982 | if (!llvm::isPowerOf2_32(clusterSize)) | |||
983 | return groupOp->emitOpError( | |||
984 | "cluster size operand must be a power of two"); | |||
985 | } | |||
986 | return success(); | |||
987 | } | |||
988 | ||||
989 | /// Result of a logical op must be a scalar or vector of boolean type. | |||
990 | static Type getUnaryOpResultType(Type operandType) { | |||
991 | Builder builder(operandType.getContext()); | |||
992 | Type resultType = builder.getIntegerType(1); | |||
993 | if (auto vecType = operandType.dyn_cast<VectorType>()) | |||
994 | return VectorType::get(vecType.getNumElements(), resultType); | |||
995 | return resultType; | |||
996 | } | |||
997 | ||||
998 | static LogicalResult verifyShiftOp(Operation *op) { | |||
999 | if (op->getOperand(0).getType() != op->getResult(0).getType()) { | |||
1000 | return op->emitError("expected the same type for the first operand and " | |||
1001 | "result, but provided ") | |||
1002 | << op->getOperand(0).getType() << " and " | |||
1003 | << op->getResult(0).getType(); | |||
1004 | } | |||
1005 | return success(); | |||
1006 | } | |||
1007 | ||||
1008 | //===----------------------------------------------------------------------===// | |||
1009 | // spirv.AccessChainOp | |||
1010 | //===----------------------------------------------------------------------===// | |||
1011 | ||||
1012 | static Type getElementPtrType(Type type, ValueRange indices, Location baseLoc) { | |||
1013 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1014 | if (!ptrType) { | |||
1015 | emitError(baseLoc, "'spirv.AccessChain' op expected a pointer " | |||
1016 | "to composite type, but provided ") | |||
1017 | << type; | |||
1018 | return nullptr; | |||
1019 | } | |||
1020 | ||||
1021 | auto resultType = ptrType.getPointeeType(); | |||
1022 | auto resultStorageClass = ptrType.getStorageClass(); | |||
1023 | int32_t index = 0; | |||
1024 | ||||
1025 | for (auto indexSSA : indices) { | |||
1026 | auto cType = resultType.dyn_cast<spirv::CompositeType>(); | |||
1027 | if (!cType) { | |||
1028 | emitError( | |||
1029 | baseLoc, | |||
1030 | "'spirv.AccessChain' op cannot extract from non-composite type ") | |||
1031 | << resultType << " with index " << index; | |||
1032 | return nullptr; | |||
1033 | } | |||
1034 | index = 0; | |||
1035 | if (resultType.isa<spirv::StructType>()) { | |||
1036 | Operation *op = indexSSA.getDefiningOp(); | |||
1037 | if (!op) { | |||
1038 | emitError(baseLoc, "'spirv.AccessChain' op index must be an " | |||
1039 | "integer spirv.Constant to access " | |||
1040 | "element of spirv.struct"); | |||
1041 | return nullptr; | |||
1042 | } | |||
1043 | ||||
1044 | // TODO: this should be relaxed to allow | |||
1045 | // integer literals of other bitwidths. | |||
1046 | if (failed(extractValueFromConstOp(op, index))) { | |||
1047 | emitError( | |||
1048 | baseLoc, | |||
1049 | "'spirv.AccessChain' index must be an integer spirv.Constant to " | |||
1050 | "access element of spirv.struct, but provided ") | |||
1051 | << op->getName(); | |||
1052 | return nullptr; | |||
1053 | } | |||
1054 | if (index < 0 || static_cast<uint64_t>(index) >= cType.getNumElements()) { | |||
1055 | emitError(baseLoc, "'spirv.AccessChain' op index ") | |||
1056 | << index << " out of bounds for " << resultType; | |||
1057 | return nullptr; | |||
1058 | } | |||
1059 | } | |||
1060 | resultType = cType.getElementType(index); | |||
1061 | } | |||
1062 | return spirv::PointerType::get(resultType, resultStorageClass); | |||
1063 | } | |||
1064 | ||||
1065 | void spirv::AccessChainOp::build(OpBuilder &builder, OperationState &state, | |||
1066 | Value basePtr, ValueRange indices) { | |||
1067 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
1068 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1068, __extension__ __PRETTY_FUNCTION__)); | |||
1069 | build(builder, state, type, basePtr, indices); | |||
1070 | } | |||
1071 | ||||
1072 | ParseResult spirv::AccessChainOp::parse(OpAsmParser &parser, | |||
1073 | OperationState &result) { | |||
1074 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
1075 | SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo; | |||
1076 | Type type; | |||
1077 | auto loc = parser.getCurrentLocation(); | |||
1078 | SmallVector<Type, 4> indicesTypes; | |||
1079 | ||||
1080 | if (parser.parseOperand(ptrInfo) || | |||
1081 | parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || | |||
1082 | parser.parseColonType(type) || | |||
1083 | parser.resolveOperand(ptrInfo, type, result.operands)) { | |||
1084 | return failure(); | |||
1085 | } | |||
1086 | ||||
1087 | // Check that the provided indices list is not empty before parsing their | |||
1088 | // type list. | |||
1089 | if (indicesInfo.empty()) { | |||
1090 | return mlir::emitError(result.location, | |||
1091 | "'spirv.AccessChain' op expected at " | |||
1092 | "least one index "); | |||
1093 | } | |||
1094 | ||||
1095 | if (parser.parseComma() || parser.parseTypeList(indicesTypes)) | |||
1096 | return failure(); | |||
1097 | ||||
1098 | // Check that the indices types list is not empty and that it has a one-to-one | |||
1099 | // mapping to the provided indices. | |||
1100 | if (indicesTypes.size() != indicesInfo.size()) { | |||
1101 | return mlir::emitError( | |||
1102 | result.location, "'spirv.AccessChain' op indices types' count must be " | |||
1103 | "equal to indices info count"); | |||
1104 | } | |||
1105 | ||||
1106 | if (parser.resolveOperands(indicesInfo, indicesTypes, loc, result.operands)) | |||
1107 | return failure(); | |||
1108 | ||||
1109 | auto resultType = getElementPtrType( | |||
1110 | type, llvm::ArrayRef(result.operands).drop_front(), result.location); | |||
1111 | if (!resultType) { | |||
1112 | return failure(); | |||
1113 | } | |||
1114 | ||||
1115 | result.addTypes(resultType); | |||
1116 | return success(); | |||
1117 | } | |||
1118 | ||||
1119 | template <typename Op> | |||
1120 | static void printAccessChain(Op op, ValueRange indices, OpAsmPrinter &printer) { | |||
1121 | printer << ' ' << op.getBasePtr() << '[' << indices | |||
1122 | << "] : " << op.getBasePtr().getType() << ", " << indices.getTypes(); | |||
1123 | } | |||
1124 | ||||
1125 | void spirv::AccessChainOp::print(OpAsmPrinter &printer) { | |||
1126 | printAccessChain(*this, getIndices(), printer); | |||
1127 | } | |||
1128 | ||||
1129 | template <typename Op> | |||
1130 | static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { | |||
1131 | auto resultType = getElementPtrType(accessChainOp.getBasePtr().getType(), | |||
1132 | indices, accessChainOp.getLoc()); | |||
1133 | if (!resultType) | |||
1134 | return failure(); | |||
1135 | ||||
1136 | auto providedResultType = | |||
1137 | accessChainOp.getType().template dyn_cast<spirv::PointerType>(); | |||
1138 | if (!providedResultType) | |||
1139 | return accessChainOp.emitOpError( | |||
1140 | "result type must be a pointer, but provided") | |||
1141 | << providedResultType; | |||
1142 | ||||
1143 | if (resultType != providedResultType) | |||
1144 | return accessChainOp.emitOpError("invalid result type: expected ") | |||
1145 | << resultType << ", but provided " << providedResultType; | |||
1146 | ||||
1147 | return success(); | |||
1148 | } | |||
1149 | ||||
1150 | LogicalResult spirv::AccessChainOp::verify() { | |||
1151 | return verifyAccessChain(*this, getIndices()); | |||
1152 | } | |||
1153 | ||||
1154 | //===----------------------------------------------------------------------===// | |||
1155 | // spirv.mlir.addressof | |||
1156 | //===----------------------------------------------------------------------===// | |||
1157 | ||||
1158 | void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, | |||
1159 | spirv::GlobalVariableOp var) { | |||
1160 | build(builder, state, var.getType(), SymbolRefAttr::get(var)); | |||
1161 | } | |||
1162 | ||||
1163 | LogicalResult spirv::AddressOfOp::verify() { | |||
1164 | auto varOp = dyn_cast_or_null<spirv::GlobalVariableOp>( | |||
1165 | SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), | |||
1166 | getVariableAttr())); | |||
1167 | if (!varOp) { | |||
1168 | return emitOpError("expected spirv.GlobalVariable symbol"); | |||
1169 | } | |||
1170 | if (getPointer().getType() != varOp.getType()) { | |||
1171 | return emitOpError( | |||
1172 | "result type mismatch with the referenced global variable's type"); | |||
1173 | } | |||
1174 | return success(); | |||
1175 | } | |||
1176 | ||||
1177 | template <typename T> | |||
1178 | static void printAtomicCompareExchangeImpl(T atomOp, OpAsmPrinter &printer) { | |||
1179 | printer << " \"" << stringifyScope(atomOp.getMemoryScope()) << "\" \"" | |||
1180 | << stringifyMemorySemantics(atomOp.getEqualSemantics()) << "\" \"" | |||
1181 | << stringifyMemorySemantics(atomOp.getUnequalSemantics()) << "\" " | |||
1182 | << atomOp.getOperands() << " : " << atomOp.getPointer().getType(); | |||
1183 | } | |||
1184 | ||||
1185 | static ParseResult parseAtomicCompareExchangeImpl(OpAsmParser &parser, | |||
1186 | OperationState &state) { | |||
1187 | spirv::Scope memoryScope; | |||
1188 | spirv::MemorySemantics equalSemantics, unequalSemantics; | |||
1189 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; | |||
1190 | Type type; | |||
1191 | if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, state, | |||
1192 | kMemoryScopeAttrName) || | |||
1193 | parseEnumStrAttr<spirv::MemorySemanticsAttr>( | |||
1194 | equalSemantics, parser, state, kEqualSemanticsAttrName) || | |||
1195 | parseEnumStrAttr<spirv::MemorySemanticsAttr>( | |||
1196 | unequalSemantics, parser, state, kUnequalSemanticsAttrName) || | |||
1197 | parser.parseOperandList(operandInfo, 3)) | |||
1198 | return failure(); | |||
1199 | ||||
1200 | auto loc = parser.getCurrentLocation(); | |||
1201 | if (parser.parseColonType(type)) | |||
1202 | return failure(); | |||
1203 | ||||
1204 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1205 | if (!ptrType) | |||
1206 | return parser.emitError(loc, "expected pointer type"); | |||
1207 | ||||
1208 | if (parser.resolveOperands( | |||
1209 | operandInfo, | |||
1210 | {ptrType, ptrType.getPointeeType(), ptrType.getPointeeType()}, | |||
1211 | parser.getNameLoc(), state.operands)) | |||
1212 | return failure(); | |||
1213 | ||||
1214 | return parser.addTypeToList(ptrType.getPointeeType(), state.types); | |||
1215 | } | |||
1216 | ||||
1217 | template <typename T> | |||
1218 | static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { | |||
1219 | // According to the spec: | |||
1220 | // "The type of Value must be the same as Result Type. The type of the value | |||
1221 | // pointed to by Pointer must be the same as Result Type. This type must also | |||
1222 | // match the type of Comparator." | |||
1223 | if (atomOp.getType() != atomOp.getValue().getType()) | |||
1224 | return atomOp.emitOpError("value operand must have the same type as the op " | |||
1225 | "result, but found ") | |||
1226 | << atomOp.getValue().getType() << " vs " << atomOp.getType(); | |||
1227 | ||||
1228 | if (atomOp.getType() != atomOp.getComparator().getType()) | |||
1229 | return atomOp.emitOpError( | |||
1230 | "comparator operand must have the same type as the op " | |||
1231 | "result, but found ") | |||
1232 | << atomOp.getComparator().getType() << " vs " << atomOp.getType(); | |||
1233 | ||||
1234 | Type pointeeType = atomOp.getPointer() | |||
1235 | .getType() | |||
1236 | .template cast<spirv::PointerType>() | |||
1237 | .getPointeeType(); | |||
1238 | if (atomOp.getType() != pointeeType) | |||
1239 | return atomOp.emitOpError( | |||
1240 | "pointer operand's pointee type must have the same " | |||
1241 | "as the op result type, but found ") | |||
1242 | << pointeeType << " vs " << atomOp.getType(); | |||
1243 | ||||
1244 | // TODO: Unequal cannot be set to Release or Acquire and Release. | |||
1245 | // In addition, Unequal cannot be set to a stronger memory-order then Equal. | |||
1246 | ||||
1247 | return success(); | |||
1248 | } | |||
1249 | ||||
1250 | //===----------------------------------------------------------------------===// | |||
1251 | // spirv.AtomicAndOp | |||
1252 | //===----------------------------------------------------------------------===// | |||
1253 | ||||
1254 | LogicalResult spirv::AtomicAndOp::verify() { | |||
1255 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1256 | } | |||
1257 | ||||
1258 | ParseResult spirv::AtomicAndOp::parse(OpAsmParser &parser, | |||
1259 | OperationState &result) { | |||
1260 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1261 | } | |||
1262 | void spirv::AtomicAndOp::print(OpAsmPrinter &p) { | |||
1263 | ::printAtomicUpdateOp(*this, p); | |||
1264 | } | |||
1265 | ||||
1266 | //===----------------------------------------------------------------------===// | |||
1267 | // spirv.AtomicCompareExchangeOp | |||
1268 | //===----------------------------------------------------------------------===// | |||
1269 | ||||
1270 | LogicalResult spirv::AtomicCompareExchangeOp::verify() { | |||
1271 | return ::verifyAtomicCompareExchangeImpl(*this); | |||
1272 | } | |||
1273 | ||||
1274 | ParseResult spirv::AtomicCompareExchangeOp::parse(OpAsmParser &parser, | |||
1275 | OperationState &result) { | |||
1276 | return ::parseAtomicCompareExchangeImpl(parser, result); | |||
1277 | } | |||
1278 | void spirv::AtomicCompareExchangeOp::print(OpAsmPrinter &p) { | |||
1279 | ::printAtomicCompareExchangeImpl(*this, p); | |||
1280 | } | |||
1281 | ||||
1282 | //===----------------------------------------------------------------------===// | |||
1283 | // spirv.AtomicCompareExchangeWeakOp | |||
1284 | //===----------------------------------------------------------------------===// | |||
1285 | ||||
1286 | LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { | |||
1287 | return ::verifyAtomicCompareExchangeImpl(*this); | |||
1288 | } | |||
1289 | ||||
1290 | ParseResult spirv::AtomicCompareExchangeWeakOp::parse(OpAsmParser &parser, | |||
1291 | OperationState &result) { | |||
1292 | return ::parseAtomicCompareExchangeImpl(parser, result); | |||
1293 | } | |||
1294 | void spirv::AtomicCompareExchangeWeakOp::print(OpAsmPrinter &p) { | |||
1295 | ::printAtomicCompareExchangeImpl(*this, p); | |||
1296 | } | |||
1297 | ||||
1298 | //===----------------------------------------------------------------------===// | |||
1299 | // spirv.AtomicExchange | |||
1300 | //===----------------------------------------------------------------------===// | |||
1301 | ||||
1302 | void spirv::AtomicExchangeOp::print(OpAsmPrinter &printer) { | |||
1303 | printer << " \"" << stringifyScope(getMemoryScope()) << "\" \"" | |||
1304 | << stringifyMemorySemantics(getSemantics()) << "\" " << getOperands() | |||
1305 | << " : " << getPointer().getType(); | |||
1306 | } | |||
1307 | ||||
1308 | ParseResult spirv::AtomicExchangeOp::parse(OpAsmParser &parser, | |||
1309 | OperationState &result) { | |||
1310 | spirv::Scope memoryScope; | |||
1311 | spirv::MemorySemantics semantics; | |||
1312 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
1313 | Type type; | |||
1314 | if (parseEnumStrAttr<spirv::ScopeAttr>(memoryScope, parser, result, | |||
1315 | kMemoryScopeAttrName) || | |||
1316 | parseEnumStrAttr<spirv::MemorySemanticsAttr>(semantics, parser, result, | |||
1317 | kSemanticsAttrName) || | |||
1318 | parser.parseOperandList(operandInfo, 2)) | |||
1319 | return failure(); | |||
1320 | ||||
1321 | auto loc = parser.getCurrentLocation(); | |||
1322 | if (parser.parseColonType(type)) | |||
1323 | return failure(); | |||
1324 | ||||
1325 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
1326 | if (!ptrType) | |||
1327 | return parser.emitError(loc, "expected pointer type"); | |||
1328 | ||||
1329 | if (parser.resolveOperands(operandInfo, {ptrType, ptrType.getPointeeType()}, | |||
1330 | parser.getNameLoc(), result.operands)) | |||
1331 | return failure(); | |||
1332 | ||||
1333 | return parser.addTypeToList(ptrType.getPointeeType(), result.types); | |||
1334 | } | |||
1335 | ||||
1336 | LogicalResult spirv::AtomicExchangeOp::verify() { | |||
1337 | if (getType() != getValue().getType()) | |||
1338 | return emitOpError("value operand must have the same type as the op " | |||
1339 | "result, but found ") | |||
1340 | << getValue().getType() << " vs " << getType(); | |||
1341 | ||||
1342 | Type pointeeType = | |||
1343 | getPointer().getType().cast<spirv::PointerType>().getPointeeType(); | |||
1344 | if (getType() != pointeeType) | |||
1345 | return emitOpError("pointer operand's pointee type must have the same " | |||
1346 | "as the op result type, but found ") | |||
1347 | << pointeeType << " vs " << getType(); | |||
1348 | ||||
1349 | return success(); | |||
1350 | } | |||
1351 | ||||
1352 | //===----------------------------------------------------------------------===// | |||
1353 | // spirv.AtomicIAddOp | |||
1354 | //===----------------------------------------------------------------------===// | |||
1355 | ||||
1356 | LogicalResult spirv::AtomicIAddOp::verify() { | |||
1357 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1358 | } | |||
1359 | ||||
1360 | ParseResult spirv::AtomicIAddOp::parse(OpAsmParser &parser, | |||
1361 | OperationState &result) { | |||
1362 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1363 | } | |||
1364 | void spirv::AtomicIAddOp::print(OpAsmPrinter &p) { | |||
1365 | ::printAtomicUpdateOp(*this, p); | |||
1366 | } | |||
1367 | ||||
1368 | //===----------------------------------------------------------------------===// | |||
1369 | // spirv.EXT.AtomicFAddOp | |||
1370 | //===----------------------------------------------------------------------===// | |||
1371 | ||||
1372 | LogicalResult spirv::EXTAtomicFAddOp::verify() { | |||
1373 | return ::verifyAtomicUpdateOp<FloatType>(getOperation()); | |||
1374 | } | |||
1375 | ||||
1376 | ParseResult spirv::EXTAtomicFAddOp::parse(OpAsmParser &parser, | |||
1377 | OperationState &result) { | |||
1378 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1379 | } | |||
1380 | void spirv::EXTAtomicFAddOp::print(OpAsmPrinter &p) { | |||
1381 | ::printAtomicUpdateOp(*this, p); | |||
1382 | } | |||
1383 | ||||
1384 | //===----------------------------------------------------------------------===// | |||
1385 | // spirv.AtomicIDecrementOp | |||
1386 | //===----------------------------------------------------------------------===// | |||
1387 | ||||
1388 | LogicalResult spirv::AtomicIDecrementOp::verify() { | |||
1389 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1390 | } | |||
1391 | ||||
1392 | ParseResult spirv::AtomicIDecrementOp::parse(OpAsmParser &parser, | |||
1393 | OperationState &result) { | |||
1394 | return ::parseAtomicUpdateOp(parser, result, false); | |||
1395 | } | |||
1396 | void spirv::AtomicIDecrementOp::print(OpAsmPrinter &p) { | |||
1397 | ::printAtomicUpdateOp(*this, p); | |||
1398 | } | |||
1399 | ||||
1400 | //===----------------------------------------------------------------------===// | |||
1401 | // spirv.AtomicIIncrementOp | |||
1402 | //===----------------------------------------------------------------------===// | |||
1403 | ||||
1404 | LogicalResult spirv::AtomicIIncrementOp::verify() { | |||
1405 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1406 | } | |||
1407 | ||||
1408 | ParseResult spirv::AtomicIIncrementOp::parse(OpAsmParser &parser, | |||
1409 | OperationState &result) { | |||
1410 | return ::parseAtomicUpdateOp(parser, result, false); | |||
1411 | } | |||
1412 | void spirv::AtomicIIncrementOp::print(OpAsmPrinter &p) { | |||
1413 | ::printAtomicUpdateOp(*this, p); | |||
1414 | } | |||
1415 | ||||
1416 | //===----------------------------------------------------------------------===// | |||
1417 | // spirv.AtomicISubOp | |||
1418 | //===----------------------------------------------------------------------===// | |||
1419 | ||||
1420 | LogicalResult spirv::AtomicISubOp::verify() { | |||
1421 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1422 | } | |||
1423 | ||||
1424 | ParseResult spirv::AtomicISubOp::parse(OpAsmParser &parser, | |||
1425 | OperationState &result) { | |||
1426 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1427 | } | |||
1428 | void spirv::AtomicISubOp::print(OpAsmPrinter &p) { | |||
1429 | ::printAtomicUpdateOp(*this, p); | |||
1430 | } | |||
1431 | ||||
1432 | //===----------------------------------------------------------------------===// | |||
1433 | // spirv.AtomicOrOp | |||
1434 | //===----------------------------------------------------------------------===// | |||
1435 | ||||
1436 | LogicalResult spirv::AtomicOrOp::verify() { | |||
1437 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1438 | } | |||
1439 | ||||
1440 | ParseResult spirv::AtomicOrOp::parse(OpAsmParser &parser, | |||
1441 | OperationState &result) { | |||
1442 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1443 | } | |||
1444 | void spirv::AtomicOrOp::print(OpAsmPrinter &p) { | |||
1445 | ::printAtomicUpdateOp(*this, p); | |||
1446 | } | |||
1447 | ||||
1448 | //===----------------------------------------------------------------------===// | |||
1449 | // spirv.AtomicSMaxOp | |||
1450 | //===----------------------------------------------------------------------===// | |||
1451 | ||||
1452 | LogicalResult spirv::AtomicSMaxOp::verify() { | |||
1453 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1454 | } | |||
1455 | ||||
1456 | ParseResult spirv::AtomicSMaxOp::parse(OpAsmParser &parser, | |||
1457 | OperationState &result) { | |||
1458 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1459 | } | |||
1460 | void spirv::AtomicSMaxOp::print(OpAsmPrinter &p) { | |||
1461 | ::printAtomicUpdateOp(*this, p); | |||
1462 | } | |||
1463 | ||||
1464 | //===----------------------------------------------------------------------===// | |||
1465 | // spirv.AtomicSMinOp | |||
1466 | //===----------------------------------------------------------------------===// | |||
1467 | ||||
1468 | LogicalResult spirv::AtomicSMinOp::verify() { | |||
1469 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1470 | } | |||
1471 | ||||
1472 | ParseResult spirv::AtomicSMinOp::parse(OpAsmParser &parser, | |||
1473 | OperationState &result) { | |||
1474 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1475 | } | |||
1476 | void spirv::AtomicSMinOp::print(OpAsmPrinter &p) { | |||
1477 | ::printAtomicUpdateOp(*this, p); | |||
1478 | } | |||
1479 | ||||
1480 | //===----------------------------------------------------------------------===// | |||
1481 | // spirv.AtomicUMaxOp | |||
1482 | //===----------------------------------------------------------------------===// | |||
1483 | ||||
1484 | LogicalResult spirv::AtomicUMaxOp::verify() { | |||
1485 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1486 | } | |||
1487 | ||||
1488 | ParseResult spirv::AtomicUMaxOp::parse(OpAsmParser &parser, | |||
1489 | OperationState &result) { | |||
1490 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1491 | } | |||
1492 | void spirv::AtomicUMaxOp::print(OpAsmPrinter &p) { | |||
1493 | ::printAtomicUpdateOp(*this, p); | |||
1494 | } | |||
1495 | ||||
1496 | //===----------------------------------------------------------------------===// | |||
1497 | // spirv.AtomicUMinOp | |||
1498 | //===----------------------------------------------------------------------===// | |||
1499 | ||||
1500 | LogicalResult spirv::AtomicUMinOp::verify() { | |||
1501 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1502 | } | |||
1503 | ||||
1504 | ParseResult spirv::AtomicUMinOp::parse(OpAsmParser &parser, | |||
1505 | OperationState &result) { | |||
1506 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1507 | } | |||
1508 | void spirv::AtomicUMinOp::print(OpAsmPrinter &p) { | |||
1509 | ::printAtomicUpdateOp(*this, p); | |||
1510 | } | |||
1511 | ||||
1512 | //===----------------------------------------------------------------------===// | |||
1513 | // spirv.AtomicXorOp | |||
1514 | //===----------------------------------------------------------------------===// | |||
1515 | ||||
1516 | LogicalResult spirv::AtomicXorOp::verify() { | |||
1517 | return ::verifyAtomicUpdateOp<IntegerType>(getOperation()); | |||
1518 | } | |||
1519 | ||||
1520 | ParseResult spirv::AtomicXorOp::parse(OpAsmParser &parser, | |||
1521 | OperationState &result) { | |||
1522 | return ::parseAtomicUpdateOp(parser, result, true); | |||
1523 | } | |||
1524 | void spirv::AtomicXorOp::print(OpAsmPrinter &p) { | |||
1525 | ::printAtomicUpdateOp(*this, p); | |||
1526 | } | |||
1527 | ||||
1528 | //===----------------------------------------------------------------------===// | |||
1529 | // spirv.BitcastOp | |||
1530 | //===----------------------------------------------------------------------===// | |||
1531 | ||||
1532 | LogicalResult spirv::BitcastOp::verify() { | |||
1533 | // TODO: The SPIR-V spec validation rules are different for different | |||
1534 | // versions. | |||
1535 | auto operandType = getOperand().getType(); | |||
1536 | auto resultType = getResult().getType(); | |||
1537 | if (operandType == resultType) { | |||
1538 | return emitError("result type must be different from operand type"); | |||
1539 | } | |||
1540 | if (operandType.isa<spirv::PointerType>() && | |||
1541 | !resultType.isa<spirv::PointerType>()) { | |||
1542 | return emitError( | |||
1543 | "unhandled bit cast conversion from pointer type to non-pointer type"); | |||
1544 | } | |||
1545 | if (!operandType.isa<spirv::PointerType>() && | |||
1546 | resultType.isa<spirv::PointerType>()) { | |||
1547 | return emitError( | |||
1548 | "unhandled bit cast conversion from non-pointer type to pointer type"); | |||
1549 | } | |||
1550 | auto operandBitWidth = getBitWidth(operandType); | |||
1551 | auto resultBitWidth = getBitWidth(resultType); | |||
1552 | if (operandBitWidth != resultBitWidth) { | |||
1553 | return emitOpError("mismatch in result type bitwidth ") | |||
1554 | << resultBitWidth << " and operand type bitwidth " | |||
1555 | << operandBitWidth; | |||
1556 | } | |||
1557 | return success(); | |||
1558 | } | |||
1559 | ||||
1560 | //===----------------------------------------------------------------------===// | |||
1561 | // spirv.PtrCastToGenericOp | |||
1562 | //===----------------------------------------------------------------------===// | |||
1563 | ||||
1564 | LogicalResult spirv::PtrCastToGenericOp::verify() { | |||
1565 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1566 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1567 | ||||
1568 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1569 | if (operandStorage != spirv::StorageClass::Workgroup && | |||
1570 | operandStorage != spirv::StorageClass::CrossWorkgroup && | |||
1571 | operandStorage != spirv::StorageClass::Function) | |||
1572 | return emitError("pointer must point to the Workgroup, CrossWorkgroup" | |||
1573 | ", or Function Storage Class"); | |||
1574 | ||||
1575 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1576 | if (resultStorage != spirv::StorageClass::Generic) | |||
1577 | return emitError("result type must be of storage class Generic"); | |||
1578 | ||||
1579 | Type operandPointeeType = operandType.getPointeeType(); | |||
1580 | Type resultPointeeType = resultType.getPointeeType(); | |||
1581 | if (operandPointeeType != resultPointeeType) | |||
1582 | return emitOpError("pointer operand's pointee type must have the same " | |||
1583 | "as the op result type, but found ") | |||
1584 | << operandPointeeType << " vs " << resultPointeeType; | |||
1585 | return success(); | |||
1586 | } | |||
1587 | ||||
1588 | //===----------------------------------------------------------------------===// | |||
1589 | // spirv.GenericCastToPtrOp | |||
1590 | //===----------------------------------------------------------------------===// | |||
1591 | ||||
1592 | LogicalResult spirv::GenericCastToPtrOp::verify() { | |||
1593 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1594 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1595 | ||||
1596 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1597 | if (operandStorage != spirv::StorageClass::Generic) | |||
1598 | return emitError("pointer type must be of storage class Generic"); | |||
1599 | ||||
1600 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1601 | if (resultStorage != spirv::StorageClass::Workgroup && | |||
1602 | resultStorage != spirv::StorageClass::CrossWorkgroup && | |||
1603 | resultStorage != spirv::StorageClass::Function) | |||
1604 | return emitError("result must point to the Workgroup, CrossWorkgroup, " | |||
1605 | "or Function Storage Class"); | |||
1606 | ||||
1607 | Type operandPointeeType = operandType.getPointeeType(); | |||
1608 | Type resultPointeeType = resultType.getPointeeType(); | |||
1609 | if (operandPointeeType != resultPointeeType) | |||
1610 | return emitOpError("pointer operand's pointee type must have the same " | |||
1611 | "as the op result type, but found ") | |||
1612 | << operandPointeeType << " vs " << resultPointeeType; | |||
1613 | return success(); | |||
1614 | } | |||
1615 | ||||
1616 | //===----------------------------------------------------------------------===// | |||
1617 | // spirv.GenericCastToPtrExplicitOp | |||
1618 | //===----------------------------------------------------------------------===// | |||
1619 | ||||
1620 | LogicalResult spirv::GenericCastToPtrExplicitOp::verify() { | |||
1621 | auto operandType = getPointer().getType().cast<spirv::PointerType>(); | |||
1622 | auto resultType = getResult().getType().cast<spirv::PointerType>(); | |||
1623 | ||||
1624 | spirv::StorageClass operandStorage = operandType.getStorageClass(); | |||
1625 | if (operandStorage != spirv::StorageClass::Generic) | |||
1626 | return emitError("pointer type must be of storage class Generic"); | |||
1627 | ||||
1628 | spirv::StorageClass resultStorage = resultType.getStorageClass(); | |||
1629 | if (resultStorage != spirv::StorageClass::Workgroup && | |||
1630 | resultStorage != spirv::StorageClass::CrossWorkgroup && | |||
1631 | resultStorage != spirv::StorageClass::Function) | |||
1632 | return emitError("result must point to the Workgroup, CrossWorkgroup, " | |||
1633 | "or Function Storage Class"); | |||
1634 | ||||
1635 | Type operandPointeeType = operandType.getPointeeType(); | |||
1636 | Type resultPointeeType = resultType.getPointeeType(); | |||
1637 | if (operandPointeeType != resultPointeeType) | |||
1638 | return emitOpError("pointer operand's pointee type must have the same " | |||
1639 | "as the op result type, but found ") | |||
1640 | << operandPointeeType << " vs " << resultPointeeType; | |||
1641 | return success(); | |||
1642 | } | |||
1643 | ||||
1644 | //===----------------------------------------------------------------------===// | |||
1645 | // spirv.BranchOp | |||
1646 | //===----------------------------------------------------------------------===// | |||
1647 | ||||
1648 | SuccessorOperands spirv::BranchOp::getSuccessorOperands(unsigned index) { | |||
1649 | assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index" ) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1649, __extension__ __PRETTY_FUNCTION__)); | |||
1650 | return SuccessorOperands(0, getTargetOperandsMutable()); | |||
1651 | } | |||
1652 | ||||
1653 | //===----------------------------------------------------------------------===// | |||
1654 | // spirv.BranchConditionalOp | |||
1655 | //===----------------------------------------------------------------------===// | |||
1656 | ||||
1657 | SuccessorOperands | |||
1658 | spirv::BranchConditionalOp::getSuccessorOperands(unsigned index) { | |||
1659 | assert(index < 2 && "invalid successor index")(static_cast <bool> (index < 2 && "invalid successor index" ) ? void (0) : __assert_fail ("index < 2 && \"invalid successor index\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 1659, __extension__ __PRETTY_FUNCTION__)); | |||
1660 | return SuccessorOperands(index == kTrueIndex | |||
1661 | ? getTrueTargetOperandsMutable() | |||
1662 | : getFalseTargetOperandsMutable()); | |||
1663 | } | |||
1664 | ||||
1665 | ParseResult spirv::BranchConditionalOp::parse(OpAsmParser &parser, | |||
1666 | OperationState &result) { | |||
1667 | auto &builder = parser.getBuilder(); | |||
1668 | OpAsmParser::UnresolvedOperand condInfo; | |||
1669 | Block *dest; | |||
1670 | ||||
1671 | // Parse the condition. | |||
1672 | Type boolTy = builder.getI1Type(); | |||
1673 | if (parser.parseOperand(condInfo) || | |||
1674 | parser.resolveOperand(condInfo, boolTy, result.operands)) | |||
1675 | return failure(); | |||
1676 | ||||
1677 | // Parse the optional branch weights. | |||
1678 | if (succeeded(parser.parseOptionalLSquare())) { | |||
1679 | IntegerAttr trueWeight, falseWeight; | |||
1680 | NamedAttrList weights; | |||
1681 | ||||
1682 | auto i32Type = builder.getIntegerType(32); | |||
1683 | if (parser.parseAttribute(trueWeight, i32Type, "weight", weights) || | |||
1684 | parser.parseComma() || | |||
1685 | parser.parseAttribute(falseWeight, i32Type, "weight", weights) || | |||
1686 | parser.parseRSquare()) | |||
1687 | return failure(); | |||
1688 | ||||
1689 | result.addAttribute(kBranchWeightAttrName, | |||
1690 | builder.getArrayAttr({trueWeight, falseWeight})); | |||
1691 | } | |||
1692 | ||||
1693 | // Parse the true branch. | |||
1694 | SmallVector<Value, 4> trueOperands; | |||
1695 | if (parser.parseComma() || | |||
1696 | parser.parseSuccessorAndUseList(dest, trueOperands)) | |||
1697 | return failure(); | |||
1698 | result.addSuccessors(dest); | |||
1699 | result.addOperands(trueOperands); | |||
1700 | ||||
1701 | // Parse the false branch. | |||
1702 | SmallVector<Value, 4> falseOperands; | |||
1703 | if (parser.parseComma() || | |||
1704 | parser.parseSuccessorAndUseList(dest, falseOperands)) | |||
1705 | return failure(); | |||
1706 | result.addSuccessors(dest); | |||
1707 | result.addOperands(falseOperands); | |||
1708 | result.addAttribute(spirv::BranchConditionalOp::getOperandSegmentSizeAttr(), | |||
1709 | builder.getDenseI32ArrayAttr( | |||
1710 | {1, static_cast<int32_t>(trueOperands.size()), | |||
1711 | static_cast<int32_t>(falseOperands.size())})); | |||
1712 | ||||
1713 | return success(); | |||
1714 | } | |||
1715 | ||||
1716 | void spirv::BranchConditionalOp::print(OpAsmPrinter &printer) { | |||
1717 | printer << ' ' << getCondition(); | |||
1718 | ||||
1719 | if (auto weights = getBranchWeights()) { | |||
1720 | printer << " ["; | |||
1721 | llvm::interleaveComma(weights->getValue(), printer, [&](Attribute a) { | |||
1722 | printer << a.cast<IntegerAttr>().getInt(); | |||
1723 | }); | |||
1724 | printer << "]"; | |||
1725 | } | |||
1726 | ||||
1727 | printer << ", "; | |||
1728 | printer.printSuccessorAndUseList(getTrueBlock(), getTrueBlockArguments()); | |||
1729 | printer << ", "; | |||
1730 | printer.printSuccessorAndUseList(getFalseBlock(), getFalseBlockArguments()); | |||
1731 | } | |||
1732 | ||||
1733 | LogicalResult spirv::BranchConditionalOp::verify() { | |||
1734 | if (auto weights = getBranchWeights()) { | |||
1735 | if (weights->getValue().size() != 2) { | |||
1736 | return emitOpError("must have exactly two branch weights"); | |||
1737 | } | |||
1738 | if (llvm::all_of(*weights, [](Attribute attr) { | |||
1739 | return attr.cast<IntegerAttr>().getValue().isZero(); | |||
1740 | })) | |||
1741 | return emitOpError("branch weights cannot both be zero"); | |||
1742 | } | |||
1743 | ||||
1744 | return success(); | |||
1745 | } | |||
1746 | ||||
1747 | //===----------------------------------------------------------------------===// | |||
1748 | // spirv.CompositeConstruct | |||
1749 | //===----------------------------------------------------------------------===// | |||
1750 | ||||
1751 | LogicalResult spirv::CompositeConstructOp::verify() { | |||
1752 | auto cType = getType().cast<spirv::CompositeType>(); | |||
1753 | operand_range constituents = this->getConstituents(); | |||
1754 | ||||
1755 | if (auto coopType = cType.dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
1756 | if (constituents.size() != 1) | |||
1757 | return emitOpError("has incorrect number of operands: expected ") | |||
1758 | << "1, but provided " << constituents.size(); | |||
1759 | if (coopType.getElementType() != constituents.front().getType()) | |||
1760 | return emitOpError("operand type mismatch: expected operand type ") | |||
1761 | << coopType.getElementType() << ", but provided " | |||
1762 | << constituents.front().getType(); | |||
1763 | return success(); | |||
1764 | } | |||
1765 | ||||
1766 | if (auto jointType = cType.dyn_cast<spirv::JointMatrixINTELType>()) { | |||
1767 | if (constituents.size() != 1) | |||
1768 | return emitOpError("has incorrect number of operands: expected ") | |||
1769 | << "1, but provided " << constituents.size(); | |||
1770 | if (jointType.getElementType() != constituents.front().getType()) | |||
1771 | return emitOpError("operand type mismatch: expected operand type ") | |||
1772 | << jointType.getElementType() << ", but provided " | |||
1773 | << constituents.front().getType(); | |||
1774 | return success(); | |||
1775 | } | |||
1776 | ||||
1777 | if (constituents.size() == cType.getNumElements()) { | |||
1778 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { | |||
1779 | if (constituents[index].getType() != cType.getElementType(index)) { | |||
1780 | return emitOpError("operand type mismatch: expected operand type ") | |||
1781 | << cType.getElementType(index) << ", but provided " | |||
1782 | << constituents[index].getType(); | |||
1783 | } | |||
1784 | } | |||
1785 | return success(); | |||
1786 | } | |||
1787 | ||||
1788 | // If not constructing a cooperative matrix type, then we must be constructing | |||
1789 | // a vector type. | |||
1790 | auto resultType = cType.dyn_cast<VectorType>(); | |||
1791 | if (!resultType) | |||
1792 | return emitOpError( | |||
1793 | "expected to return a vector or cooperative matrix when the number of " | |||
1794 | "constituents is less than what the result needs"); | |||
1795 | ||||
1796 | SmallVector<unsigned> sizes; | |||
1797 | for (Value component : constituents) { | |||
1798 | if (!component.getType().isa<VectorType>() && | |||
1799 | !component.getType().isIntOrFloat()) | |||
1800 | return emitOpError("operand type mismatch: expected operand to have " | |||
1801 | "a scalar or vector type, but provided ") | |||
1802 | << component.getType(); | |||
1803 | ||||
1804 | Type elementType = component.getType(); | |||
1805 | if (auto vectorType = component.getType().dyn_cast<VectorType>()) { | |||
1806 | sizes.push_back(vectorType.getNumElements()); | |||
1807 | elementType = vectorType.getElementType(); | |||
1808 | } else { | |||
1809 | sizes.push_back(1); | |||
1810 | } | |||
1811 | ||||
1812 | if (elementType != resultType.getElementType()) | |||
1813 | return emitOpError("operand element type mismatch: expected to be ") | |||
1814 | << resultType.getElementType() << ", but provided " << elementType; | |||
1815 | } | |||
1816 | unsigned totalCount = std::accumulate(sizes.begin(), sizes.end(), 0); | |||
1817 | if (totalCount != cType.getNumElements()) | |||
1818 | return emitOpError("has incorrect number of operands: expected ") | |||
1819 | << cType.getNumElements() << ", but provided " << totalCount; | |||
1820 | return success(); | |||
1821 | } | |||
1822 | ||||
1823 | //===----------------------------------------------------------------------===// | |||
1824 | // spirv.CompositeExtractOp | |||
1825 | //===----------------------------------------------------------------------===// | |||
1826 | ||||
1827 | void spirv::CompositeExtractOp::build(OpBuilder &builder, OperationState &state, | |||
1828 | Value composite, | |||
1829 | ArrayRef<int32_t> indices) { | |||
1830 | auto indexAttr = builder.getI32ArrayAttr(indices); | |||
1831 | auto elementType = | |||
1832 | getElementType(composite.getType(), indexAttr, state.location); | |||
1833 | if (!elementType) { | |||
1834 | return; | |||
1835 | } | |||
1836 | build(builder, state, elementType, composite, indexAttr); | |||
1837 | } | |||
1838 | ||||
1839 | ParseResult spirv::CompositeExtractOp::parse(OpAsmParser &parser, | |||
1840 | OperationState &result) { | |||
1841 | OpAsmParser::UnresolvedOperand compositeInfo; | |||
1842 | Attribute indicesAttr; | |||
1843 | Type compositeType; | |||
1844 | SMLoc attrLocation; | |||
1845 | ||||
1846 | if (parser.parseOperand(compositeInfo) || | |||
1847 | parser.getCurrentLocation(&attrLocation) || | |||
1848 | parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || | |||
1849 | parser.parseColonType(compositeType) || | |||
1850 | parser.resolveOperand(compositeInfo, compositeType, result.operands)) { | |||
1851 | return failure(); | |||
1852 | } | |||
1853 | ||||
1854 | Type resultType = | |||
1855 | getElementType(compositeType, indicesAttr, parser, attrLocation); | |||
1856 | if (!resultType) { | |||
1857 | return failure(); | |||
1858 | } | |||
1859 | result.addTypes(resultType); | |||
1860 | return success(); | |||
1861 | } | |||
1862 | ||||
1863 | void spirv::CompositeExtractOp::print(OpAsmPrinter &printer) { | |||
1864 | printer << ' ' << getComposite() << getIndices() << " : " | |||
1865 | << getComposite().getType(); | |||
1866 | } | |||
1867 | ||||
1868 | LogicalResult spirv::CompositeExtractOp::verify() { | |||
1869 | auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>(); | |||
1870 | auto resultType = | |||
1871 | getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); | |||
1872 | if (!resultType) | |||
1873 | return failure(); | |||
1874 | ||||
1875 | if (resultType != getType()) { | |||
1876 | return emitOpError("invalid result type: expected ") | |||
1877 | << resultType << " but provided " << getType(); | |||
1878 | } | |||
1879 | ||||
1880 | return success(); | |||
1881 | } | |||
1882 | ||||
1883 | //===----------------------------------------------------------------------===// | |||
1884 | // spirv.CompositeInsert | |||
1885 | //===----------------------------------------------------------------------===// | |||
1886 | ||||
1887 | void spirv::CompositeInsertOp::build(OpBuilder &builder, OperationState &state, | |||
1888 | Value object, Value composite, | |||
1889 | ArrayRef<int32_t> indices) { | |||
1890 | auto indexAttr = builder.getI32ArrayAttr(indices); | |||
1891 | build(builder, state, composite.getType(), object, composite, indexAttr); | |||
1892 | } | |||
1893 | ||||
1894 | ParseResult spirv::CompositeInsertOp::parse(OpAsmParser &parser, | |||
1895 | OperationState &result) { | |||
1896 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operands; | |||
1897 | Type objectType, compositeType; | |||
1898 | Attribute indicesAttr; | |||
1899 | auto loc = parser.getCurrentLocation(); | |||
1900 | ||||
1901 | return failure( | |||
1902 | parser.parseOperandList(operands, 2) || | |||
1903 | parser.parseAttribute(indicesAttr, kIndicesAttrName, result.attributes) || | |||
1904 | parser.parseColonType(objectType) || | |||
1905 | parser.parseKeywordType("into", compositeType) || | |||
1906 | parser.resolveOperands(operands, {objectType, compositeType}, loc, | |||
1907 | result.operands) || | |||
1908 | parser.addTypesToList(compositeType, result.types)); | |||
1909 | } | |||
1910 | ||||
1911 | LogicalResult spirv::CompositeInsertOp::verify() { | |||
1912 | auto indicesArrayAttr = getIndices().dyn_cast<ArrayAttr>(); | |||
1913 | auto objectType = | |||
1914 | getElementType(getComposite().getType(), indicesArrayAttr, getLoc()); | |||
1915 | if (!objectType) | |||
1916 | return failure(); | |||
1917 | ||||
1918 | if (objectType != getObject().getType()) { | |||
1919 | return emitOpError("object operand type should be ") | |||
1920 | << objectType << ", but found " << getObject().getType(); | |||
1921 | } | |||
1922 | ||||
1923 | if (getComposite().getType() != getType()) { | |||
1924 | return emitOpError("result type should be the same as " | |||
1925 | "the composite type, but found ") | |||
1926 | << getComposite().getType() << " vs " << getType(); | |||
1927 | } | |||
1928 | ||||
1929 | return success(); | |||
1930 | } | |||
1931 | ||||
1932 | void spirv::CompositeInsertOp::print(OpAsmPrinter &printer) { | |||
1933 | printer << " " << getObject() << ", " << getComposite() << getIndices() | |||
1934 | << " : " << getObject().getType() << " into " | |||
1935 | << getComposite().getType(); | |||
1936 | } | |||
1937 | ||||
1938 | //===----------------------------------------------------------------------===// | |||
1939 | // spirv.Constant | |||
1940 | //===----------------------------------------------------------------------===// | |||
1941 | ||||
1942 | ParseResult spirv::ConstantOp::parse(OpAsmParser &parser, | |||
1943 | OperationState &result) { | |||
1944 | Attribute value; | |||
1945 | if (parser.parseAttribute(value, kValueAttrName, result.attributes)) | |||
1946 | return failure(); | |||
1947 | ||||
1948 | Type type = NoneType::get(parser.getContext()); | |||
1949 | if (auto typedAttr = value.dyn_cast<TypedAttr>()) | |||
1950 | type = typedAttr.getType(); | |||
1951 | if (type.isa<NoneType, TensorType>()) { | |||
1952 | if (parser.parseColonType(type)) | |||
1953 | return failure(); | |||
1954 | } | |||
1955 | ||||
1956 | return parser.addTypeToList(type, result.types); | |||
1957 | } | |||
1958 | ||||
1959 | void spirv::ConstantOp::print(OpAsmPrinter &printer) { | |||
1960 | printer << ' ' << getValue(); | |||
1961 | if (getType().isa<spirv::ArrayType>()) | |||
1962 | printer << " : " << getType(); | |||
1963 | } | |||
1964 | ||||
1965 | static LogicalResult verifyConstantType(spirv::ConstantOp op, Attribute value, | |||
1966 | Type opType) { | |||
1967 | if (value.isa<IntegerAttr, FloatAttr>()) { | |||
1968 | auto valueType = value.cast<TypedAttr>().getType(); | |||
1969 | if (valueType != opType) | |||
1970 | return op.emitOpError("result type (") | |||
1971 | << opType << ") does not match value type (" << valueType << ")"; | |||
1972 | return success(); | |||
1973 | } | |||
1974 | if (value.isa<DenseIntOrFPElementsAttr, SparseElementsAttr>()) { | |||
1975 | auto valueType = value.cast<TypedAttr>().getType(); | |||
1976 | if (valueType == opType) | |||
1977 | return success(); | |||
1978 | auto arrayType = opType.dyn_cast<spirv::ArrayType>(); | |||
1979 | auto shapedType = valueType.dyn_cast<ShapedType>(); | |||
1980 | if (!arrayType) | |||
1981 | return op.emitOpError("result or element type (") | |||
1982 | << opType << ") does not match value type (" << valueType | |||
1983 | << "), must be the same or spirv.array"; | |||
1984 | ||||
1985 | int numElements = arrayType.getNumElements(); | |||
1986 | auto opElemType = arrayType.getElementType(); | |||
1987 | while (auto t = opElemType.dyn_cast<spirv::ArrayType>()) { | |||
1988 | numElements *= t.getNumElements(); | |||
1989 | opElemType = t.getElementType(); | |||
1990 | } | |||
1991 | if (!opElemType.isIntOrFloat()) | |||
1992 | return op.emitOpError("only support nested array result type"); | |||
1993 | ||||
1994 | auto valueElemType = shapedType.getElementType(); | |||
1995 | if (valueElemType != opElemType) { | |||
1996 | return op.emitOpError("result element type (") | |||
1997 | << opElemType << ") does not match value element type (" | |||
1998 | << valueElemType << ")"; | |||
1999 | } | |||
2000 | ||||
2001 | if (numElements != shapedType.getNumElements()) { | |||
2002 | return op.emitOpError("result number of elements (") | |||
2003 | << numElements << ") does not match value number of elements (" | |||
2004 | << shapedType.getNumElements() << ")"; | |||
2005 | } | |||
2006 | return success(); | |||
2007 | } | |||
2008 | if (auto arrayAttr = value.dyn_cast<ArrayAttr>()) { | |||
2009 | auto arrayType = opType.dyn_cast<spirv::ArrayType>(); | |||
2010 | if (!arrayType) | |||
2011 | return op.emitOpError( | |||
2012 | "must have spirv.array result type for array value"); | |||
2013 | Type elemType = arrayType.getElementType(); | |||
2014 | for (Attribute element : arrayAttr.getValue()) { | |||
2015 | // Verify array elements recursively. | |||
2016 | if (failed(verifyConstantType(op, element, elemType))) | |||
2017 | return failure(); | |||
2018 | } | |||
2019 | return success(); | |||
2020 | } | |||
2021 | return op.emitOpError("cannot have attribute: ") << value; | |||
2022 | } | |||
2023 | ||||
2024 | LogicalResult spirv::ConstantOp::verify() { | |||
2025 | // ODS already generates checks to make sure the result type is valid. We just | |||
2026 | // need to additionally check that the value's attribute type is consistent | |||
2027 | // with the result type. | |||
2028 | return verifyConstantType(*this, getValueAttr(), getType()); | |||
2029 | } | |||
2030 | ||||
2031 | bool spirv::ConstantOp::isBuildableWith(Type type) { | |||
2032 | // Must be valid SPIR-V type first. | |||
2033 | if (!type.isa<spirv::SPIRVType>()) | |||
2034 | return false; | |||
2035 | ||||
2036 | if (isa<SPIRVDialect>(type.getDialect())) { | |||
2037 | // TODO: support constant struct | |||
2038 | return type.isa<spirv::ArrayType>(); | |||
2039 | } | |||
2040 | ||||
2041 | return true; | |||
2042 | } | |||
2043 | ||||
2044 | spirv::ConstantOp spirv::ConstantOp::getZero(Type type, Location loc, | |||
2045 | OpBuilder &builder) { | |||
2046 | if (auto intType = type.dyn_cast<IntegerType>()) { | |||
2047 | unsigned width = intType.getWidth(); | |||
2048 | if (width == 1) | |||
2049 | return builder.create<spirv::ConstantOp>(loc, type, | |||
2050 | builder.getBoolAttr(false)); | |||
2051 | return builder.create<spirv::ConstantOp>( | |||
2052 | loc, type, builder.getIntegerAttr(type, APInt(width, 0))); | |||
2053 | } | |||
2054 | if (auto floatType = type.dyn_cast<FloatType>()) { | |||
2055 | return builder.create<spirv::ConstantOp>( | |||
2056 | loc, type, builder.getFloatAttr(floatType, 0.0)); | |||
2057 | } | |||
2058 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
2059 | Type elemType = vectorType.getElementType(); | |||
2060 | if (elemType.isa<IntegerType>()) { | |||
2061 | return builder.create<spirv::ConstantOp>( | |||
2062 | loc, type, | |||
2063 | DenseElementsAttr::get(vectorType, | |||
2064 | IntegerAttr::get(elemType, 0).getValue())); | |||
2065 | } | |||
2066 | if (elemType.isa<FloatType>()) { | |||
2067 | return builder.create<spirv::ConstantOp>( | |||
2068 | loc, type, | |||
2069 | DenseFPElementsAttr::get(vectorType, | |||
2070 | FloatAttr::get(elemType, 0.0).getValue())); | |||
2071 | } | |||
2072 | } | |||
2073 | ||||
2074 | llvm_unreachable("unimplemented types for ConstantOp::getZero()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getZero()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2074); | |||
2075 | } | |||
2076 | ||||
2077 | spirv::ConstantOp spirv::ConstantOp::getOne(Type type, Location loc, | |||
2078 | OpBuilder &builder) { | |||
2079 | if (auto intType = type.dyn_cast<IntegerType>()) { | |||
2080 | unsigned width = intType.getWidth(); | |||
2081 | if (width == 1) | |||
2082 | return builder.create<spirv::ConstantOp>(loc, type, | |||
2083 | builder.getBoolAttr(true)); | |||
2084 | return builder.create<spirv::ConstantOp>( | |||
2085 | loc, type, builder.getIntegerAttr(type, APInt(width, 1))); | |||
2086 | } | |||
2087 | if (auto floatType = type.dyn_cast<FloatType>()) { | |||
2088 | return builder.create<spirv::ConstantOp>( | |||
2089 | loc, type, builder.getFloatAttr(floatType, 1.0)); | |||
2090 | } | |||
2091 | if (auto vectorType = type.dyn_cast<VectorType>()) { | |||
2092 | Type elemType = vectorType.getElementType(); | |||
2093 | if (elemType.isa<IntegerType>()) { | |||
2094 | return builder.create<spirv::ConstantOp>( | |||
2095 | loc, type, | |||
2096 | DenseElementsAttr::get(vectorType, | |||
2097 | IntegerAttr::get(elemType, 1).getValue())); | |||
2098 | } | |||
2099 | if (elemType.isa<FloatType>()) { | |||
2100 | return builder.create<spirv::ConstantOp>( | |||
2101 | loc, type, | |||
2102 | DenseFPElementsAttr::get(vectorType, | |||
2103 | FloatAttr::get(elemType, 1.0).getValue())); | |||
2104 | } | |||
2105 | } | |||
2106 | ||||
2107 | llvm_unreachable("unimplemented types for ConstantOp::getOne()")::llvm::llvm_unreachable_internal("unimplemented types for ConstantOp::getOne()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2107); | |||
2108 | } | |||
2109 | ||||
2110 | void mlir::spirv::ConstantOp::getAsmResultNames( | |||
2111 | llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { | |||
2112 | Type type = getType(); | |||
2113 | ||||
2114 | SmallString<32> specialNameBuffer; | |||
2115 | llvm::raw_svector_ostream specialName(specialNameBuffer); | |||
2116 | specialName << "cst"; | |||
2117 | ||||
2118 | IntegerType intTy = type.dyn_cast<IntegerType>(); | |||
2119 | ||||
2120 | if (IntegerAttr intCst = getValue().dyn_cast<IntegerAttr>()) { | |||
2121 | if (intTy && intTy.getWidth() == 1) { | |||
2122 | return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); | |||
2123 | } | |||
2124 | ||||
2125 | if (intTy.isSignless()) { | |||
2126 | specialName << intCst.getInt(); | |||
2127 | } else if (intTy.isUnsigned()) { | |||
2128 | specialName << intCst.getUInt(); | |||
2129 | } else { | |||
2130 | specialName << intCst.getSInt(); | |||
2131 | } | |||
2132 | } | |||
2133 | ||||
2134 | if (intTy || type.isa<FloatType>()) { | |||
2135 | specialName << '_' << type; | |||
2136 | } | |||
2137 | ||||
2138 | if (auto vecType = type.dyn_cast<VectorType>()) { | |||
2139 | specialName << "_vec_"; | |||
2140 | specialName << vecType.getDimSize(0); | |||
2141 | ||||
2142 | Type elementType = vecType.getElementType(); | |||
2143 | ||||
2144 | if (elementType.isa<IntegerType>() || elementType.isa<FloatType>()) { | |||
2145 | specialName << "x" << elementType; | |||
2146 | } | |||
2147 | } | |||
2148 | ||||
2149 | setNameFn(getResult(), specialName.str()); | |||
2150 | } | |||
2151 | ||||
2152 | void mlir::spirv::AddressOfOp::getAsmResultNames( | |||
2153 | llvm::function_ref<void(mlir::Value, llvm::StringRef)> setNameFn) { | |||
2154 | SmallString<32> specialNameBuffer; | |||
2155 | llvm::raw_svector_ostream specialName(specialNameBuffer); | |||
2156 | specialName << getVariable() << "_addr"; | |||
2157 | setNameFn(getResult(), specialName.str()); | |||
2158 | } | |||
2159 | ||||
2160 | //===----------------------------------------------------------------------===// | |||
2161 | // spirv.ControlBarrierOp | |||
2162 | //===----------------------------------------------------------------------===// | |||
2163 | ||||
2164 | LogicalResult spirv::ControlBarrierOp::verify() { | |||
2165 | return verifyMemorySemantics(getOperation(), getMemorySemantics()); | |||
2166 | } | |||
2167 | ||||
2168 | //===----------------------------------------------------------------------===// | |||
2169 | // spirv.ConvertFToSOp | |||
2170 | //===----------------------------------------------------------------------===// | |||
2171 | ||||
2172 | LogicalResult spirv::ConvertFToSOp::verify() { | |||
2173 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2174 | /*skipBitWidthCheck=*/true); | |||
2175 | } | |||
2176 | ||||
2177 | //===----------------------------------------------------------------------===// | |||
2178 | // spirv.ConvertFToUOp | |||
2179 | //===----------------------------------------------------------------------===// | |||
2180 | ||||
2181 | LogicalResult spirv::ConvertFToUOp::verify() { | |||
2182 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2183 | /*skipBitWidthCheck=*/true); | |||
2184 | } | |||
2185 | ||||
2186 | //===----------------------------------------------------------------------===// | |||
2187 | // spirv.ConvertSToFOp | |||
2188 | //===----------------------------------------------------------------------===// | |||
2189 | ||||
2190 | LogicalResult spirv::ConvertSToFOp::verify() { | |||
2191 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2192 | /*skipBitWidthCheck=*/true); | |||
2193 | } | |||
2194 | ||||
2195 | //===----------------------------------------------------------------------===// | |||
2196 | // spirv.ConvertUToFOp | |||
2197 | //===----------------------------------------------------------------------===// | |||
2198 | ||||
2199 | LogicalResult spirv::ConvertUToFOp::verify() { | |||
2200 | return verifyCastOp(*this, /*requireSameBitWidth=*/false, | |||
2201 | /*skipBitWidthCheck=*/true); | |||
2202 | } | |||
2203 | ||||
2204 | //===----------------------------------------------------------------------===// | |||
2205 | // spirv.INTELConvertBF16ToFOp | |||
2206 | //===----------------------------------------------------------------------===// | |||
2207 | ||||
2208 | LogicalResult spirv::INTELConvertBF16ToFOp::verify() { | |||
2209 | auto operandType = getOperand().getType(); | |||
2210 | auto resultType = getResult().getType(); | |||
2211 | // ODS checks that vector result type and vector operand type have the same | |||
2212 | // shape. | |||
2213 | if (auto vectorType = operandType.dyn_cast<VectorType>()) { | |||
2214 | unsigned operandNumElements = vectorType.getNumElements(); | |||
2215 | unsigned resultNumElements = resultType.cast<VectorType>().getNumElements(); | |||
2216 | if (operandNumElements != resultNumElements) { | |||
2217 | return emitOpError( | |||
2218 | "operand and result must have same number of elements"); | |||
2219 | } | |||
2220 | } | |||
2221 | return success(); | |||
2222 | } | |||
2223 | ||||
2224 | //===----------------------------------------------------------------------===// | |||
2225 | // spirv.INTELConvertFToBF16Op | |||
2226 | //===----------------------------------------------------------------------===// | |||
2227 | ||||
2228 | LogicalResult spirv::INTELConvertFToBF16Op::verify() { | |||
2229 | auto operandType = getOperand().getType(); | |||
2230 | auto resultType = getResult().getType(); | |||
2231 | // ODS checks that vector result type and vector operand type have the same | |||
2232 | // shape. | |||
2233 | if (auto vectorType = operandType.dyn_cast<VectorType>()) { | |||
2234 | unsigned operandNumElements = vectorType.getNumElements(); | |||
2235 | unsigned resultNumElements = resultType.cast<VectorType>().getNumElements(); | |||
2236 | if (operandNumElements != resultNumElements) { | |||
2237 | return emitOpError( | |||
2238 | "operand and result must have same number of elements"); | |||
2239 | } | |||
2240 | } | |||
2241 | return success(); | |||
2242 | } | |||
2243 | ||||
2244 | //===----------------------------------------------------------------------===// | |||
2245 | // spirv.EntryPoint | |||
2246 | //===----------------------------------------------------------------------===// | |||
2247 | ||||
2248 | void spirv::EntryPointOp::build(OpBuilder &builder, OperationState &state, | |||
2249 | spirv::ExecutionModel executionModel, | |||
2250 | spirv::FuncOp function, | |||
2251 | ArrayRef<Attribute> interfaceVars) { | |||
2252 | build(builder, state, | |||
2253 | spirv::ExecutionModelAttr::get(builder.getContext(), executionModel), | |||
2254 | SymbolRefAttr::get(function), builder.getArrayAttr(interfaceVars)); | |||
2255 | } | |||
2256 | ||||
2257 | ParseResult spirv::EntryPointOp::parse(OpAsmParser &parser, | |||
2258 | OperationState &result) { | |||
2259 | spirv::ExecutionModel execModel; | |||
| ||||
2260 | SmallVector<OpAsmParser::UnresolvedOperand, 0> identifiers; | |||
2261 | SmallVector<Type, 0> idTypes; | |||
2262 | SmallVector<Attribute, 4> interfaceVars; | |||
2263 | ||||
2264 | FlatSymbolRefAttr fn; | |||
2265 | if (parseEnumStrAttr<spirv::ExecutionModelAttr>(execModel, parser, result) || | |||
2266 | parser.parseAttribute(fn, Type(), kFnNameAttrName, result.attributes)) { | |||
2267 | return failure(); | |||
2268 | } | |||
2269 | ||||
2270 | if (!parser.parseOptionalComma()) { | |||
2271 | // Parse the interface variables | |||
2272 | if (parser.parseCommaSeparatedList([&]() -> ParseResult { | |||
2273 | // The name of the interface variable attribute isnt important | |||
2274 | FlatSymbolRefAttr var; | |||
2275 | NamedAttrList attrs; | |||
2276 | if (parser.parseAttribute(var, Type(), "var_symbol", attrs)) | |||
2277 | return failure(); | |||
2278 | interfaceVars.push_back(var); | |||
2279 | return success(); | |||
2280 | })) | |||
2281 | return failure(); | |||
2282 | } | |||
2283 | result.addAttribute(kInterfaceAttrName, | |||
2284 | parser.getBuilder().getArrayAttr(interfaceVars)); | |||
2285 | return success(); | |||
2286 | } | |||
2287 | ||||
2288 | void spirv::EntryPointOp::print(OpAsmPrinter &printer) { | |||
2289 | printer << " \"" << stringifyExecutionModel(getExecutionModel()) << "\" "; | |||
2290 | printer.printSymbolName(getFn()); | |||
2291 | auto interfaceVars = getInterface().getValue(); | |||
2292 | if (!interfaceVars.empty()) { | |||
2293 | printer << ", "; | |||
2294 | llvm::interleaveComma(interfaceVars, printer); | |||
2295 | } | |||
2296 | } | |||
2297 | ||||
2298 | LogicalResult spirv::EntryPointOp::verify() { | |||
2299 | // Checks for fn and interface symbol reference are done in spirv::ModuleOp | |||
2300 | // verification. | |||
2301 | return success(); | |||
2302 | } | |||
2303 | ||||
2304 | //===----------------------------------------------------------------------===// | |||
2305 | // spirv.ExecutionMode | |||
2306 | //===----------------------------------------------------------------------===// | |||
2307 | ||||
2308 | void spirv::ExecutionModeOp::build(OpBuilder &builder, OperationState &state, | |||
2309 | spirv::FuncOp function, | |||
2310 | spirv::ExecutionMode executionMode, | |||
2311 | ArrayRef<int32_t> params) { | |||
2312 | build(builder, state, SymbolRefAttr::get(function), | |||
2313 | spirv::ExecutionModeAttr::get(builder.getContext(), executionMode), | |||
2314 | builder.getI32ArrayAttr(params)); | |||
2315 | } | |||
2316 | ||||
2317 | ParseResult spirv::ExecutionModeOp::parse(OpAsmParser &parser, | |||
2318 | OperationState &result) { | |||
2319 | spirv::ExecutionMode execMode; | |||
2320 | Attribute fn; | |||
2321 | if (parser.parseAttribute(fn, kFnNameAttrName, result.attributes) || | |||
2322 | parseEnumStrAttr<spirv::ExecutionModeAttr>(execMode, parser, result)) { | |||
2323 | return failure(); | |||
2324 | } | |||
2325 | ||||
2326 | SmallVector<int32_t, 4> values; | |||
2327 | Type i32Type = parser.getBuilder().getIntegerType(32); | |||
2328 | while (!parser.parseOptionalComma()) { | |||
2329 | NamedAttrList attr; | |||
2330 | Attribute value; | |||
2331 | if (parser.parseAttribute(value, i32Type, "value", attr)) { | |||
2332 | return failure(); | |||
2333 | } | |||
2334 | values.push_back(value.cast<IntegerAttr>().getInt()); | |||
2335 | } | |||
2336 | result.addAttribute(kValuesAttrName, | |||
2337 | parser.getBuilder().getI32ArrayAttr(values)); | |||
2338 | return success(); | |||
2339 | } | |||
2340 | ||||
2341 | void spirv::ExecutionModeOp::print(OpAsmPrinter &printer) { | |||
2342 | printer << " "; | |||
2343 | printer.printSymbolName(getFn()); | |||
2344 | printer << " \"" << stringifyExecutionMode(getExecutionMode()) << "\""; | |||
2345 | auto values = this->getValues(); | |||
2346 | if (values.empty()) | |||
2347 | return; | |||
2348 | printer << ", "; | |||
2349 | llvm::interleaveComma(values, printer, [&](Attribute a) { | |||
2350 | printer << a.cast<IntegerAttr>().getInt(); | |||
2351 | }); | |||
2352 | } | |||
2353 | ||||
2354 | //===----------------------------------------------------------------------===// | |||
2355 | // spirv.FConvertOp | |||
2356 | //===----------------------------------------------------------------------===// | |||
2357 | ||||
2358 | LogicalResult spirv::FConvertOp::verify() { | |||
2359 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2360 | } | |||
2361 | ||||
2362 | //===----------------------------------------------------------------------===// | |||
2363 | // spirv.SConvertOp | |||
2364 | //===----------------------------------------------------------------------===// | |||
2365 | ||||
2366 | LogicalResult spirv::SConvertOp::verify() { | |||
2367 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2368 | } | |||
2369 | ||||
2370 | //===----------------------------------------------------------------------===// | |||
2371 | // spirv.UConvertOp | |||
2372 | //===----------------------------------------------------------------------===// | |||
2373 | ||||
2374 | LogicalResult spirv::UConvertOp::verify() { | |||
2375 | return verifyCastOp(*this, /*requireSameBitWidth=*/false); | |||
2376 | } | |||
2377 | ||||
2378 | //===----------------------------------------------------------------------===// | |||
2379 | // spirv.func | |||
2380 | //===----------------------------------------------------------------------===// | |||
2381 | ||||
2382 | ParseResult spirv::FuncOp::parse(OpAsmParser &parser, OperationState &result) { | |||
2383 | SmallVector<OpAsmParser::Argument> entryArgs; | |||
2384 | SmallVector<DictionaryAttr> resultAttrs; | |||
2385 | SmallVector<Type> resultTypes; | |||
2386 | auto &builder = parser.getBuilder(); | |||
2387 | ||||
2388 | // Parse the name as a symbol. | |||
2389 | StringAttr nameAttr; | |||
2390 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
2391 | result.attributes)) | |||
2392 | return failure(); | |||
2393 | ||||
2394 | // Parse the function signature. | |||
2395 | bool isVariadic = false; | |||
2396 | if (function_interface_impl::parseFunctionSignature( | |||
2397 | parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes, | |||
2398 | resultAttrs)) | |||
2399 | return failure(); | |||
2400 | ||||
2401 | SmallVector<Type> argTypes; | |||
2402 | for (auto &arg : entryArgs) | |||
2403 | argTypes.push_back(arg.type); | |||
2404 | auto fnType = builder.getFunctionType(argTypes, resultTypes); | |||
2405 | result.addAttribute(getFunctionTypeAttrName(result.name), | |||
2406 | TypeAttr::get(fnType)); | |||
2407 | ||||
2408 | // Parse the optional function control keyword. | |||
2409 | spirv::FunctionControl fnControl; | |||
2410 | if (parseEnumStrAttr<spirv::FunctionControlAttr>(fnControl, parser, result)) | |||
2411 | return failure(); | |||
2412 | ||||
2413 | // If additional attributes are present, parse them. | |||
2414 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes)) | |||
2415 | return failure(); | |||
2416 | ||||
2417 | // Add the attributes to the function arguments. | |||
2418 | assert(resultAttrs.size() == resultTypes.size())(static_cast <bool> (resultAttrs.size() == resultTypes. size()) ? void (0) : __assert_fail ("resultAttrs.size() == resultTypes.size()" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 2418, __extension__ __PRETTY_FUNCTION__)); | |||
2419 | function_interface_impl::addArgAndResultAttrs( | |||
2420 | builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name), | |||
2421 | getResAttrsAttrName(result.name)); | |||
2422 | ||||
2423 | // Parse the optional function body. | |||
2424 | auto *body = result.addRegion(); | |||
2425 | OptionalParseResult parseResult = | |||
2426 | parser.parseOptionalRegion(*body, entryArgs); | |||
2427 | return failure(parseResult.has_value() && failed(*parseResult)); | |||
2428 | } | |||
2429 | ||||
2430 | void spirv::FuncOp::print(OpAsmPrinter &printer) { | |||
2431 | // Print function name, signature, and control. | |||
2432 | printer << " "; | |||
2433 | printer.printSymbolName(getSymName()); | |||
2434 | auto fnType = getFunctionType(); | |||
2435 | function_interface_impl::printFunctionSignature( | |||
2436 | printer, *this, fnType.getInputs(), | |||
2437 | /*isVariadic=*/false, fnType.getResults()); | |||
2438 | printer << " \"" << spirv::stringifyFunctionControl(getFunctionControl()) | |||
2439 | << "\""; | |||
2440 | function_interface_impl::printFunctionAttributes( | |||
2441 | printer, *this, | |||
2442 | {spirv::attributeName<spirv::FunctionControl>(), | |||
2443 | getFunctionTypeAttrName(), getArgAttrsAttrName(), getResAttrsAttrName(), | |||
2444 | getFunctionControlAttrName()}); | |||
2445 | ||||
2446 | // Print the body if this is not an external function. | |||
2447 | Region &body = this->getBody(); | |||
2448 | if (!body.empty()) { | |||
2449 | printer << ' '; | |||
2450 | printer.printRegion(body, /*printEntryBlockArgs=*/false, | |||
2451 | /*printBlockTerminators=*/true); | |||
2452 | } | |||
2453 | } | |||
2454 | ||||
2455 | LogicalResult spirv::FuncOp::verifyType() { | |||
2456 | if (getFunctionType().getNumResults() > 1) | |||
2457 | return emitOpError("cannot have more than one result"); | |||
2458 | return success(); | |||
2459 | } | |||
2460 | ||||
2461 | LogicalResult spirv::FuncOp::verifyBody() { | |||
2462 | FunctionType fnType = getFunctionType(); | |||
2463 | ||||
2464 | auto walkResult = walk([fnType](Operation *op) -> WalkResult { | |||
2465 | if (auto retOp = dyn_cast<spirv::ReturnOp>(op)) { | |||
2466 | if (fnType.getNumResults() != 0) | |||
2467 | return retOp.emitOpError("cannot be used in functions returning value"); | |||
2468 | } else if (auto retOp = dyn_cast<spirv::ReturnValueOp>(op)) { | |||
2469 | if (fnType.getNumResults() != 1) | |||
2470 | return retOp.emitOpError( | |||
2471 | "returns 1 value but enclosing function requires ") | |||
2472 | << fnType.getNumResults() << " results"; | |||
2473 | ||||
2474 | auto retOperandType = retOp.getValue().getType(); | |||
2475 | auto fnResultType = fnType.getResult(0); | |||
2476 | if (retOperandType != fnResultType) | |||
2477 | return retOp.emitOpError(" return value's type (") | |||
2478 | << retOperandType << ") mismatch with function's result type (" | |||
2479 | << fnResultType << ")"; | |||
2480 | } | |||
2481 | return WalkResult::advance(); | |||
2482 | }); | |||
2483 | ||||
2484 | // TODO: verify other bits like linkage type. | |||
2485 | ||||
2486 | return failure(walkResult.wasInterrupted()); | |||
2487 | } | |||
2488 | ||||
2489 | void spirv::FuncOp::build(OpBuilder &builder, OperationState &state, | |||
2490 | StringRef name, FunctionType type, | |||
2491 | spirv::FunctionControl control, | |||
2492 | ArrayRef<NamedAttribute> attrs) { | |||
2493 | state.addAttribute(SymbolTable::getSymbolAttrName(), | |||
2494 | builder.getStringAttr(name)); | |||
2495 | state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type)); | |||
2496 | state.addAttribute(spirv::attributeName<spirv::FunctionControl>(), | |||
2497 | builder.getAttr<spirv::FunctionControlAttr>(control)); | |||
2498 | state.attributes.append(attrs.begin(), attrs.end()); | |||
2499 | state.addRegion(); | |||
2500 | } | |||
2501 | ||||
2502 | // CallableOpInterface | |||
2503 | Region *spirv::FuncOp::getCallableRegion() { | |||
2504 | return isExternal() ? nullptr : &getBody(); | |||
2505 | } | |||
2506 | ||||
2507 | // CallableOpInterface | |||
2508 | ArrayRef<Type> spirv::FuncOp::getCallableResults() { | |||
2509 | return getFunctionType().getResults(); | |||
2510 | } | |||
2511 | ||||
2512 | // CallableOpInterface | |||
2513 | ::mlir::ArrayAttr spirv::FuncOp::getCallableArgAttrs() { | |||
2514 | return getArgAttrs().value_or(nullptr); | |||
2515 | } | |||
2516 | ||||
2517 | // CallableOpInterface | |||
2518 | ::mlir::ArrayAttr spirv::FuncOp::getCallableResAttrs() { | |||
2519 | return getResAttrs().value_or(nullptr); | |||
2520 | } | |||
2521 | ||||
2522 | //===----------------------------------------------------------------------===// | |||
2523 | // spirv.FunctionCall | |||
2524 | //===----------------------------------------------------------------------===// | |||
2525 | ||||
2526 | LogicalResult spirv::FunctionCallOp::verify() { | |||
2527 | auto fnName = getCalleeAttr(); | |||
2528 | ||||
2529 | auto funcOp = dyn_cast_or_null<spirv::FuncOp>( | |||
2530 | SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); | |||
2531 | if (!funcOp) { | |||
2532 | return emitOpError("callee function '") | |||
2533 | << fnName.getValue() << "' not found in nearest symbol table"; | |||
2534 | } | |||
2535 | ||||
2536 | auto functionType = funcOp.getFunctionType(); | |||
2537 | ||||
2538 | if (getNumResults() > 1) { | |||
2539 | return emitOpError( | |||
2540 | "expected callee function to have 0 or 1 result, but provided ") | |||
2541 | << getNumResults(); | |||
2542 | } | |||
2543 | ||||
2544 | if (functionType.getNumInputs() != getNumOperands()) { | |||
2545 | return emitOpError("has incorrect number of operands for callee: expected ") | |||
2546 | << functionType.getNumInputs() << ", but provided " | |||
2547 | << getNumOperands(); | |||
2548 | } | |||
2549 | ||||
2550 | for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { | |||
2551 | if (getOperand(i).getType() != functionType.getInput(i)) { | |||
2552 | return emitOpError("operand type mismatch: expected operand type ") | |||
2553 | << functionType.getInput(i) << ", but provided " | |||
2554 | << getOperand(i).getType() << " for operand number " << i; | |||
2555 | } | |||
2556 | } | |||
2557 | ||||
2558 | if (functionType.getNumResults() != getNumResults()) { | |||
2559 | return emitOpError( | |||
2560 | "has incorrect number of results has for callee: expected ") | |||
2561 | << functionType.getNumResults() << ", but provided " | |||
2562 | << getNumResults(); | |||
2563 | } | |||
2564 | ||||
2565 | if (getNumResults() && | |||
2566 | (getResult(0).getType() != functionType.getResult(0))) { | |||
2567 | return emitOpError("result type mismatch: expected ") | |||
2568 | << functionType.getResult(0) << ", but provided " | |||
2569 | << getResult(0).getType(); | |||
2570 | } | |||
2571 | ||||
2572 | return success(); | |||
2573 | } | |||
2574 | ||||
2575 | CallInterfaceCallable spirv::FunctionCallOp::getCallableForCallee() { | |||
2576 | return (*this)->getAttrOfType<SymbolRefAttr>(kCallee); | |||
2577 | } | |||
2578 | ||||
2579 | void spirv::FunctionCallOp::setCalleeFromCallable( | |||
2580 | CallInterfaceCallable callee) { | |||
2581 | (*this)->setAttr(kCallee, callee.get<SymbolRefAttr>()); | |||
2582 | } | |||
2583 | ||||
2584 | Operation::operand_range spirv::FunctionCallOp::getArgOperands() { | |||
2585 | return getArguments(); | |||
2586 | } | |||
2587 | ||||
2588 | //===----------------------------------------------------------------------===// | |||
2589 | // spirv.GLFClampOp | |||
2590 | //===----------------------------------------------------------------------===// | |||
2591 | ||||
2592 | ParseResult spirv::GLFClampOp::parse(OpAsmParser &parser, | |||
2593 | OperationState &result) { | |||
2594 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2595 | } | |||
2596 | void spirv::GLFClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2597 | ||||
2598 | //===----------------------------------------------------------------------===// | |||
2599 | // spirv.GLUClampOp | |||
2600 | //===----------------------------------------------------------------------===// | |||
2601 | ||||
2602 | ParseResult spirv::GLUClampOp::parse(OpAsmParser &parser, | |||
2603 | OperationState &result) { | |||
2604 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2605 | } | |||
2606 | void spirv::GLUClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2607 | ||||
2608 | //===----------------------------------------------------------------------===// | |||
2609 | // spirv.GLSClampOp | |||
2610 | //===----------------------------------------------------------------------===// | |||
2611 | ||||
2612 | ParseResult spirv::GLSClampOp::parse(OpAsmParser &parser, | |||
2613 | OperationState &result) { | |||
2614 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2615 | } | |||
2616 | void spirv::GLSClampOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2617 | ||||
2618 | //===----------------------------------------------------------------------===// | |||
2619 | // spirv.GLFmaOp | |||
2620 | //===----------------------------------------------------------------------===// | |||
2621 | ||||
2622 | ParseResult spirv::GLFmaOp::parse(OpAsmParser &parser, OperationState &result) { | |||
2623 | return parseOneResultSameOperandTypeOp(parser, result); | |||
2624 | } | |||
2625 | void spirv::GLFmaOp::print(OpAsmPrinter &p) { printOneResultOp(*this, p); } | |||
2626 | ||||
2627 | //===----------------------------------------------------------------------===// | |||
2628 | // spirv.GlobalVariable | |||
2629 | //===----------------------------------------------------------------------===// | |||
2630 | ||||
2631 | void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, | |||
2632 | Type type, StringRef name, | |||
2633 | unsigned descriptorSet, unsigned binding) { | |||
2634 | build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); | |||
2635 | state.addAttribute( | |||
2636 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::DescriptorSet), | |||
2637 | builder.getI32IntegerAttr(descriptorSet)); | |||
2638 | state.addAttribute( | |||
2639 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::Binding), | |||
2640 | builder.getI32IntegerAttr(binding)); | |||
2641 | } | |||
2642 | ||||
2643 | void spirv::GlobalVariableOp::build(OpBuilder &builder, OperationState &state, | |||
2644 | Type type, StringRef name, | |||
2645 | spirv::BuiltIn builtin) { | |||
2646 | build(builder, state, TypeAttr::get(type), builder.getStringAttr(name)); | |||
2647 | state.addAttribute( | |||
2648 | spirv::SPIRVDialect::getAttributeName(spirv::Decoration::BuiltIn), | |||
2649 | builder.getStringAttr(spirv::stringifyBuiltIn(builtin))); | |||
2650 | } | |||
2651 | ||||
2652 | ParseResult spirv::GlobalVariableOp::parse(OpAsmParser &parser, | |||
2653 | OperationState &result) { | |||
2654 | // Parse variable name. | |||
2655 | StringAttr nameAttr; | |||
2656 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
2657 | result.attributes)) { | |||
2658 | return failure(); | |||
2659 | } | |||
2660 | ||||
2661 | // Parse optional initializer | |||
2662 | if (succeeded(parser.parseOptionalKeyword(kInitializerAttrName))) { | |||
2663 | FlatSymbolRefAttr initSymbol; | |||
2664 | if (parser.parseLParen() || | |||
2665 | parser.parseAttribute(initSymbol, Type(), kInitializerAttrName, | |||
2666 | result.attributes) || | |||
2667 | parser.parseRParen()) | |||
2668 | return failure(); | |||
2669 | } | |||
2670 | ||||
2671 | if (parseVariableDecorations(parser, result)) { | |||
2672 | return failure(); | |||
2673 | } | |||
2674 | ||||
2675 | Type type; | |||
2676 | auto loc = parser.getCurrentLocation(); | |||
2677 | if (parser.parseColonType(type)) { | |||
2678 | return failure(); | |||
2679 | } | |||
2680 | if (!type.isa<spirv::PointerType>()) { | |||
2681 | return parser.emitError(loc, "expected spirv.ptr type"); | |||
2682 | } | |||
2683 | result.addAttribute(kTypeAttrName, TypeAttr::get(type)); | |||
2684 | ||||
2685 | return success(); | |||
2686 | } | |||
2687 | ||||
2688 | void spirv::GlobalVariableOp::print(OpAsmPrinter &printer) { | |||
2689 | SmallVector<StringRef, 4> elidedAttrs{ | |||
2690 | spirv::attributeName<spirv::StorageClass>()}; | |||
2691 | ||||
2692 | // Print variable name. | |||
2693 | printer << ' '; | |||
2694 | printer.printSymbolName(getSymName()); | |||
2695 | elidedAttrs.push_back(SymbolTable::getSymbolAttrName()); | |||
2696 | ||||
2697 | // Print optional initializer | |||
2698 | if (auto initializer = this->getInitializer()) { | |||
2699 | printer << " " << kInitializerAttrName << '('; | |||
2700 | printer.printSymbolName(*initializer); | |||
2701 | printer << ')'; | |||
2702 | elidedAttrs.push_back(kInitializerAttrName); | |||
2703 | } | |||
2704 | ||||
2705 | elidedAttrs.push_back(kTypeAttrName); | |||
2706 | printVariableDecorations(*this, printer, elidedAttrs); | |||
2707 | printer << " : " << getType(); | |||
2708 | } | |||
2709 | ||||
2710 | LogicalResult spirv::GlobalVariableOp::verify() { | |||
2711 | if (!getType().isa<spirv::PointerType>()) | |||
2712 | return emitOpError("result must be of a !spv.ptr type"); | |||
2713 | ||||
2714 | // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the | |||
2715 | // object. It cannot be Generic. It must be the same as the Storage Class | |||
2716 | // operand of the Result Type." | |||
2717 | // Also, Function storage class is reserved by spirv.Variable. | |||
2718 | auto storageClass = this->storageClass(); | |||
2719 | if (storageClass == spirv::StorageClass::Generic || | |||
2720 | storageClass == spirv::StorageClass::Function) { | |||
2721 | return emitOpError("storage class cannot be '") | |||
2722 | << stringifyStorageClass(storageClass) << "'"; | |||
2723 | } | |||
2724 | ||||
2725 | if (auto init = | |||
2726 | (*this)->getAttrOfType<FlatSymbolRefAttr>(kInitializerAttrName)) { | |||
2727 | Operation *initOp = SymbolTable::lookupNearestSymbolFrom( | |||
2728 | (*this)->getParentOp(), init.getAttr()); | |||
2729 | // TODO: Currently only variable initialization with specialization | |||
2730 | // constants and other variables is supported. They could be normal | |||
2731 | // constants in the module scope as well. | |||
2732 | if (!initOp || | |||
2733 | !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) { | |||
2734 | return emitOpError("initializer must be result of a " | |||
2735 | "spirv.SpecConstant or spirv.GlobalVariable op"); | |||
2736 | } | |||
2737 | } | |||
2738 | ||||
2739 | return success(); | |||
2740 | } | |||
2741 | ||||
2742 | //===----------------------------------------------------------------------===// | |||
2743 | // spirv.GroupBroadcast | |||
2744 | //===----------------------------------------------------------------------===// | |||
2745 | ||||
2746 | LogicalResult spirv::GroupBroadcastOp::verify() { | |||
2747 | spirv::Scope scope = getExecutionScope(); | |||
2748 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2749 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2750 | ||||
2751 | if (auto localIdTy = getLocalid().getType().dyn_cast<VectorType>()) | |||
2752 | if (localIdTy.getNumElements() != 2 && localIdTy.getNumElements() != 3) | |||
2753 | return emitOpError("localid is a vector and can be with only " | |||
2754 | " 2 or 3 components, actual number is ") | |||
2755 | << localIdTy.getNumElements(); | |||
2756 | ||||
2757 | return success(); | |||
2758 | } | |||
2759 | ||||
2760 | //===----------------------------------------------------------------------===// | |||
2761 | // spirv.GroupNonUniformBallotOp | |||
2762 | //===----------------------------------------------------------------------===// | |||
2763 | ||||
2764 | LogicalResult spirv::GroupNonUniformBallotOp::verify() { | |||
2765 | spirv::Scope scope = getExecutionScope(); | |||
2766 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2767 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2768 | ||||
2769 | return success(); | |||
2770 | } | |||
2771 | ||||
2772 | //===----------------------------------------------------------------------===// | |||
2773 | // spirv.GroupNonUniformBroadcast | |||
2774 | //===----------------------------------------------------------------------===// | |||
2775 | ||||
2776 | LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { | |||
2777 | spirv::Scope scope = getExecutionScope(); | |||
2778 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2779 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2780 | ||||
2781 | // SPIR-V spec: "Before version 1.5, Id must come from a | |||
2782 | // constant instruction. | |||
2783 | auto targetEnv = spirv::getDefaultTargetEnv(getContext()); | |||
2784 | if (auto spirvModule = (*this)->getParentOfType<spirv::ModuleOp>()) | |||
2785 | targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); | |||
2786 | ||||
2787 | if (targetEnv.getVersion() < spirv::Version::V_1_5) { | |||
2788 | auto *idOp = getId().getDefiningOp(); | |||
2789 | if (!idOp || !isa<spirv::ConstantOp, // for normal constant | |||
2790 | spirv::ReferenceOfOp>(idOp)) // for spec constant | |||
2791 | return emitOpError("id must be the result of a constant op"); | |||
2792 | } | |||
2793 | ||||
2794 | return success(); | |||
2795 | } | |||
2796 | ||||
2797 | //===----------------------------------------------------------------------===// | |||
2798 | // spirv.GroupNonUniformShuffle* | |||
2799 | //===----------------------------------------------------------------------===// | |||
2800 | ||||
2801 | template <typename OpTy> | |||
2802 | static LogicalResult verifyGroupNonUniformShuffleOp(OpTy op) { | |||
2803 | spirv::Scope scope = op.getExecutionScope(); | |||
2804 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2805 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2806 | ||||
2807 | if (op.getOperands().back().getType().isSignedInteger()) | |||
2808 | return op.emitOpError("second operand must be a singless/unsigned integer"); | |||
2809 | ||||
2810 | return success(); | |||
2811 | } | |||
2812 | ||||
2813 | LogicalResult spirv::GroupNonUniformShuffleOp::verify() { | |||
2814 | return verifyGroupNonUniformShuffleOp(*this); | |||
2815 | } | |||
2816 | LogicalResult spirv::GroupNonUniformShuffleDownOp::verify() { | |||
2817 | return verifyGroupNonUniformShuffleOp(*this); | |||
2818 | } | |||
2819 | LogicalResult spirv::GroupNonUniformShuffleUpOp::verify() { | |||
2820 | return verifyGroupNonUniformShuffleOp(*this); | |||
2821 | } | |||
2822 | LogicalResult spirv::GroupNonUniformShuffleXorOp::verify() { | |||
2823 | return verifyGroupNonUniformShuffleOp(*this); | |||
2824 | } | |||
2825 | ||||
2826 | //===----------------------------------------------------------------------===// | |||
2827 | // spirv.INTEL.SubgroupBlockRead | |||
2828 | //===----------------------------------------------------------------------===// | |||
2829 | ||||
2830 | ParseResult spirv::INTELSubgroupBlockReadOp::parse(OpAsmParser &parser, | |||
2831 | OperationState &result) { | |||
2832 | // Parse the storage class specification | |||
2833 | spirv::StorageClass storageClass; | |||
2834 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
2835 | Type elementType; | |||
2836 | if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || | |||
2837 | parser.parseColon() || parser.parseType(elementType)) { | |||
2838 | return failure(); | |||
2839 | } | |||
2840 | ||||
2841 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
2842 | if (auto valVecTy = elementType.dyn_cast<VectorType>()) | |||
2843 | ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); | |||
2844 | ||||
2845 | if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { | |||
2846 | return failure(); | |||
2847 | } | |||
2848 | ||||
2849 | result.addTypes(elementType); | |||
2850 | return success(); | |||
2851 | } | |||
2852 | ||||
2853 | void spirv::INTELSubgroupBlockReadOp::print(OpAsmPrinter &printer) { | |||
2854 | printer << " " << getPtr() << " : " << getType(); | |||
2855 | } | |||
2856 | ||||
2857 | LogicalResult spirv::INTELSubgroupBlockReadOp::verify() { | |||
2858 | if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) | |||
2859 | return failure(); | |||
2860 | ||||
2861 | return success(); | |||
2862 | } | |||
2863 | ||||
2864 | //===----------------------------------------------------------------------===// | |||
2865 | // spirv.INTEL.SubgroupBlockWrite | |||
2866 | //===----------------------------------------------------------------------===// | |||
2867 | ||||
2868 | ParseResult spirv::INTELSubgroupBlockWriteOp::parse(OpAsmParser &parser, | |||
2869 | OperationState &result) { | |||
2870 | // Parse the storage class specification | |||
2871 | spirv::StorageClass storageClass; | |||
2872 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
2873 | auto loc = parser.getCurrentLocation(); | |||
2874 | Type elementType; | |||
2875 | if (parseEnumStrAttr(storageClass, parser) || | |||
2876 | parser.parseOperandList(operandInfo, 2) || parser.parseColon() || | |||
2877 | parser.parseType(elementType)) { | |||
2878 | return failure(); | |||
2879 | } | |||
2880 | ||||
2881 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
2882 | if (auto valVecTy = elementType.dyn_cast<VectorType>()) | |||
2883 | ptrType = spirv::PointerType::get(valVecTy.getElementType(), storageClass); | |||
2884 | ||||
2885 | if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, | |||
2886 | result.operands)) { | |||
2887 | return failure(); | |||
2888 | } | |||
2889 | return success(); | |||
2890 | } | |||
2891 | ||||
2892 | void spirv::INTELSubgroupBlockWriteOp::print(OpAsmPrinter &printer) { | |||
2893 | printer << " " << getPtr() << ", " << getValue() << " : " | |||
2894 | << getValue().getType(); | |||
2895 | } | |||
2896 | ||||
2897 | LogicalResult spirv::INTELSubgroupBlockWriteOp::verify() { | |||
2898 | if (failed(verifyBlockReadWritePtrAndValTypes(*this, getPtr(), getValue()))) | |||
2899 | return failure(); | |||
2900 | ||||
2901 | return success(); | |||
2902 | } | |||
2903 | ||||
2904 | //===----------------------------------------------------------------------===// | |||
2905 | // spirv.GroupNonUniformElectOp | |||
2906 | //===----------------------------------------------------------------------===// | |||
2907 | ||||
2908 | LogicalResult spirv::GroupNonUniformElectOp::verify() { | |||
2909 | spirv::Scope scope = getExecutionScope(); | |||
2910 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
2911 | return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
2912 | ||||
2913 | return success(); | |||
2914 | } | |||
2915 | ||||
2916 | //===----------------------------------------------------------------------===// | |||
2917 | // spirv.GroupNonUniformFAddOp | |||
2918 | //===----------------------------------------------------------------------===// | |||
2919 | ||||
2920 | LogicalResult spirv::GroupNonUniformFAddOp::verify() { | |||
2921 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2922 | } | |||
2923 | ||||
2924 | ParseResult spirv::GroupNonUniformFAddOp::parse(OpAsmParser &parser, | |||
2925 | OperationState &result) { | |||
2926 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2927 | } | |||
2928 | void spirv::GroupNonUniformFAddOp::print(OpAsmPrinter &p) { | |||
2929 | printGroupNonUniformArithmeticOp(*this, p); | |||
2930 | } | |||
2931 | ||||
2932 | //===----------------------------------------------------------------------===// | |||
2933 | // spirv.GroupNonUniformFMaxOp | |||
2934 | //===----------------------------------------------------------------------===// | |||
2935 | ||||
2936 | LogicalResult spirv::GroupNonUniformFMaxOp::verify() { | |||
2937 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2938 | } | |||
2939 | ||||
2940 | ParseResult spirv::GroupNonUniformFMaxOp::parse(OpAsmParser &parser, | |||
2941 | OperationState &result) { | |||
2942 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2943 | } | |||
2944 | void spirv::GroupNonUniformFMaxOp::print(OpAsmPrinter &p) { | |||
2945 | printGroupNonUniformArithmeticOp(*this, p); | |||
2946 | } | |||
2947 | ||||
2948 | //===----------------------------------------------------------------------===// | |||
2949 | // spirv.GroupNonUniformFMinOp | |||
2950 | //===----------------------------------------------------------------------===// | |||
2951 | ||||
2952 | LogicalResult spirv::GroupNonUniformFMinOp::verify() { | |||
2953 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2954 | } | |||
2955 | ||||
2956 | ParseResult spirv::GroupNonUniformFMinOp::parse(OpAsmParser &parser, | |||
2957 | OperationState &result) { | |||
2958 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2959 | } | |||
2960 | void spirv::GroupNonUniformFMinOp::print(OpAsmPrinter &p) { | |||
2961 | printGroupNonUniformArithmeticOp(*this, p); | |||
2962 | } | |||
2963 | ||||
2964 | //===----------------------------------------------------------------------===// | |||
2965 | // spirv.GroupNonUniformFMulOp | |||
2966 | //===----------------------------------------------------------------------===// | |||
2967 | ||||
2968 | LogicalResult spirv::GroupNonUniformFMulOp::verify() { | |||
2969 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2970 | } | |||
2971 | ||||
2972 | ParseResult spirv::GroupNonUniformFMulOp::parse(OpAsmParser &parser, | |||
2973 | OperationState &result) { | |||
2974 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2975 | } | |||
2976 | void spirv::GroupNonUniformFMulOp::print(OpAsmPrinter &p) { | |||
2977 | printGroupNonUniformArithmeticOp(*this, p); | |||
2978 | } | |||
2979 | ||||
2980 | //===----------------------------------------------------------------------===// | |||
2981 | // spirv.GroupNonUniformIAddOp | |||
2982 | //===----------------------------------------------------------------------===// | |||
2983 | ||||
2984 | LogicalResult spirv::GroupNonUniformIAddOp::verify() { | |||
2985 | return verifyGroupNonUniformArithmeticOp(*this); | |||
2986 | } | |||
2987 | ||||
2988 | ParseResult spirv::GroupNonUniformIAddOp::parse(OpAsmParser &parser, | |||
2989 | OperationState &result) { | |||
2990 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
2991 | } | |||
2992 | void spirv::GroupNonUniformIAddOp::print(OpAsmPrinter &p) { | |||
2993 | printGroupNonUniformArithmeticOp(*this, p); | |||
2994 | } | |||
2995 | ||||
2996 | //===----------------------------------------------------------------------===// | |||
2997 | // spirv.GroupNonUniformIMulOp | |||
2998 | //===----------------------------------------------------------------------===// | |||
2999 | ||||
3000 | LogicalResult spirv::GroupNonUniformIMulOp::verify() { | |||
3001 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3002 | } | |||
3003 | ||||
3004 | ParseResult spirv::GroupNonUniformIMulOp::parse(OpAsmParser &parser, | |||
3005 | OperationState &result) { | |||
3006 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3007 | } | |||
3008 | void spirv::GroupNonUniformIMulOp::print(OpAsmPrinter &p) { | |||
3009 | printGroupNonUniformArithmeticOp(*this, p); | |||
3010 | } | |||
3011 | ||||
3012 | //===----------------------------------------------------------------------===// | |||
3013 | // spirv.GroupNonUniformSMaxOp | |||
3014 | //===----------------------------------------------------------------------===// | |||
3015 | ||||
3016 | LogicalResult spirv::GroupNonUniformSMaxOp::verify() { | |||
3017 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3018 | } | |||
3019 | ||||
3020 | ParseResult spirv::GroupNonUniformSMaxOp::parse(OpAsmParser &parser, | |||
3021 | OperationState &result) { | |||
3022 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3023 | } | |||
3024 | void spirv::GroupNonUniformSMaxOp::print(OpAsmPrinter &p) { | |||
3025 | printGroupNonUniformArithmeticOp(*this, p); | |||
3026 | } | |||
3027 | ||||
3028 | //===----------------------------------------------------------------------===// | |||
3029 | // spirv.GroupNonUniformSMinOp | |||
3030 | //===----------------------------------------------------------------------===// | |||
3031 | ||||
3032 | LogicalResult spirv::GroupNonUniformSMinOp::verify() { | |||
3033 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3034 | } | |||
3035 | ||||
3036 | ParseResult spirv::GroupNonUniformSMinOp::parse(OpAsmParser &parser, | |||
3037 | OperationState &result) { | |||
3038 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3039 | } | |||
3040 | void spirv::GroupNonUniformSMinOp::print(OpAsmPrinter &p) { | |||
3041 | printGroupNonUniformArithmeticOp(*this, p); | |||
3042 | } | |||
3043 | ||||
3044 | //===----------------------------------------------------------------------===// | |||
3045 | // spirv.GroupNonUniformUMaxOp | |||
3046 | //===----------------------------------------------------------------------===// | |||
3047 | ||||
3048 | LogicalResult spirv::GroupNonUniformUMaxOp::verify() { | |||
3049 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3050 | } | |||
3051 | ||||
3052 | ParseResult spirv::GroupNonUniformUMaxOp::parse(OpAsmParser &parser, | |||
3053 | OperationState &result) { | |||
3054 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3055 | } | |||
3056 | void spirv::GroupNonUniformUMaxOp::print(OpAsmPrinter &p) { | |||
3057 | printGroupNonUniformArithmeticOp(*this, p); | |||
3058 | } | |||
3059 | ||||
3060 | //===----------------------------------------------------------------------===// | |||
3061 | // spirv.GroupNonUniformUMinOp | |||
3062 | //===----------------------------------------------------------------------===// | |||
3063 | ||||
3064 | LogicalResult spirv::GroupNonUniformUMinOp::verify() { | |||
3065 | return verifyGroupNonUniformArithmeticOp(*this); | |||
3066 | } | |||
3067 | ||||
3068 | ParseResult spirv::GroupNonUniformUMinOp::parse(OpAsmParser &parser, | |||
3069 | OperationState &result) { | |||
3070 | return parseGroupNonUniformArithmeticOp(parser, result); | |||
3071 | } | |||
3072 | void spirv::GroupNonUniformUMinOp::print(OpAsmPrinter &p) { | |||
3073 | printGroupNonUniformArithmeticOp(*this, p); | |||
3074 | } | |||
3075 | ||||
3076 | //===----------------------------------------------------------------------===// | |||
3077 | // spirv.IAddCarryOp | |||
3078 | //===----------------------------------------------------------------------===// | |||
3079 | ||||
3080 | LogicalResult spirv::IAddCarryOp::verify() { | |||
3081 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3082 | } | |||
3083 | ||||
3084 | ParseResult spirv::IAddCarryOp::parse(OpAsmParser &parser, | |||
3085 | OperationState &result) { | |||
3086 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3087 | } | |||
3088 | ||||
3089 | void spirv::IAddCarryOp::print(OpAsmPrinter &printer) { | |||
3090 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3091 | } | |||
3092 | ||||
3093 | //===----------------------------------------------------------------------===// | |||
3094 | // spirv.ISubBorrowOp | |||
3095 | //===----------------------------------------------------------------------===// | |||
3096 | ||||
3097 | LogicalResult spirv::ISubBorrowOp::verify() { | |||
3098 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3099 | } | |||
3100 | ||||
3101 | ParseResult spirv::ISubBorrowOp::parse(OpAsmParser &parser, | |||
3102 | OperationState &result) { | |||
3103 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3104 | } | |||
3105 | ||||
3106 | void spirv::ISubBorrowOp::print(OpAsmPrinter &printer) { | |||
3107 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3108 | } | |||
3109 | ||||
3110 | //===----------------------------------------------------------------------===// | |||
3111 | // spirv.SMulExtended | |||
3112 | //===----------------------------------------------------------------------===// | |||
3113 | ||||
3114 | LogicalResult spirv::SMulExtendedOp::verify() { | |||
3115 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3116 | } | |||
3117 | ||||
3118 | ParseResult spirv::SMulExtendedOp::parse(OpAsmParser &parser, | |||
3119 | OperationState &result) { | |||
3120 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3121 | } | |||
3122 | ||||
3123 | void spirv::SMulExtendedOp::print(OpAsmPrinter &printer) { | |||
3124 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3125 | } | |||
3126 | ||||
3127 | //===----------------------------------------------------------------------===// | |||
3128 | // spirv.UMulExtended | |||
3129 | //===----------------------------------------------------------------------===// | |||
3130 | ||||
3131 | LogicalResult spirv::UMulExtendedOp::verify() { | |||
3132 | return ::verifyArithmeticExtendedBinaryOp(*this); | |||
3133 | } | |||
3134 | ||||
3135 | ParseResult spirv::UMulExtendedOp::parse(OpAsmParser &parser, | |||
3136 | OperationState &result) { | |||
3137 | return ::parseArithmeticExtendedBinaryOp(parser, result); | |||
3138 | } | |||
3139 | ||||
3140 | void spirv::UMulExtendedOp::print(OpAsmPrinter &printer) { | |||
3141 | ::printArithmeticExtendedBinaryOp(*this, printer); | |||
3142 | } | |||
3143 | ||||
3144 | //===----------------------------------------------------------------------===// | |||
3145 | // spirv.LoadOp | |||
3146 | //===----------------------------------------------------------------------===// | |||
3147 | ||||
3148 | void spirv::LoadOp::build(OpBuilder &builder, OperationState &state, | |||
3149 | Value basePtr, MemoryAccessAttr memoryAccess, | |||
3150 | IntegerAttr alignment) { | |||
3151 | auto ptrType = basePtr.getType().cast<spirv::PointerType>(); | |||
3152 | build(builder, state, ptrType.getPointeeType(), basePtr, memoryAccess, | |||
3153 | alignment); | |||
3154 | } | |||
3155 | ||||
3156 | ParseResult spirv::LoadOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3157 | // Parse the storage class specification | |||
3158 | spirv::StorageClass storageClass; | |||
3159 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
3160 | Type elementType; | |||
3161 | if (parseEnumStrAttr(storageClass, parser) || parser.parseOperand(ptrInfo) || | |||
3162 | parseMemoryAccessAttributes(parser, result) || | |||
3163 | parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || | |||
3164 | parser.parseType(elementType)) { | |||
3165 | return failure(); | |||
3166 | } | |||
3167 | ||||
3168 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
3169 | if (parser.resolveOperand(ptrInfo, ptrType, result.operands)) { | |||
3170 | return failure(); | |||
3171 | } | |||
3172 | ||||
3173 | result.addTypes(elementType); | |||
3174 | return success(); | |||
3175 | } | |||
3176 | ||||
3177 | void spirv::LoadOp::print(OpAsmPrinter &printer) { | |||
3178 | SmallVector<StringRef, 4> elidedAttrs; | |||
3179 | StringRef sc = stringifyStorageClass( | |||
3180 | getPtr().getType().cast<spirv::PointerType>().getStorageClass()); | |||
3181 | printer << " \"" << sc << "\" " << getPtr(); | |||
3182 | ||||
3183 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
3184 | ||||
3185 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
3186 | printer << " : " << getType(); | |||
3187 | } | |||
3188 | ||||
3189 | LogicalResult spirv::LoadOp::verify() { | |||
3190 | // SPIR-V spec : "Result Type is the type of the loaded object. It must be a | |||
3191 | // type with fixed size; i.e., it cannot be, nor include, any | |||
3192 | // OpTypeRuntimeArray types." | |||
3193 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) { | |||
3194 | return failure(); | |||
3195 | } | |||
3196 | return verifyMemoryAccessAttribute(*this); | |||
3197 | } | |||
3198 | ||||
3199 | //===----------------------------------------------------------------------===// | |||
3200 | // spirv.mlir.loop | |||
3201 | //===----------------------------------------------------------------------===// | |||
3202 | ||||
3203 | void spirv::LoopOp::build(OpBuilder &builder, OperationState &state) { | |||
3204 | state.addAttribute("loop_control", builder.getAttr<spirv::LoopControlAttr>( | |||
3205 | spirv::LoopControl::None)); | |||
3206 | state.addRegion(); | |||
3207 | } | |||
3208 | ||||
3209 | ParseResult spirv::LoopOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3210 | if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser, | |||
3211 | result)) | |||
3212 | return failure(); | |||
3213 | return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); | |||
3214 | } | |||
3215 | ||||
3216 | void spirv::LoopOp::print(OpAsmPrinter &printer) { | |||
3217 | auto control = getLoopControl(); | |||
3218 | if (control != spirv::LoopControl::None) | |||
3219 | printer << " control(" << spirv::stringifyLoopControl(control) << ")"; | |||
3220 | printer << ' '; | |||
3221 | printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, | |||
3222 | /*printBlockTerminators=*/true); | |||
3223 | } | |||
3224 | ||||
3225 | /// Returns true if the given `srcBlock` contains only one `spirv.Branch` to the | |||
3226 | /// given `dstBlock`. | |||
3227 | static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { | |||
3228 | // Check that there is only one op in the `srcBlock`. | |||
3229 | if (!llvm::hasSingleElement(srcBlock)) | |||
3230 | return false; | |||
3231 | ||||
3232 | auto branchOp = dyn_cast<spirv::BranchOp>(srcBlock.back()); | |||
3233 | return branchOp && branchOp.getSuccessor() == &dstBlock; | |||
3234 | } | |||
3235 | ||||
3236 | LogicalResult spirv::LoopOp::verifyRegions() { | |||
3237 | auto *op = getOperation(); | |||
3238 | ||||
3239 | // We need to verify that the blocks follow the following layout: | |||
3240 | // | |||
3241 | // +-------------+ | |||
3242 | // | entry block | | |||
3243 | // +-------------+ | |||
3244 | // | | |||
3245 | // v | |||
3246 | // +-------------+ | |||
3247 | // | loop header | <-----+ | |||
3248 | // +-------------+ | | |||
3249 | // | | |||
3250 | // ... | | |||
3251 | // \ | / | | |||
3252 | // v | | |||
3253 | // +---------------+ | | |||
3254 | // | loop continue | -----+ | |||
3255 | // +---------------+ | |||
3256 | // | |||
3257 | // ... | |||
3258 | // \ | / | |||
3259 | // v | |||
3260 | // +-------------+ | |||
3261 | // | merge block | | |||
3262 | // +-------------+ | |||
3263 | ||||
3264 | auto ®ion = op->getRegion(0); | |||
3265 | // Allow empty region as a degenerated case, which can come from | |||
3266 | // optimizations. | |||
3267 | if (region.empty()) | |||
3268 | return success(); | |||
3269 | ||||
3270 | // The last block is the merge block. | |||
3271 | Block &merge = region.back(); | |||
3272 | if (!isMergeBlock(merge)) | |||
3273 | return emitOpError("last block must be the merge block with only one " | |||
3274 | "'spirv.mlir.merge' op"); | |||
3275 | ||||
3276 | if (std::next(region.begin()) == region.end()) | |||
3277 | return emitOpError( | |||
3278 | "must have an entry block branching to the loop header block"); | |||
3279 | // The first block is the entry block. | |||
3280 | Block &entry = region.front(); | |||
3281 | ||||
3282 | if (std::next(region.begin(), 2) == region.end()) | |||
3283 | return emitOpError( | |||
3284 | "must have a loop header block branched from the entry block"); | |||
3285 | // The second block is the loop header block. | |||
3286 | Block &header = *std::next(region.begin(), 1); | |||
3287 | ||||
3288 | if (!hasOneBranchOpTo(entry, header)) | |||
3289 | return emitOpError( | |||
3290 | "entry block must only have one 'spirv.Branch' op to the second block"); | |||
3291 | ||||
3292 | if (std::next(region.begin(), 3) == region.end()) | |||
3293 | return emitOpError( | |||
3294 | "requires a loop continue block branching to the loop header block"); | |||
3295 | // The second to last block is the loop continue block. | |||
3296 | Block &cont = *std::prev(region.end(), 2); | |||
3297 | ||||
3298 | // Make sure that we have a branch from the loop continue block to the loop | |||
3299 | // header block. | |||
3300 | if (llvm::none_of( | |||
3301 | llvm::seq<unsigned>(0, cont.getNumSuccessors()), | |||
3302 | [&](unsigned index) { return cont.getSuccessor(index) == &header; })) | |||
3303 | return emitOpError("second to last block must be the loop continue " | |||
3304 | "block that branches to the loop header block"); | |||
3305 | ||||
3306 | // Make sure that no other blocks (except the entry and loop continue block) | |||
3307 | // branches to the loop header block. | |||
3308 | for (auto &block : llvm::make_range(std::next(region.begin(), 2), | |||
3309 | std::prev(region.end(), 2))) { | |||
3310 | for (auto i : llvm::seq<unsigned>(0, block.getNumSuccessors())) { | |||
3311 | if (block.getSuccessor(i) == &header) { | |||
3312 | return emitOpError("can only have the entry and loop continue " | |||
3313 | "block branching to the loop header block"); | |||
3314 | } | |||
3315 | } | |||
3316 | } | |||
3317 | ||||
3318 | return success(); | |||
3319 | } | |||
3320 | ||||
3321 | Block *spirv::LoopOp::getEntryBlock() { | |||
3322 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3322, __extension__ __PRETTY_FUNCTION__)); | |||
3323 | return &getBody().front(); | |||
3324 | } | |||
3325 | ||||
3326 | Block *spirv::LoopOp::getHeaderBlock() { | |||
3327 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3327, __extension__ __PRETTY_FUNCTION__)); | |||
3328 | // The second block is the loop header block. | |||
3329 | return &*std::next(getBody().begin()); | |||
3330 | } | |||
3331 | ||||
3332 | Block *spirv::LoopOp::getContinueBlock() { | |||
3333 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3333, __extension__ __PRETTY_FUNCTION__)); | |||
3334 | // The second to last block is the loop continue block. | |||
3335 | return &*std::prev(getBody().end(), 2); | |||
3336 | } | |||
3337 | ||||
3338 | Block *spirv::LoopOp::getMergeBlock() { | |||
3339 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3339, __extension__ __PRETTY_FUNCTION__)); | |||
3340 | // The last block is the loop merge block. | |||
3341 | return &getBody().back(); | |||
3342 | } | |||
3343 | ||||
3344 | void spirv::LoopOp::addEntryAndMergeBlock() { | |||
3345 | assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist" ) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3345, __extension__ __PRETTY_FUNCTION__)); | |||
3346 | getBody().push_back(new Block()); | |||
3347 | auto *mergeBlock = new Block(); | |||
3348 | getBody().push_back(mergeBlock); | |||
3349 | OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); | |||
3350 | ||||
3351 | // Add a spirv.mlir.merge op into the merge block. | |||
3352 | builder.create<spirv::MergeOp>(getLoc()); | |||
3353 | } | |||
3354 | ||||
3355 | //===----------------------------------------------------------------------===// | |||
3356 | // spirv.MemoryBarrierOp | |||
3357 | //===----------------------------------------------------------------------===// | |||
3358 | ||||
3359 | LogicalResult spirv::MemoryBarrierOp::verify() { | |||
3360 | return verifyMemorySemantics(getOperation(), getMemorySemantics()); | |||
3361 | } | |||
3362 | ||||
3363 | //===----------------------------------------------------------------------===// | |||
3364 | // spirv.mlir.merge | |||
3365 | //===----------------------------------------------------------------------===// | |||
3366 | ||||
3367 | LogicalResult spirv::MergeOp::verify() { | |||
3368 | auto *parentOp = (*this)->getParentOp(); | |||
3369 | if (!parentOp || !isa<spirv::SelectionOp, spirv::LoopOp>(parentOp)) | |||
3370 | return emitOpError( | |||
3371 | "expected parent op to be 'spirv.mlir.selection' or 'spirv.mlir.loop'"); | |||
3372 | ||||
3373 | // TODO: This check should be done in `verifyRegions` of parent op. | |||
3374 | Block &parentLastBlock = (*this)->getParentRegion()->back(); | |||
3375 | if (getOperation() != parentLastBlock.getTerminator()) | |||
3376 | return emitOpError("can only be used in the last block of " | |||
3377 | "'spirv.mlir.selection' or 'spirv.mlir.loop'"); | |||
3378 | return success(); | |||
3379 | } | |||
3380 | ||||
3381 | //===----------------------------------------------------------------------===// | |||
3382 | // spirv.module | |||
3383 | //===----------------------------------------------------------------------===// | |||
3384 | ||||
3385 | void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, | |||
3386 | std::optional<StringRef> name) { | |||
3387 | OpBuilder::InsertionGuard guard(builder); | |||
3388 | builder.createBlock(state.addRegion()); | |||
3389 | if (name) { | |||
3390 | state.attributes.append(mlir::SymbolTable::getSymbolAttrName(), | |||
3391 | builder.getStringAttr(*name)); | |||
3392 | } | |||
3393 | } | |||
3394 | ||||
3395 | void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state, | |||
3396 | spirv::AddressingModel addressingModel, | |||
3397 | spirv::MemoryModel memoryModel, | |||
3398 | std::optional<VerCapExtAttr> vceTriple, | |||
3399 | std::optional<StringRef> name) { | |||
3400 | state.addAttribute( | |||
3401 | "addressing_model", | |||
3402 | builder.getAttr<spirv::AddressingModelAttr>(addressingModel)); | |||
3403 | state.addAttribute("memory_model", | |||
3404 | builder.getAttr<spirv::MemoryModelAttr>(memoryModel)); | |||
3405 | OpBuilder::InsertionGuard guard(builder); | |||
3406 | builder.createBlock(state.addRegion()); | |||
3407 | if (vceTriple) | |||
3408 | state.addAttribute(getVCETripleAttrName(), *vceTriple); | |||
3409 | if (name) | |||
3410 | state.addAttribute(mlir::SymbolTable::getSymbolAttrName(), | |||
3411 | builder.getStringAttr(*name)); | |||
3412 | } | |||
3413 | ||||
3414 | ParseResult spirv::ModuleOp::parse(OpAsmParser &parser, | |||
3415 | OperationState &result) { | |||
3416 | Region *body = result.addRegion(); | |||
3417 | ||||
3418 | // If the name is present, parse it. | |||
3419 | StringAttr nameAttr; | |||
3420 | (void)parser.parseOptionalSymbolName( | |||
3421 | nameAttr, mlir::SymbolTable::getSymbolAttrName(), result.attributes); | |||
3422 | ||||
3423 | // Parse attributes | |||
3424 | spirv::AddressingModel addrModel; | |||
3425 | spirv::MemoryModel memoryModel; | |||
3426 | if (::parseEnumKeywordAttr<spirv::AddressingModelAttr>(addrModel, parser, | |||
3427 | result) || | |||
3428 | ::parseEnumKeywordAttr<spirv::MemoryModelAttr>(memoryModel, parser, | |||
3429 | result)) | |||
3430 | return failure(); | |||
3431 | ||||
3432 | if (succeeded(parser.parseOptionalKeyword("requires"))) { | |||
3433 | spirv::VerCapExtAttr vceTriple; | |||
3434 | if (parser.parseAttribute(vceTriple, | |||
3435 | spirv::ModuleOp::getVCETripleAttrName(), | |||
3436 | result.attributes)) | |||
3437 | return failure(); | |||
3438 | } | |||
3439 | ||||
3440 | if (parser.parseOptionalAttrDictWithKeyword(result.attributes) || | |||
3441 | parser.parseRegion(*body, /*arguments=*/{})) | |||
3442 | return failure(); | |||
3443 | ||||
3444 | // Make sure we have at least one block. | |||
3445 | if (body->empty()) | |||
3446 | body->push_back(new Block()); | |||
3447 | ||||
3448 | return success(); | |||
3449 | } | |||
3450 | ||||
3451 | void spirv::ModuleOp::print(OpAsmPrinter &printer) { | |||
3452 | if (std::optional<StringRef> name = getName()) { | |||
3453 | printer << ' '; | |||
3454 | printer.printSymbolName(*name); | |||
3455 | } | |||
3456 | ||||
3457 | SmallVector<StringRef, 2> elidedAttrs; | |||
3458 | ||||
3459 | printer << " " << spirv::stringifyAddressingModel(getAddressingModel()) << " " | |||
3460 | << spirv::stringifyMemoryModel(getMemoryModel()); | |||
3461 | auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>(); | |||
3462 | auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>(); | |||
3463 | elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName, | |||
3464 | mlir::SymbolTable::getSymbolAttrName()}); | |||
3465 | ||||
3466 | if (std::optional<spirv::VerCapExtAttr> triple = getVceTriple()) { | |||
3467 | printer << " requires " << *triple; | |||
3468 | elidedAttrs.push_back(spirv::ModuleOp::getVCETripleAttrName()); | |||
3469 | } | |||
3470 | ||||
3471 | printer.printOptionalAttrDictWithKeyword((*this)->getAttrs(), elidedAttrs); | |||
3472 | printer << ' '; | |||
3473 | printer.printRegion(getRegion()); | |||
3474 | } | |||
3475 | ||||
3476 | LogicalResult spirv::ModuleOp::verifyRegions() { | |||
3477 | Dialect *dialect = (*this)->getDialect(); | |||
3478 | DenseMap<std::pair<spirv::FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp> | |||
3479 | entryPoints; | |||
3480 | mlir::SymbolTable table(*this); | |||
3481 | ||||
3482 | for (auto &op : *getBody()) { | |||
3483 | if (op.getDialect() != dialect) | |||
3484 | return op.emitError("'spirv.module' can only contain spirv.* ops"); | |||
3485 | ||||
3486 | // For EntryPoint op, check that the function and execution model is not | |||
3487 | // duplicated in EntryPointOps. Also verify that the interface specified | |||
3488 | // comes from globalVariables here to make this check cheaper. | |||
3489 | if (auto entryPointOp = dyn_cast<spirv::EntryPointOp>(op)) { | |||
3490 | auto funcOp = table.lookup<spirv::FuncOp>(entryPointOp.getFn()); | |||
3491 | if (!funcOp) { | |||
3492 | return entryPointOp.emitError("function '") | |||
3493 | << entryPointOp.getFn() << "' not found in 'spirv.module'"; | |||
3494 | } | |||
3495 | if (auto interface = entryPointOp.getInterface()) { | |||
3496 | for (Attribute varRef : interface) { | |||
3497 | auto varSymRef = varRef.dyn_cast<FlatSymbolRefAttr>(); | |||
3498 | if (!varSymRef) { | |||
3499 | return entryPointOp.emitError( | |||
3500 | "expected symbol reference for interface " | |||
3501 | "specification instead of '") | |||
3502 | << varRef; | |||
3503 | } | |||
3504 | auto variableOp = | |||
3505 | table.lookup<spirv::GlobalVariableOp>(varSymRef.getValue()); | |||
3506 | if (!variableOp) { | |||
3507 | return entryPointOp.emitError("expected spirv.GlobalVariable " | |||
3508 | "symbol reference instead of'") | |||
3509 | << varSymRef << "'"; | |||
3510 | } | |||
3511 | } | |||
3512 | } | |||
3513 | ||||
3514 | auto key = std::pair<spirv::FuncOp, spirv::ExecutionModel>( | |||
3515 | funcOp, entryPointOp.getExecutionModel()); | |||
3516 | auto entryPtIt = entryPoints.find(key); | |||
3517 | if (entryPtIt != entryPoints.end()) { | |||
3518 | return entryPointOp.emitError("duplicate of a previous EntryPointOp"); | |||
3519 | } | |||
3520 | entryPoints[key] = entryPointOp; | |||
3521 | } else if (auto funcOp = dyn_cast<spirv::FuncOp>(op)) { | |||
3522 | if (funcOp.isExternal()) | |||
3523 | return op.emitError("'spirv.module' cannot contain external functions"); | |||
3524 | ||||
3525 | // TODO: move this check to spirv.func. | |||
3526 | for (auto &block : funcOp) | |||
3527 | for (auto &op : block) { | |||
3528 | if (op.getDialect() != dialect) | |||
3529 | return op.emitError( | |||
3530 | "functions in 'spirv.module' can only contain spirv.* ops"); | |||
3531 | } | |||
3532 | } | |||
3533 | } | |||
3534 | ||||
3535 | return success(); | |||
3536 | } | |||
3537 | ||||
3538 | //===----------------------------------------------------------------------===// | |||
3539 | // spirv.mlir.referenceof | |||
3540 | //===----------------------------------------------------------------------===// | |||
3541 | ||||
3542 | LogicalResult spirv::ReferenceOfOp::verify() { | |||
3543 | auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( | |||
3544 | (*this)->getParentOp(), getSpecConstAttr()); | |||
3545 | Type constType; | |||
3546 | ||||
3547 | auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym); | |||
3548 | if (specConstOp) | |||
3549 | constType = specConstOp.getDefaultValue().getType(); | |||
3550 | ||||
3551 | auto specConstCompositeOp = | |||
3552 | dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym); | |||
3553 | if (specConstCompositeOp) | |||
3554 | constType = specConstCompositeOp.getType(); | |||
3555 | ||||
3556 | if (!specConstOp && !specConstCompositeOp) | |||
3557 | return emitOpError( | |||
3558 | "expected spirv.SpecConstant or spirv.SpecConstantComposite symbol"); | |||
3559 | ||||
3560 | if (getReference().getType() != constType) | |||
3561 | return emitOpError("result type mismatch with the referenced " | |||
3562 | "specialization constant's type"); | |||
3563 | ||||
3564 | return success(); | |||
3565 | } | |||
3566 | ||||
3567 | //===----------------------------------------------------------------------===// | |||
3568 | // spirv.Return | |||
3569 | //===----------------------------------------------------------------------===// | |||
3570 | ||||
3571 | LogicalResult spirv::ReturnOp::verify() { | |||
3572 | // Verification is performed in spirv.func op. | |||
3573 | return success(); | |||
3574 | } | |||
3575 | ||||
3576 | //===----------------------------------------------------------------------===// | |||
3577 | // spirv.ReturnValue | |||
3578 | //===----------------------------------------------------------------------===// | |||
3579 | ||||
3580 | LogicalResult spirv::ReturnValueOp::verify() { | |||
3581 | // Verification is performed in spirv.func op. | |||
3582 | return success(); | |||
3583 | } | |||
3584 | ||||
3585 | //===----------------------------------------------------------------------===// | |||
3586 | // spirv.Select | |||
3587 | //===----------------------------------------------------------------------===// | |||
3588 | ||||
3589 | LogicalResult spirv::SelectOp::verify() { | |||
3590 | if (auto conditionTy = getCondition().getType().dyn_cast<VectorType>()) { | |||
3591 | auto resultVectorTy = getResult().getType().dyn_cast<VectorType>(); | |||
3592 | if (!resultVectorTy) { | |||
3593 | return emitOpError("result expected to be of vector type when " | |||
3594 | "condition is of vector type"); | |||
3595 | } | |||
3596 | if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { | |||
3597 | return emitOpError("result should have the same number of elements as " | |||
3598 | "the condition when condition is of vector type"); | |||
3599 | } | |||
3600 | } | |||
3601 | return success(); | |||
3602 | } | |||
3603 | ||||
3604 | //===----------------------------------------------------------------------===// | |||
3605 | // spirv.mlir.selection | |||
3606 | //===----------------------------------------------------------------------===// | |||
3607 | ||||
3608 | ParseResult spirv::SelectionOp::parse(OpAsmParser &parser, | |||
3609 | OperationState &result) { | |||
3610 | if (parseControlAttribute<spirv::SelectionControlAttr, | |||
3611 | spirv::SelectionControl>(parser, result)) | |||
3612 | return failure(); | |||
3613 | return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); | |||
3614 | } | |||
3615 | ||||
3616 | void spirv::SelectionOp::print(OpAsmPrinter &printer) { | |||
3617 | auto control = getSelectionControl(); | |||
3618 | if (control != spirv::SelectionControl::None) | |||
3619 | printer << " control(" << spirv::stringifySelectionControl(control) << ")"; | |||
3620 | printer << ' '; | |||
3621 | printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, | |||
3622 | /*printBlockTerminators=*/true); | |||
3623 | } | |||
3624 | ||||
3625 | LogicalResult spirv::SelectionOp::verifyRegions() { | |||
3626 | auto *op = getOperation(); | |||
3627 | ||||
3628 | // We need to verify that the blocks follow the following layout: | |||
3629 | // | |||
3630 | // +--------------+ | |||
3631 | // | header block | | |||
3632 | // +--------------+ | |||
3633 | // / | \ | |||
3634 | // ... | |||
3635 | // | |||
3636 | // | |||
3637 | // +---------+ +---------+ +---------+ | |||
3638 | // | case #0 | | case #1 | | case #2 | ... | |||
3639 | // +---------+ +---------+ +---------+ | |||
3640 | // | |||
3641 | // | |||
3642 | // ... | |||
3643 | // \ | / | |||
3644 | // v | |||
3645 | // +-------------+ | |||
3646 | // | merge block | | |||
3647 | // +-------------+ | |||
3648 | ||||
3649 | auto ®ion = op->getRegion(0); | |||
3650 | // Allow empty region as a degenerated case, which can come from | |||
3651 | // optimizations. | |||
3652 | if (region.empty()) | |||
3653 | return success(); | |||
3654 | ||||
3655 | // The last block is the merge block. | |||
3656 | if (!isMergeBlock(region.back())) | |||
3657 | return emitOpError("last block must be the merge block with only one " | |||
3658 | "'spirv.mlir.merge' op"); | |||
3659 | ||||
3660 | if (std::next(region.begin()) == region.end()) | |||
3661 | return emitOpError("must have a selection header block"); | |||
3662 | ||||
3663 | return success(); | |||
3664 | } | |||
3665 | ||||
3666 | Block *spirv::SelectionOp::getHeaderBlock() { | |||
3667 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3667, __extension__ __PRETTY_FUNCTION__)); | |||
3668 | // The first block is the loop header block. | |||
3669 | return &getBody().front(); | |||
3670 | } | |||
3671 | ||||
3672 | Block *spirv::SelectionOp::getMergeBlock() { | |||
3673 | assert(!getBody().empty() && "op region should not be empty!")(static_cast <bool> (!getBody().empty() && "op region should not be empty!" ) ? void (0) : __assert_fail ("!getBody().empty() && \"op region should not be empty!\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3673, __extension__ __PRETTY_FUNCTION__)); | |||
3674 | // The last block is the loop merge block. | |||
3675 | return &getBody().back(); | |||
3676 | } | |||
3677 | ||||
3678 | void spirv::SelectionOp::addMergeBlock() { | |||
3679 | assert(getBody().empty() && "entry and merge block already exist")(static_cast <bool> (getBody().empty() && "entry and merge block already exist" ) ? void (0) : __assert_fail ("getBody().empty() && \"entry and merge block already exist\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 3679, __extension__ __PRETTY_FUNCTION__)); | |||
3680 | auto *mergeBlock = new Block(); | |||
3681 | getBody().push_back(mergeBlock); | |||
3682 | OpBuilder builder = OpBuilder::atBlockEnd(mergeBlock); | |||
3683 | ||||
3684 | // Add a spirv.mlir.merge op into the merge block. | |||
3685 | builder.create<spirv::MergeOp>(getLoc()); | |||
3686 | } | |||
3687 | ||||
3688 | spirv::SelectionOp spirv::SelectionOp::createIfThen( | |||
3689 | Location loc, Value condition, | |||
3690 | function_ref<void(OpBuilder &builder)> thenBody, OpBuilder &builder) { | |||
3691 | auto selectionOp = | |||
3692 | builder.create<spirv::SelectionOp>(loc, spirv::SelectionControl::None); | |||
3693 | ||||
3694 | selectionOp.addMergeBlock(); | |||
3695 | Block *mergeBlock = selectionOp.getMergeBlock(); | |||
3696 | Block *thenBlock = nullptr; | |||
3697 | ||||
3698 | // Build the "then" block. | |||
3699 | { | |||
3700 | OpBuilder::InsertionGuard guard(builder); | |||
3701 | thenBlock = builder.createBlock(mergeBlock); | |||
3702 | thenBody(builder); | |||
3703 | builder.create<spirv::BranchOp>(loc, mergeBlock); | |||
3704 | } | |||
3705 | ||||
3706 | // Build the header block. | |||
3707 | { | |||
3708 | OpBuilder::InsertionGuard guard(builder); | |||
3709 | builder.createBlock(thenBlock); | |||
3710 | builder.create<spirv::BranchConditionalOp>( | |||
3711 | loc, condition, thenBlock, | |||
3712 | /*trueArguments=*/ArrayRef<Value>(), mergeBlock, | |||
3713 | /*falseArguments=*/ArrayRef<Value>()); | |||
3714 | } | |||
3715 | ||||
3716 | return selectionOp; | |||
3717 | } | |||
3718 | ||||
3719 | //===----------------------------------------------------------------------===// | |||
3720 | // spirv.SpecConstant | |||
3721 | //===----------------------------------------------------------------------===// | |||
3722 | ||||
3723 | ParseResult spirv::SpecConstantOp::parse(OpAsmParser &parser, | |||
3724 | OperationState &result) { | |||
3725 | StringAttr nameAttr; | |||
3726 | Attribute valueAttr; | |||
3727 | ||||
3728 | if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), | |||
3729 | result.attributes)) | |||
3730 | return failure(); | |||
3731 | ||||
3732 | // Parse optional spec_id. | |||
3733 | if (succeeded(parser.parseOptionalKeyword(kSpecIdAttrName))) { | |||
3734 | IntegerAttr specIdAttr; | |||
3735 | if (parser.parseLParen() || | |||
3736 | parser.parseAttribute(specIdAttr, kSpecIdAttrName, result.attributes) || | |||
3737 | parser.parseRParen()) | |||
3738 | return failure(); | |||
3739 | } | |||
3740 | ||||
3741 | if (parser.parseEqual() || | |||
3742 | parser.parseAttribute(valueAttr, kDefaultValueAttrName, | |||
3743 | result.attributes)) | |||
3744 | return failure(); | |||
3745 | ||||
3746 | return success(); | |||
3747 | } | |||
3748 | ||||
3749 | void spirv::SpecConstantOp::print(OpAsmPrinter &printer) { | |||
3750 | printer << ' '; | |||
3751 | printer.printSymbolName(getSymName()); | |||
3752 | if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) | |||
3753 | printer << ' ' << kSpecIdAttrName << '(' << specID.getInt() << ')'; | |||
3754 | printer << " = " << getDefaultValue(); | |||
3755 | } | |||
3756 | ||||
3757 | LogicalResult spirv::SpecConstantOp::verify() { | |||
3758 | if (auto specID = (*this)->getAttrOfType<IntegerAttr>(kSpecIdAttrName)) | |||
3759 | if (specID.getValue().isNegative()) | |||
3760 | return emitOpError("SpecId cannot be negative"); | |||
3761 | ||||
3762 | auto value = getDefaultValue(); | |||
3763 | if (value.isa<IntegerAttr, FloatAttr>()) { | |||
3764 | // Make sure bitwidth is allowed. | |||
3765 | if (!value.getType().isa<spirv::SPIRVType>()) | |||
3766 | return emitOpError("default value bitwidth disallowed"); | |||
3767 | return success(); | |||
3768 | } | |||
3769 | return emitOpError( | |||
3770 | "default value can only be a bool, integer, or float scalar"); | |||
3771 | } | |||
3772 | ||||
3773 | //===----------------------------------------------------------------------===// | |||
3774 | // spirv.StoreOp | |||
3775 | //===----------------------------------------------------------------------===// | |||
3776 | ||||
3777 | ParseResult spirv::StoreOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3778 | // Parse the storage class specification | |||
3779 | spirv::StorageClass storageClass; | |||
3780 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfo; | |||
3781 | auto loc = parser.getCurrentLocation(); | |||
3782 | Type elementType; | |||
3783 | if (parseEnumStrAttr(storageClass, parser) || | |||
3784 | parser.parseOperandList(operandInfo, 2) || | |||
3785 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
3786 | parser.parseType(elementType)) { | |||
3787 | return failure(); | |||
3788 | } | |||
3789 | ||||
3790 | auto ptrType = spirv::PointerType::get(elementType, storageClass); | |||
3791 | if (parser.resolveOperands(operandInfo, {ptrType, elementType}, loc, | |||
3792 | result.operands)) { | |||
3793 | return failure(); | |||
3794 | } | |||
3795 | return success(); | |||
3796 | } | |||
3797 | ||||
3798 | void spirv::StoreOp::print(OpAsmPrinter &printer) { | |||
3799 | SmallVector<StringRef, 4> elidedAttrs; | |||
3800 | StringRef sc = stringifyStorageClass( | |||
3801 | getPtr().getType().cast<spirv::PointerType>().getStorageClass()); | |||
3802 | printer << " \"" << sc << "\" " << getPtr() << ", " << getValue(); | |||
3803 | ||||
3804 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
3805 | ||||
3806 | printer << " : " << getValue().getType(); | |||
3807 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
3808 | } | |||
3809 | ||||
3810 | LogicalResult spirv::StoreOp::verify() { | |||
3811 | // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an | |||
3812 | // OpTypePointer whose Type operand is the same as the type of Object." | |||
3813 | if (failed(verifyLoadStorePtrAndValTypes(*this, getPtr(), getValue()))) | |||
3814 | return failure(); | |||
3815 | return verifyMemoryAccessAttribute(*this); | |||
3816 | } | |||
3817 | ||||
3818 | //===----------------------------------------------------------------------===// | |||
3819 | // spirv.Unreachable | |||
3820 | //===----------------------------------------------------------------------===// | |||
3821 | ||||
3822 | LogicalResult spirv::UnreachableOp::verify() { | |||
3823 | auto *block = (*this)->getBlock(); | |||
3824 | // Fast track: if this is in entry block, its invalid. Otherwise, if no | |||
3825 | // predecessors, it's valid. | |||
3826 | if (block->isEntryBlock()) | |||
3827 | return emitOpError("cannot be used in reachable block"); | |||
3828 | if (block->hasNoPredecessors()) | |||
3829 | return success(); | |||
3830 | ||||
3831 | // TODO: further verification needs to analyze reachability from | |||
3832 | // the entry block. | |||
3833 | ||||
3834 | return success(); | |||
3835 | } | |||
3836 | ||||
3837 | //===----------------------------------------------------------------------===// | |||
3838 | // spirv.Variable | |||
3839 | //===----------------------------------------------------------------------===// | |||
3840 | ||||
3841 | ParseResult spirv::VariableOp::parse(OpAsmParser &parser, | |||
3842 | OperationState &result) { | |||
3843 | // Parse optional initializer | |||
3844 | std::optional<OpAsmParser::UnresolvedOperand> initInfo; | |||
3845 | if (succeeded(parser.parseOptionalKeyword("init"))) { | |||
3846 | initInfo = OpAsmParser::UnresolvedOperand(); | |||
3847 | if (parser.parseLParen() || parser.parseOperand(*initInfo) || | |||
3848 | parser.parseRParen()) | |||
3849 | return failure(); | |||
3850 | } | |||
3851 | ||||
3852 | if (parseVariableDecorations(parser, result)) { | |||
3853 | return failure(); | |||
3854 | } | |||
3855 | ||||
3856 | // Parse result pointer type | |||
3857 | Type type; | |||
3858 | if (parser.parseColon()) | |||
3859 | return failure(); | |||
3860 | auto loc = parser.getCurrentLocation(); | |||
3861 | if (parser.parseType(type)) | |||
3862 | return failure(); | |||
3863 | ||||
3864 | auto ptrType = type.dyn_cast<spirv::PointerType>(); | |||
3865 | if (!ptrType) | |||
3866 | return parser.emitError(loc, "expected spirv.ptr type"); | |||
3867 | result.addTypes(ptrType); | |||
3868 | ||||
3869 | // Resolve the initializer operand | |||
3870 | if (initInfo) { | |||
3871 | if (parser.resolveOperand(*initInfo, ptrType.getPointeeType(), | |||
3872 | result.operands)) | |||
3873 | return failure(); | |||
3874 | } | |||
3875 | ||||
3876 | auto attr = parser.getBuilder().getAttr<spirv::StorageClassAttr>( | |||
3877 | ptrType.getStorageClass()); | |||
3878 | result.addAttribute(spirv::attributeName<spirv::StorageClass>(), attr); | |||
3879 | ||||
3880 | return success(); | |||
3881 | } | |||
3882 | ||||
3883 | void spirv::VariableOp::print(OpAsmPrinter &printer) { | |||
3884 | SmallVector<StringRef, 4> elidedAttrs{ | |||
3885 | spirv::attributeName<spirv::StorageClass>()}; | |||
3886 | // Print optional initializer | |||
3887 | if (getNumOperands() != 0) | |||
3888 | printer << " init(" << getInitializer() << ")"; | |||
3889 | ||||
3890 | printVariableDecorations(*this, printer, elidedAttrs); | |||
3891 | printer << " : " << getType(); | |||
3892 | } | |||
3893 | ||||
3894 | LogicalResult spirv::VariableOp::verify() { | |||
3895 | // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the | |||
3896 | // object. It cannot be Generic. It must be the same as the Storage Class | |||
3897 | // operand of the Result Type." | |||
3898 | if (getStorageClass() != spirv::StorageClass::Function) { | |||
3899 | return emitOpError( | |||
3900 | "can only be used to model function-level variables. Use " | |||
3901 | "spirv.GlobalVariable for module-level variables."); | |||
3902 | } | |||
3903 | ||||
3904 | auto pointerType = getPointer().getType().cast<spirv::PointerType>(); | |||
3905 | if (getStorageClass() != pointerType.getStorageClass()) | |||
3906 | return emitOpError( | |||
3907 | "storage class must match result pointer's storage class"); | |||
3908 | ||||
3909 | if (getNumOperands() != 0) { | |||
3910 | // SPIR-V spec: "Initializer must be an <id> from a constant instruction or | |||
3911 | // a global (module scope) OpVariable instruction". | |||
3912 | auto *initOp = getOperand(0).getDefiningOp(); | |||
3913 | if (!initOp || !isa<spirv::ConstantOp, // for normal constant | |||
3914 | spirv::ReferenceOfOp, // for spec constant | |||
3915 | spirv::AddressOfOp>(initOp)) | |||
3916 | return emitOpError("initializer must be the result of a " | |||
3917 | "constant or spirv.GlobalVariable op"); | |||
3918 | } | |||
3919 | ||||
3920 | // TODO: generate these strings using ODS. | |||
3921 | auto *op = getOperation(); | |||
3922 | auto descriptorSetName = llvm::convertToSnakeFromCamelCase( | |||
3923 | stringifyDecoration(spirv::Decoration::DescriptorSet)); | |||
3924 | auto bindingName = llvm::convertToSnakeFromCamelCase( | |||
3925 | stringifyDecoration(spirv::Decoration::Binding)); | |||
3926 | auto builtInName = llvm::convertToSnakeFromCamelCase( | |||
3927 | stringifyDecoration(spirv::Decoration::BuiltIn)); | |||
3928 | ||||
3929 | for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { | |||
3930 | if (op->getAttr(attr)) | |||
3931 | return emitOpError("cannot have '") | |||
3932 | << attr << "' attribute (only allowed in spirv.GlobalVariable)"; | |||
3933 | } | |||
3934 | ||||
3935 | return success(); | |||
3936 | } | |||
3937 | ||||
3938 | //===----------------------------------------------------------------------===// | |||
3939 | // spirv.VectorShuffle | |||
3940 | //===----------------------------------------------------------------------===// | |||
3941 | ||||
3942 | LogicalResult spirv::VectorShuffleOp::verify() { | |||
3943 | VectorType resultType = getType().cast<VectorType>(); | |||
3944 | ||||
3945 | size_t numResultElements = resultType.getNumElements(); | |||
3946 | if (numResultElements != getComponents().size()) | |||
3947 | return emitOpError("result type element count (") | |||
3948 | << numResultElements | |||
3949 | << ") mismatch with the number of component selectors (" | |||
3950 | << getComponents().size() << ")"; | |||
3951 | ||||
3952 | size_t totalSrcElements = | |||
3953 | getVector1().getType().cast<VectorType>().getNumElements() + | |||
3954 | getVector2().getType().cast<VectorType>().getNumElements(); | |||
3955 | ||||
3956 | for (const auto &selector : getComponents().getAsValueRange<IntegerAttr>()) { | |||
3957 | uint32_t index = selector.getZExtValue(); | |||
3958 | if (index >= totalSrcElements && | |||
3959 | index != std::numeric_limits<uint32_t>().max()) | |||
3960 | return emitOpError("component selector ") | |||
3961 | << index << " out of range: expected to be in [0, " | |||
3962 | << totalSrcElements << ") or 0xffffffff"; | |||
3963 | } | |||
3964 | return success(); | |||
3965 | } | |||
3966 | ||||
3967 | //===----------------------------------------------------------------------===// | |||
3968 | // spirv.NV.CooperativeMatrixLoad | |||
3969 | //===----------------------------------------------------------------------===// | |||
3970 | ||||
3971 | ParseResult spirv::NVCooperativeMatrixLoadOp::parse(OpAsmParser &parser, | |||
3972 | OperationState &result) { | |||
3973 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operandInfo; | |||
3974 | Type strideType = parser.getBuilder().getIntegerType(32); | |||
3975 | Type columnMajorType = parser.getBuilder().getIntegerType(1); | |||
3976 | Type ptrType; | |||
3977 | Type elementType; | |||
3978 | if (parser.parseOperandList(operandInfo, 3) || | |||
3979 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
3980 | parser.parseType(ptrType) || parser.parseKeywordType("as", elementType)) { | |||
3981 | return failure(); | |||
3982 | } | |||
3983 | if (parser.resolveOperands(operandInfo, | |||
3984 | {ptrType, strideType, columnMajorType}, | |||
3985 | parser.getNameLoc(), result.operands)) { | |||
3986 | return failure(); | |||
3987 | } | |||
3988 | ||||
3989 | result.addTypes(elementType); | |||
3990 | return success(); | |||
3991 | } | |||
3992 | ||||
3993 | void spirv::NVCooperativeMatrixLoadOp::print(OpAsmPrinter &printer) { | |||
3994 | printer << " " << getPointer() << ", " << getStride() << ", " | |||
3995 | << getColumnmajor(); | |||
3996 | // Print optional memory access attribute. | |||
3997 | if (auto memAccess = getMemoryAccess()) | |||
3998 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; | |||
3999 | printer << " : " << getPointer().getType() << " as " << getType(); | |||
4000 | } | |||
4001 | ||||
4002 | static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, | |||
4003 | Type coopMatrix) { | |||
4004 | Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType(); | |||
4005 | if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>()) | |||
4006 | return op->emitError( | |||
4007 | "Pointer must point to a scalar or vector type but provided ") | |||
4008 | << pointeeType; | |||
4009 | spirv::StorageClass storage = | |||
4010 | pointer.cast<spirv::PointerType>().getStorageClass(); | |||
4011 | if (storage != spirv::StorageClass::Workgroup && | |||
4012 | storage != spirv::StorageClass::StorageBuffer && | |||
4013 | storage != spirv::StorageClass::PhysicalStorageBuffer) | |||
4014 | return op->emitError( | |||
4015 | "Pointer storage class must be Workgroup, StorageBuffer or " | |||
4016 | "PhysicalStorageBufferEXT but provided ") | |||
4017 | << stringifyStorageClass(storage); | |||
4018 | return success(); | |||
4019 | } | |||
4020 | ||||
4021 | LogicalResult spirv::NVCooperativeMatrixLoadOp::verify() { | |||
4022 | return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), | |||
4023 | getResult().getType()); | |||
4024 | } | |||
4025 | ||||
4026 | //===----------------------------------------------------------------------===// | |||
4027 | // spirv.NV.CooperativeMatrixStore | |||
4028 | //===----------------------------------------------------------------------===// | |||
4029 | ||||
4030 | ParseResult spirv::NVCooperativeMatrixStoreOp::parse(OpAsmParser &parser, | |||
4031 | OperationState &result) { | |||
4032 | SmallVector<OpAsmParser::UnresolvedOperand, 4> operandInfo; | |||
4033 | Type strideType = parser.getBuilder().getIntegerType(32); | |||
4034 | Type columnMajorType = parser.getBuilder().getIntegerType(1); | |||
4035 | Type ptrType; | |||
4036 | Type elementType; | |||
4037 | if (parser.parseOperandList(operandInfo, 4) || | |||
4038 | parseMemoryAccessAttributes(parser, result) || parser.parseColon() || | |||
4039 | parser.parseType(ptrType) || parser.parseComma() || | |||
4040 | parser.parseType(elementType)) { | |||
4041 | return failure(); | |||
4042 | } | |||
4043 | if (parser.resolveOperands( | |||
4044 | operandInfo, {ptrType, elementType, strideType, columnMajorType}, | |||
4045 | parser.getNameLoc(), result.operands)) { | |||
4046 | return failure(); | |||
4047 | } | |||
4048 | ||||
4049 | return success(); | |||
4050 | } | |||
4051 | ||||
4052 | void spirv::NVCooperativeMatrixStoreOp::print(OpAsmPrinter &printer) { | |||
4053 | printer << " " << getPointer() << ", " << getObject() << ", " << getStride() | |||
4054 | << ", " << getColumnmajor(); | |||
4055 | // Print optional memory access attribute. | |||
4056 | if (auto memAccess = getMemoryAccess()) | |||
4057 | printer << " [\"" << stringifyMemoryAccess(*memAccess) << "\"]"; | |||
4058 | printer << " : " << getPointer().getType() << ", " << getOperand(1).getType(); | |||
4059 | } | |||
4060 | ||||
4061 | LogicalResult spirv::NVCooperativeMatrixStoreOp::verify() { | |||
4062 | return verifyPointerAndCoopMatrixType(*this, getPointer().getType(), | |||
4063 | getObject().getType()); | |||
4064 | } | |||
4065 | ||||
4066 | //===----------------------------------------------------------------------===// | |||
4067 | // spirv.NV.CooperativeMatrixMulAdd | |||
4068 | //===----------------------------------------------------------------------===// | |||
4069 | ||||
4070 | static LogicalResult | |||
4071 | verifyCoopMatrixMulAdd(spirv::NVCooperativeMatrixMulAddOp op) { | |||
4072 | if (op.getC().getType() != op.getResult().getType()) | |||
4073 | return op.emitOpError("result and third operand must have the same type"); | |||
4074 | auto typeA = op.getA().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4075 | auto typeB = op.getB().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4076 | auto typeC = op.getC().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4077 | auto typeR = op.getResult().getType().cast<spirv::CooperativeMatrixNVType>(); | |||
4078 | if (typeA.getRows() != typeR.getRows() || | |||
4079 | typeA.getColumns() != typeB.getRows() || | |||
4080 | typeB.getColumns() != typeR.getColumns()) | |||
4081 | return op.emitOpError("matrix size must match"); | |||
4082 | if (typeR.getScope() != typeA.getScope() || | |||
4083 | typeR.getScope() != typeB.getScope() || | |||
4084 | typeR.getScope() != typeC.getScope()) | |||
4085 | return op.emitOpError("matrix scope must match"); | |||
4086 | auto elementTypeA = typeA.getElementType(); | |||
4087 | auto elementTypeB = typeB.getElementType(); | |||
4088 | if (isa<IntegerType>(elementTypeA) && isa<IntegerType>(elementTypeB)) { | |||
4089 | if (elementTypeA.cast<IntegerType>().getWidth() != | |||
4090 | elementTypeB.cast<IntegerType>().getWidth()) | |||
4091 | return op.emitOpError( | |||
4092 | "matrix A and B integer element types must be the same bit width"); | |||
4093 | } else if (elementTypeA != elementTypeB) { | |||
4094 | return op.emitOpError( | |||
4095 | "matrix A and B non-integer element types must match"); | |||
4096 | } | |||
4097 | if (typeR.getElementType() != typeC.getElementType()) | |||
4098 | return op.emitOpError("matrix accumulator element type must match"); | |||
4099 | return success(); | |||
4100 | } | |||
4101 | ||||
4102 | LogicalResult spirv::NVCooperativeMatrixMulAddOp::verify() { | |||
4103 | return verifyCoopMatrixMulAdd(*this); | |||
4104 | } | |||
4105 | ||||
4106 | static LogicalResult | |||
4107 | verifyPointerAndJointMatrixType(Operation *op, Type pointer, Type jointMatrix) { | |||
4108 | Type pointeeType = pointer.cast<spirv::PointerType>().getPointeeType(); | |||
4109 | if (!pointeeType.isa<spirv::ScalarType>() && !pointeeType.isa<VectorType>()) | |||
4110 | return op->emitError( | |||
4111 | "Pointer must point to a scalar or vector type but provided ") | |||
4112 | << pointeeType; | |||
4113 | spirv::StorageClass storage = | |||
4114 | pointer.cast<spirv::PointerType>().getStorageClass(); | |||
4115 | if (storage != spirv::StorageClass::Workgroup && | |||
4116 | storage != spirv::StorageClass::CrossWorkgroup && | |||
4117 | storage != spirv::StorageClass::UniformConstant && | |||
4118 | storage != spirv::StorageClass::Generic) | |||
4119 | return op->emitError("Pointer storage class must be Workgroup or " | |||
4120 | "CrossWorkgroup but provided ") | |||
4121 | << stringifyStorageClass(storage); | |||
4122 | return success(); | |||
4123 | } | |||
4124 | ||||
4125 | //===----------------------------------------------------------------------===// | |||
4126 | // spirv.INTEL.JointMatrixLoad | |||
4127 | //===----------------------------------------------------------------------===// | |||
4128 | ||||
4129 | LogicalResult spirv::INTELJointMatrixLoadOp::verify() { | |||
4130 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), | |||
4131 | getResult().getType()); | |||
4132 | } | |||
4133 | ||||
4134 | //===----------------------------------------------------------------------===// | |||
4135 | // spirv.INTEL.JointMatrixStore | |||
4136 | //===----------------------------------------------------------------------===// | |||
4137 | ||||
4138 | LogicalResult spirv::INTELJointMatrixStoreOp::verify() { | |||
4139 | return verifyPointerAndJointMatrixType(*this, getPointer().getType(), | |||
4140 | getObject().getType()); | |||
4141 | } | |||
4142 | ||||
4143 | //===----------------------------------------------------------------------===// | |||
4144 | // spirv.INTEL.JointMatrixMad | |||
4145 | //===----------------------------------------------------------------------===// | |||
4146 | ||||
4147 | static LogicalResult verifyJointMatrixMad(spirv::INTELJointMatrixMadOp op) { | |||
4148 | if (op.getC().getType() != op.getResult().getType()) | |||
4149 | return op.emitOpError("result and third operand must have the same type"); | |||
4150 | auto typeA = op.getA().getType().cast<spirv::JointMatrixINTELType>(); | |||
4151 | auto typeB = op.getB().getType().cast<spirv::JointMatrixINTELType>(); | |||
4152 | auto typeC = op.getC().getType().cast<spirv::JointMatrixINTELType>(); | |||
4153 | auto typeR = op.getResult().getType().cast<spirv::JointMatrixINTELType>(); | |||
4154 | if (typeA.getRows() != typeR.getRows() || | |||
4155 | typeA.getColumns() != typeB.getRows() || | |||
4156 | typeB.getColumns() != typeR.getColumns()) | |||
4157 | return op.emitOpError("matrix size must match"); | |||
4158 | if (typeR.getScope() != typeA.getScope() || | |||
4159 | typeR.getScope() != typeB.getScope() || | |||
4160 | typeR.getScope() != typeC.getScope()) | |||
4161 | return op.emitOpError("matrix scope must match"); | |||
4162 | if (typeA.getElementType() != typeB.getElementType() || | |||
4163 | typeR.getElementType() != typeC.getElementType()) | |||
4164 | return op.emitOpError("matrix element type must match"); | |||
4165 | return success(); | |||
4166 | } | |||
4167 | ||||
4168 | LogicalResult spirv::INTELJointMatrixMadOp::verify() { | |||
4169 | return verifyJointMatrixMad(*this); | |||
4170 | } | |||
4171 | ||||
4172 | //===----------------------------------------------------------------------===// | |||
4173 | // spirv.MatrixTimesScalar | |||
4174 | //===----------------------------------------------------------------------===// | |||
4175 | ||||
4176 | LogicalResult spirv::MatrixTimesScalarOp::verify() { | |||
4177 | if (auto inputCoopmat = | |||
4178 | getMatrix().getType().dyn_cast<spirv::CooperativeMatrixNVType>()) { | |||
4179 | if (inputCoopmat.getElementType() != getScalar().getType()) | |||
4180 | return emitError("input matrix components' type and scaling value must " | |||
4181 | "have the same type"); | |||
4182 | return success(); | |||
4183 | } | |||
4184 | ||||
4185 | // Check that the scalar type is the same as the matrix element type. | |||
4186 | auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>(); | |||
4187 | if (getScalar().getType() != inputMatrix.getElementType()) | |||
4188 | return emitError("input matrix components' type and scaling value must " | |||
4189 | "have the same type"); | |||
4190 | ||||
4191 | return success(); | |||
4192 | } | |||
4193 | ||||
4194 | //===----------------------------------------------------------------------===// | |||
4195 | // spirv.CopyMemory | |||
4196 | //===----------------------------------------------------------------------===// | |||
4197 | ||||
4198 | void spirv::CopyMemoryOp::print(OpAsmPrinter &printer) { | |||
4199 | printer << ' '; | |||
4200 | ||||
4201 | StringRef targetStorageClass = stringifyStorageClass( | |||
4202 | getTarget().getType().cast<spirv::PointerType>().getStorageClass()); | |||
4203 | printer << " \"" << targetStorageClass << "\" " << getTarget() << ", "; | |||
4204 | ||||
4205 | StringRef sourceStorageClass = stringifyStorageClass( | |||
4206 | getSource().getType().cast<spirv::PointerType>().getStorageClass()); | |||
4207 | printer << " \"" << sourceStorageClass << "\" " << getSource(); | |||
4208 | ||||
4209 | SmallVector<StringRef, 4> elidedAttrs; | |||
4210 | printMemoryAccessAttribute(*this, printer, elidedAttrs); | |||
4211 | printSourceMemoryAccessAttribute(*this, printer, elidedAttrs, | |||
4212 | getSourceMemoryAccess(), | |||
4213 | getSourceAlignment()); | |||
4214 | ||||
4215 | printer.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs); | |||
4216 | ||||
4217 | Type pointeeType = | |||
4218 | getTarget().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4219 | printer << " : " << pointeeType; | |||
4220 | } | |||
4221 | ||||
4222 | ParseResult spirv::CopyMemoryOp::parse(OpAsmParser &parser, | |||
4223 | OperationState &result) { | |||
4224 | spirv::StorageClass targetStorageClass; | |||
4225 | OpAsmParser::UnresolvedOperand targetPtrInfo; | |||
4226 | ||||
4227 | spirv::StorageClass sourceStorageClass; | |||
4228 | OpAsmParser::UnresolvedOperand sourcePtrInfo; | |||
4229 | ||||
4230 | Type elementType; | |||
4231 | ||||
4232 | if (parseEnumStrAttr(targetStorageClass, parser) || | |||
4233 | parser.parseOperand(targetPtrInfo) || parser.parseComma() || | |||
4234 | parseEnumStrAttr(sourceStorageClass, parser) || | |||
4235 | parser.parseOperand(sourcePtrInfo) || | |||
4236 | parseMemoryAccessAttributes(parser, result)) { | |||
4237 | return failure(); | |||
4238 | } | |||
4239 | ||||
4240 | if (!parser.parseOptionalComma()) { | |||
4241 | // Parse 2nd memory access attributes. | |||
4242 | if (parseSourceMemoryAccessAttributes(parser, result)) { | |||
4243 | return failure(); | |||
4244 | } | |||
4245 | } | |||
4246 | ||||
4247 | if (parser.parseColon() || parser.parseType(elementType)) | |||
4248 | return failure(); | |||
4249 | ||||
4250 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
4251 | return failure(); | |||
4252 | ||||
4253 | auto targetPtrType = spirv::PointerType::get(elementType, targetStorageClass); | |||
4254 | auto sourcePtrType = spirv::PointerType::get(elementType, sourceStorageClass); | |||
4255 | ||||
4256 | if (parser.resolveOperand(targetPtrInfo, targetPtrType, result.operands) || | |||
4257 | parser.resolveOperand(sourcePtrInfo, sourcePtrType, result.operands)) { | |||
4258 | return failure(); | |||
4259 | } | |||
4260 | ||||
4261 | return success(); | |||
4262 | } | |||
4263 | ||||
4264 | LogicalResult spirv::CopyMemoryOp::verify() { | |||
4265 | Type targetType = | |||
4266 | getTarget().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4267 | ||||
4268 | Type sourceType = | |||
4269 | getSource().getType().cast<spirv::PointerType>().getPointeeType(); | |||
4270 | ||||
4271 | if (targetType != sourceType) | |||
4272 | return emitOpError("both operands must be pointers to the same type"); | |||
4273 | ||||
4274 | if (failed(verifyMemoryAccessAttribute(*this))) | |||
4275 | return failure(); | |||
4276 | ||||
4277 | // TODO - According to the spec: | |||
4278 | // | |||
4279 | // If two masks are present, the first applies to Target and cannot include | |||
4280 | // MakePointerVisible, and the second applies to Source and cannot include | |||
4281 | // MakePointerAvailable. | |||
4282 | // | |||
4283 | // Add such verification here. | |||
4284 | ||||
4285 | return verifySourceMemoryAccessAttribute(*this); | |||
4286 | } | |||
4287 | ||||
4288 | //===----------------------------------------------------------------------===// | |||
4289 | // spirv.Transpose | |||
4290 | //===----------------------------------------------------------------------===// | |||
4291 | ||||
4292 | LogicalResult spirv::TransposeOp::verify() { | |||
4293 | auto inputMatrix = getMatrix().getType().cast<spirv::MatrixType>(); | |||
4294 | auto resultMatrix = getResult().getType().cast<spirv::MatrixType>(); | |||
4295 | ||||
4296 | // Verify that the input and output matrices have correct shapes. | |||
4297 | if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) | |||
4298 | return emitError("input matrix rows count must be equal to " | |||
4299 | "output matrix columns count"); | |||
4300 | ||||
4301 | if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) | |||
4302 | return emitError("input matrix columns count must be equal to " | |||
4303 | "output matrix rows count"); | |||
4304 | ||||
4305 | // Verify that the input and output matrices have the same component type | |||
4306 | if (inputMatrix.getElementType() != resultMatrix.getElementType()) | |||
4307 | return emitError("input and output matrices must have the same " | |||
4308 | "component type"); | |||
4309 | ||||
4310 | return success(); | |||
4311 | } | |||
4312 | ||||
4313 | //===----------------------------------------------------------------------===// | |||
4314 | // spirv.MatrixTimesMatrix | |||
4315 | //===----------------------------------------------------------------------===// | |||
4316 | ||||
4317 | LogicalResult spirv::MatrixTimesMatrixOp::verify() { | |||
4318 | auto leftMatrix = getLeftmatrix().getType().cast<spirv::MatrixType>(); | |||
4319 | auto rightMatrix = getRightmatrix().getType().cast<spirv::MatrixType>(); | |||
4320 | auto resultMatrix = getResult().getType().cast<spirv::MatrixType>(); | |||
4321 | ||||
4322 | // left matrix columns' count and right matrix rows' count must be equal | |||
4323 | if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) | |||
4324 | return emitError("left matrix columns' count must be equal to " | |||
4325 | "the right matrix rows' count"); | |||
4326 | ||||
4327 | // right and result matrices columns' count must be the same | |||
4328 | if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) | |||
4329 | return emitError( | |||
4330 | "right and result matrices must have equal columns' count"); | |||
4331 | ||||
4332 | // right and result matrices component type must be the same | |||
4333 | if (rightMatrix.getElementType() != resultMatrix.getElementType()) | |||
4334 | return emitError("right and result matrices' component type must" | |||
4335 | " be the same"); | |||
4336 | ||||
4337 | // left and result matrices component type must be the same | |||
4338 | if (leftMatrix.getElementType() != resultMatrix.getElementType()) | |||
4339 | return emitError("left and result matrices' component type" | |||
4340 | " must be the same"); | |||
4341 | ||||
4342 | // left and result matrices rows count must be the same | |||
4343 | if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) | |||
4344 | return emitError("left and result matrices must have equal rows' count"); | |||
4345 | ||||
4346 | return success(); | |||
4347 | } | |||
4348 | ||||
4349 | //===----------------------------------------------------------------------===// | |||
4350 | // spirv.SpecConstantComposite | |||
4351 | //===----------------------------------------------------------------------===// | |||
4352 | ||||
4353 | ParseResult spirv::SpecConstantCompositeOp::parse(OpAsmParser &parser, | |||
4354 | OperationState &result) { | |||
4355 | ||||
4356 | StringAttr compositeName; | |||
4357 | if (parser.parseSymbolName(compositeName, SymbolTable::getSymbolAttrName(), | |||
4358 | result.attributes)) | |||
4359 | return failure(); | |||
4360 | ||||
4361 | if (parser.parseLParen()) | |||
4362 | return failure(); | |||
4363 | ||||
4364 | SmallVector<Attribute, 4> constituents; | |||
4365 | ||||
4366 | do { | |||
4367 | // The name of the constituent attribute isn't important | |||
4368 | const char *attrName = "spec_const"; | |||
4369 | FlatSymbolRefAttr specConstRef; | |||
4370 | NamedAttrList attrs; | |||
4371 | ||||
4372 | if (parser.parseAttribute(specConstRef, Type(), attrName, attrs)) | |||
4373 | return failure(); | |||
4374 | ||||
4375 | constituents.push_back(specConstRef); | |||
4376 | } while (!parser.parseOptionalComma()); | |||
4377 | ||||
4378 | if (parser.parseRParen()) | |||
4379 | return failure(); | |||
4380 | ||||
4381 | result.addAttribute(kCompositeSpecConstituentsName, | |||
4382 | parser.getBuilder().getArrayAttr(constituents)); | |||
4383 | ||||
4384 | Type type; | |||
4385 | if (parser.parseColonType(type)) | |||
4386 | return failure(); | |||
4387 | ||||
4388 | result.addAttribute(kTypeAttrName, TypeAttr::get(type)); | |||
4389 | ||||
4390 | return success(); | |||
4391 | } | |||
4392 | ||||
4393 | void spirv::SpecConstantCompositeOp::print(OpAsmPrinter &printer) { | |||
4394 | printer << " "; | |||
4395 | printer.printSymbolName(getSymName()); | |||
4396 | printer << " ("; | |||
4397 | auto constituents = this->getConstituents().getValue(); | |||
4398 | ||||
4399 | if (!constituents.empty()) | |||
4400 | llvm::interleaveComma(constituents, printer); | |||
4401 | ||||
4402 | printer << ") : " << getType(); | |||
4403 | } | |||
4404 | ||||
4405 | LogicalResult spirv::SpecConstantCompositeOp::verify() { | |||
4406 | auto cType = getType().dyn_cast<spirv::CompositeType>(); | |||
4407 | auto constituents = this->getConstituents().getValue(); | |||
4408 | ||||
4409 | if (!cType) | |||
4410 | return emitError("result type must be a composite type, but provided ") | |||
4411 | << getType(); | |||
4412 | ||||
4413 | if (cType.isa<spirv::CooperativeMatrixNVType>()) | |||
4414 | return emitError("unsupported composite type ") << cType; | |||
4415 | if (cType.isa<spirv::JointMatrixINTELType>()) | |||
4416 | return emitError("unsupported composite type ") << cType; | |||
4417 | if (constituents.size() != cType.getNumElements()) | |||
4418 | return emitError("has incorrect number of operands: expected ") | |||
4419 | << cType.getNumElements() << ", but provided " | |||
4420 | << constituents.size(); | |||
4421 | ||||
4422 | for (auto index : llvm::seq<uint32_t>(0, constituents.size())) { | |||
4423 | auto constituent = constituents[index].cast<FlatSymbolRefAttr>(); | |||
4424 | ||||
4425 | auto constituentSpecConstOp = | |||
4426 | dyn_cast<spirv::SpecConstantOp>(SymbolTable::lookupNearestSymbolFrom( | |||
4427 | (*this)->getParentOp(), constituent.getAttr())); | |||
4428 | ||||
4429 | if (constituentSpecConstOp.getDefaultValue().getType() != | |||
4430 | cType.getElementType(index)) | |||
4431 | return emitError("has incorrect types of operands: expected ") | |||
4432 | << cType.getElementType(index) << ", but provided " | |||
4433 | << constituentSpecConstOp.getDefaultValue().getType(); | |||
4434 | } | |||
4435 | ||||
4436 | return success(); | |||
4437 | } | |||
4438 | ||||
4439 | //===----------------------------------------------------------------------===// | |||
4440 | // spirv.SpecConstantOperation | |||
4441 | //===----------------------------------------------------------------------===// | |||
4442 | ||||
4443 | ParseResult spirv::SpecConstantOperationOp::parse(OpAsmParser &parser, | |||
4444 | OperationState &result) { | |||
4445 | Region *body = result.addRegion(); | |||
4446 | ||||
4447 | if (parser.parseKeyword("wraps")) | |||
4448 | return failure(); | |||
4449 | ||||
4450 | body->push_back(new Block); | |||
4451 | Block &block = body->back(); | |||
4452 | Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin()); | |||
4453 | ||||
4454 | if (!wrappedOp) | |||
4455 | return failure(); | |||
4456 | ||||
4457 | OpBuilder builder(parser.getContext()); | |||
4458 | builder.setInsertionPointToEnd(&block); | |||
4459 | builder.create<spirv::YieldOp>(wrappedOp->getLoc(), wrappedOp->getResult(0)); | |||
4460 | result.location = wrappedOp->getLoc(); | |||
4461 | ||||
4462 | result.addTypes(wrappedOp->getResult(0).getType()); | |||
4463 | ||||
4464 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
4465 | return failure(); | |||
4466 | ||||
4467 | return success(); | |||
4468 | } | |||
4469 | ||||
4470 | void spirv::SpecConstantOperationOp::print(OpAsmPrinter &printer) { | |||
4471 | printer << " wraps "; | |||
4472 | printer.printGenericOp(&getBody().front().front()); | |||
4473 | } | |||
4474 | ||||
4475 | LogicalResult spirv::SpecConstantOperationOp::verifyRegions() { | |||
4476 | Block &block = getRegion().getBlocks().front(); | |||
4477 | ||||
4478 | if (block.getOperations().size() != 2) | |||
4479 | return emitOpError("expected exactly 2 nested ops"); | |||
4480 | ||||
4481 | Operation &enclosedOp = block.getOperations().front(); | |||
4482 | ||||
4483 | if (!enclosedOp.hasTrait<OpTrait::spirv::UsableInSpecConstantOp>()) | |||
4484 | return emitOpError("invalid enclosed op"); | |||
4485 | ||||
4486 | for (auto operand : enclosedOp.getOperands()) | |||
4487 | if (!isa<spirv::ConstantOp, spirv::ReferenceOfOp, | |||
4488 | spirv::SpecConstantOperationOp>(operand.getDefiningOp())) | |||
4489 | return emitOpError( | |||
4490 | "invalid operand, must be defined by a constant operation"); | |||
4491 | ||||
4492 | return success(); | |||
4493 | } | |||
4494 | ||||
4495 | //===----------------------------------------------------------------------===// | |||
4496 | // spirv.GL.FrexpStruct | |||
4497 | //===----------------------------------------------------------------------===// | |||
4498 | ||||
4499 | LogicalResult spirv::GLFrexpStructOp::verify() { | |||
4500 | spirv::StructType structTy = | |||
4501 | getResult().getType().dyn_cast<spirv::StructType>(); | |||
4502 | ||||
4503 | if (structTy.getNumElements() != 2) | |||
4504 | return emitError("result type must be a struct type with two memebers"); | |||
4505 | ||||
4506 | Type significandTy = structTy.getElementType(0); | |||
4507 | Type exponentTy = structTy.getElementType(1); | |||
4508 | VectorType exponentVecTy = exponentTy.dyn_cast<VectorType>(); | |||
4509 | IntegerType exponentIntTy = exponentTy.dyn_cast<IntegerType>(); | |||
4510 | ||||
4511 | Type operandTy = getOperand().getType(); | |||
4512 | VectorType operandVecTy = operandTy.dyn_cast<VectorType>(); | |||
4513 | FloatType operandFTy = operandTy.dyn_cast<FloatType>(); | |||
4514 | ||||
4515 | if (significandTy != operandTy) | |||
4516 | return emitError("member zero of the resulting struct type must be the " | |||
4517 | "same type as the operand"); | |||
4518 | ||||
4519 | if (exponentVecTy) { | |||
4520 | IntegerType componentIntTy = | |||
4521 | exponentVecTy.getElementType().dyn_cast<IntegerType>(); | |||
4522 | if (!componentIntTy || componentIntTy.getWidth() != 32) | |||
4523 | return emitError("member one of the resulting struct type must" | |||
4524 | "be a scalar or vector of 32 bit integer type"); | |||
4525 | } else if (!exponentIntTy || exponentIntTy.getWidth() != 32) { | |||
4526 | return emitError("member one of the resulting struct type " | |||
4527 | "must be a scalar or vector of 32 bit integer type"); | |||
4528 | } | |||
4529 | ||||
4530 | // Check that the two member types have the same number of components | |||
4531 | if (operandVecTy && exponentVecTy && | |||
4532 | (exponentVecTy.getNumElements() == operandVecTy.getNumElements())) | |||
4533 | return success(); | |||
4534 | ||||
4535 | if (operandFTy && exponentIntTy) | |||
4536 | return success(); | |||
4537 | ||||
4538 | return emitError("member one of the resulting struct type must have the same " | |||
4539 | "number of components as the operand type"); | |||
4540 | } | |||
4541 | ||||
4542 | //===----------------------------------------------------------------------===// | |||
4543 | // spirv.GL.Ldexp | |||
4544 | //===----------------------------------------------------------------------===// | |||
4545 | ||||
4546 | LogicalResult spirv::GLLdexpOp::verify() { | |||
4547 | Type significandType = getX().getType(); | |||
4548 | Type exponentType = getExp().getType(); | |||
4549 | ||||
4550 | if (significandType.isa<FloatType>() != exponentType.isa<IntegerType>()) | |||
4551 | return emitOpError("operands must both be scalars or vectors"); | |||
4552 | ||||
4553 | auto getNumElements = [](Type type) -> unsigned { | |||
4554 | if (auto vectorType = type.dyn_cast<VectorType>()) | |||
4555 | return vectorType.getNumElements(); | |||
4556 | return 1; | |||
4557 | }; | |||
4558 | ||||
4559 | if (getNumElements(significandType) != getNumElements(exponentType)) | |||
4560 | return emitOpError("operands must have the same number of elements"); | |||
4561 | ||||
4562 | return success(); | |||
4563 | } | |||
4564 | ||||
4565 | //===----------------------------------------------------------------------===// | |||
4566 | // spirv.ImageDrefGather | |||
4567 | //===----------------------------------------------------------------------===// | |||
4568 | ||||
4569 | LogicalResult spirv::ImageDrefGatherOp::verify() { | |||
4570 | VectorType resultType = getResult().getType().cast<VectorType>(); | |||
4571 | auto sampledImageType = | |||
4572 | getSampledimage().getType().cast<spirv::SampledImageType>(); | |||
4573 | auto imageType = sampledImageType.getImageType().cast<spirv::ImageType>(); | |||
4574 | ||||
4575 | if (resultType.getNumElements() != 4) | |||
4576 | return emitOpError("result type must be a vector of four components"); | |||
4577 | ||||
4578 | Type elementType = resultType.getElementType(); | |||
4579 | Type sampledElementType = imageType.getElementType(); | |||
4580 | if (!sampledElementType.isa<NoneType>() && elementType != sampledElementType) | |||
4581 | return emitOpError( | |||
4582 | "the component type of result must be the same as sampled type of the " | |||
4583 | "underlying image type"); | |||
4584 | ||||
4585 | spirv::Dim imageDim = imageType.getDim(); | |||
4586 | spirv::ImageSamplingInfo imageMS = imageType.getSamplingInfo(); | |||
4587 | ||||
4588 | if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && | |||
4589 | imageDim != spirv::Dim::Rect) | |||
4590 | return emitOpError( | |||
4591 | "the Dim operand of the underlying image type must be 2D, Cube, or " | |||
4592 | "Rect"); | |||
4593 | ||||
4594 | if (imageMS != spirv::ImageSamplingInfo::SingleSampled) | |||
4595 | return emitOpError("the MS operand of the underlying image type must be 0"); | |||
4596 | ||||
4597 | spirv::ImageOperandsAttr attr = getImageoperandsAttr(); | |||
4598 | auto operandArguments = getOperandArguments(); | |||
4599 | ||||
4600 | return verifyImageOperands(*this, attr, operandArguments); | |||
4601 | } | |||
4602 | ||||
4603 | //===----------------------------------------------------------------------===// | |||
4604 | // spirv.ShiftLeftLogicalOp | |||
4605 | //===----------------------------------------------------------------------===// | |||
4606 | ||||
4607 | LogicalResult spirv::ShiftLeftLogicalOp::verify() { | |||
4608 | return verifyShiftOp(*this); | |||
4609 | } | |||
4610 | ||||
4611 | //===----------------------------------------------------------------------===// | |||
4612 | // spirv.ShiftRightArithmeticOp | |||
4613 | //===----------------------------------------------------------------------===// | |||
4614 | ||||
4615 | LogicalResult spirv::ShiftRightArithmeticOp::verify() { | |||
4616 | return verifyShiftOp(*this); | |||
4617 | } | |||
4618 | ||||
4619 | //===----------------------------------------------------------------------===// | |||
4620 | // spirv.ShiftRightLogicalOp | |||
4621 | //===----------------------------------------------------------------------===// | |||
4622 | ||||
4623 | LogicalResult spirv::ShiftRightLogicalOp::verify() { | |||
4624 | return verifyShiftOp(*this); | |||
4625 | } | |||
4626 | ||||
4627 | //===----------------------------------------------------------------------===// | |||
4628 | // spirv.ImageQuerySize | |||
4629 | //===----------------------------------------------------------------------===// | |||
4630 | ||||
4631 | LogicalResult spirv::ImageQuerySizeOp::verify() { | |||
4632 | spirv::ImageType imageType = getImage().getType().cast<spirv::ImageType>(); | |||
4633 | Type resultType = getResult().getType(); | |||
4634 | ||||
4635 | spirv::Dim dim = imageType.getDim(); | |||
4636 | spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); | |||
4637 | spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo(); | |||
4638 | switch (dim) { | |||
4639 | case spirv::Dim::Dim1D: | |||
4640 | case spirv::Dim::Dim2D: | |||
4641 | case spirv::Dim::Dim3D: | |||
4642 | case spirv::Dim::Cube: | |||
4643 | if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled && | |||
4644 | samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown && | |||
4645 | samplerInfo != spirv::ImageSamplerUseInfo::NoSampler) | |||
4646 | return emitError( | |||
4647 | "if Dim is 1D, 2D, 3D, or Cube, " | |||
4648 | "it must also have either an MS of 1 or a Sampled of 0 or 2"); | |||
4649 | break; | |||
4650 | case spirv::Dim::Buffer: | |||
4651 | case spirv::Dim::Rect: | |||
4652 | break; | |||
4653 | default: | |||
4654 | return emitError("the Dim operand of the image type must " | |||
4655 | "be 1D, 2D, 3D, Buffer, Cube, or Rect"); | |||
4656 | } | |||
4657 | ||||
4658 | unsigned componentNumber = 0; | |||
4659 | switch (dim) { | |||
4660 | case spirv::Dim::Dim1D: | |||
4661 | case spirv::Dim::Buffer: | |||
4662 | componentNumber = 1; | |||
4663 | break; | |||
4664 | case spirv::Dim::Dim2D: | |||
4665 | case spirv::Dim::Cube: | |||
4666 | case spirv::Dim::Rect: | |||
4667 | componentNumber = 2; | |||
4668 | break; | |||
4669 | case spirv::Dim::Dim3D: | |||
4670 | componentNumber = 3; | |||
4671 | break; | |||
4672 | default: | |||
4673 | break; | |||
4674 | } | |||
4675 | ||||
4676 | if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed) | |||
4677 | componentNumber += 1; | |||
4678 | ||||
4679 | unsigned resultComponentNumber = 1; | |||
4680 | if (auto resultVectorType = resultType.dyn_cast<VectorType>()) | |||
4681 | resultComponentNumber = resultVectorType.getNumElements(); | |||
4682 | ||||
4683 | if (componentNumber != resultComponentNumber) | |||
4684 | return emitError("expected the result to have ") | |||
4685 | << componentNumber << " component(s), but found " | |||
4686 | << resultComponentNumber << " component(s)"; | |||
4687 | ||||
4688 | return success(); | |||
4689 | } | |||
4690 | ||||
4691 | static ParseResult parsePtrAccessChainOpImpl(StringRef opName, | |||
4692 | OpAsmParser &parser, | |||
4693 | OperationState &state) { | |||
4694 | OpAsmParser::UnresolvedOperand ptrInfo; | |||
4695 | SmallVector<OpAsmParser::UnresolvedOperand, 4> indicesInfo; | |||
4696 | Type type; | |||
4697 | auto loc = parser.getCurrentLocation(); | |||
4698 | SmallVector<Type, 4> indicesTypes; | |||
4699 | ||||
4700 | if (parser.parseOperand(ptrInfo) || | |||
4701 | parser.parseOperandList(indicesInfo, OpAsmParser::Delimiter::Square) || | |||
4702 | parser.parseColonType(type) || | |||
4703 | parser.resolveOperand(ptrInfo, type, state.operands)) | |||
4704 | return failure(); | |||
4705 | ||||
4706 | // Check that the provided indices list is not empty before parsing their | |||
4707 | // type list. | |||
4708 | if (indicesInfo.empty()) | |||
4709 | return emitError(state.location) << opName << " expected element"; | |||
4710 | ||||
4711 | if (parser.parseComma() || parser.parseTypeList(indicesTypes)) | |||
4712 | return failure(); | |||
4713 | ||||
4714 | // Check that the indices types list is not empty and that it has a one-to-one | |||
4715 | // mapping to the provided indices. | |||
4716 | if (indicesTypes.size() != indicesInfo.size()) | |||
4717 | return emitError(state.location) | |||
4718 | << opName | |||
4719 | << " indices types' count must be equal to indices info count"; | |||
4720 | ||||
4721 | if (parser.resolveOperands(indicesInfo, indicesTypes, loc, state.operands)) | |||
4722 | return failure(); | |||
4723 | ||||
4724 | auto resultType = getElementPtrType( | |||
4725 | type, llvm::ArrayRef(state.operands).drop_front(2), state.location); | |||
4726 | if (!resultType) | |||
4727 | return failure(); | |||
4728 | ||||
4729 | state.addTypes(resultType); | |||
4730 | return success(); | |||
4731 | } | |||
4732 | ||||
4733 | template <typename Op> | |||
4734 | static auto concatElemAndIndices(Op op) { | |||
4735 | SmallVector<Value> ret(op.getIndices().size() + 1); | |||
4736 | ret[0] = op.getElement(); | |||
4737 | llvm::copy(op.getIndices(), ret.begin() + 1); | |||
4738 | return ret; | |||
4739 | } | |||
4740 | ||||
4741 | //===----------------------------------------------------------------------===// | |||
4742 | // spirv.InBoundsPtrAccessChainOp | |||
4743 | //===----------------------------------------------------------------------===// | |||
4744 | ||||
4745 | void spirv::InBoundsPtrAccessChainOp::build(OpBuilder &builder, | |||
4746 | OperationState &state, | |||
4747 | Value basePtr, Value element, | |||
4748 | ValueRange indices) { | |||
4749 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
4750 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4750, __extension__ __PRETTY_FUNCTION__)); | |||
4751 | build(builder, state, type, basePtr, element, indices); | |||
4752 | } | |||
4753 | ||||
4754 | ParseResult spirv::InBoundsPtrAccessChainOp::parse(OpAsmParser &parser, | |||
4755 | OperationState &result) { | |||
4756 | return parsePtrAccessChainOpImpl( | |||
4757 | spirv::InBoundsPtrAccessChainOp::getOperationName(), parser, result); | |||
4758 | } | |||
4759 | ||||
4760 | void spirv::InBoundsPtrAccessChainOp::print(OpAsmPrinter &printer) { | |||
4761 | printAccessChain(*this, concatElemAndIndices(*this), printer); | |||
4762 | } | |||
4763 | ||||
4764 | LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { | |||
4765 | return verifyAccessChain(*this, getIndices()); | |||
4766 | } | |||
4767 | ||||
4768 | //===----------------------------------------------------------------------===// | |||
4769 | // spirv.PtrAccessChainOp | |||
4770 | //===----------------------------------------------------------------------===// | |||
4771 | ||||
4772 | void spirv::PtrAccessChainOp::build(OpBuilder &builder, OperationState &state, | |||
4773 | Value basePtr, Value element, | |||
4774 | ValueRange indices) { | |||
4775 | auto type = getElementPtrType(basePtr.getType(), indices, state.location); | |||
4776 | assert(type && "Unable to deduce return type based on basePtr and indices")(static_cast <bool> (type && "Unable to deduce return type based on basePtr and indices" ) ? void (0) : __assert_fail ("type && \"Unable to deduce return type based on basePtr and indices\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4776, __extension__ __PRETTY_FUNCTION__)); | |||
4777 | build(builder, state, type, basePtr, element, indices); | |||
4778 | } | |||
4779 | ||||
4780 | ParseResult spirv::PtrAccessChainOp::parse(OpAsmParser &parser, | |||
4781 | OperationState &result) { | |||
4782 | return parsePtrAccessChainOpImpl(spirv::PtrAccessChainOp::getOperationName(), | |||
4783 | parser, result); | |||
4784 | } | |||
4785 | ||||
4786 | void spirv::PtrAccessChainOp::print(OpAsmPrinter &printer) { | |||
4787 | printAccessChain(*this, concatElemAndIndices(*this), printer); | |||
4788 | } | |||
4789 | ||||
4790 | LogicalResult spirv::PtrAccessChainOp::verify() { | |||
4791 | return verifyAccessChain(*this, getIndices()); | |||
4792 | } | |||
4793 | ||||
4794 | //===----------------------------------------------------------------------===// | |||
4795 | // spirv.VectorTimesScalarOp | |||
4796 | //===----------------------------------------------------------------------===// | |||
4797 | ||||
4798 | LogicalResult spirv::VectorTimesScalarOp::verify() { | |||
4799 | if (getVector().getType() != getType()) | |||
4800 | return emitOpError("vector operand and result type mismatch"); | |||
4801 | auto scalarType = getType().cast<VectorType>().getElementType(); | |||
4802 | if (getScalar().getType() != scalarType) | |||
4803 | return emitOpError("scalar operand and result element type match"); | |||
4804 | return success(); | |||
4805 | } | |||
4806 | ||||
4807 | //===----------------------------------------------------------------------===// | |||
4808 | // Group ops | |||
4809 | //===----------------------------------------------------------------------===// | |||
4810 | ||||
4811 | template <typename Op> | |||
4812 | static LogicalResult verifyGroupOp(Op op) { | |||
4813 | spirv::Scope scope = op.getExecutionScope(); | |||
4814 | if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) | |||
4815 | return op.emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); | |||
4816 | ||||
4817 | return success(); | |||
4818 | } | |||
4819 | ||||
4820 | LogicalResult spirv::GroupIAddOp::verify() { return verifyGroupOp(*this); } | |||
4821 | ||||
4822 | LogicalResult spirv::GroupFAddOp::verify() { return verifyGroupOp(*this); } | |||
4823 | ||||
4824 | LogicalResult spirv::GroupFMinOp::verify() { return verifyGroupOp(*this); } | |||
4825 | ||||
4826 | LogicalResult spirv::GroupUMinOp::verify() { return verifyGroupOp(*this); } | |||
4827 | ||||
4828 | LogicalResult spirv::GroupSMinOp::verify() { return verifyGroupOp(*this); } | |||
4829 | ||||
4830 | LogicalResult spirv::GroupFMaxOp::verify() { return verifyGroupOp(*this); } | |||
4831 | ||||
4832 | LogicalResult spirv::GroupUMaxOp::verify() { return verifyGroupOp(*this); } | |||
4833 | ||||
4834 | LogicalResult spirv::GroupSMaxOp::verify() { return verifyGroupOp(*this); } | |||
4835 | ||||
4836 | LogicalResult spirv::GroupIMulKHROp::verify() { return verifyGroupOp(*this); } | |||
4837 | ||||
4838 | LogicalResult spirv::GroupFMulKHROp::verify() { return verifyGroupOp(*this); } | |||
4839 | ||||
4840 | //===----------------------------------------------------------------------===// | |||
4841 | // Integer Dot Product ops | |||
4842 | //===----------------------------------------------------------------------===// | |||
4843 | ||||
4844 | static LogicalResult verifyIntegerDotProduct(Operation *op) { | |||
4845 | assert(llvm::is_contained({2u, 3u}, op->getNumOperands()) &&(static_cast <bool> (llvm::is_contained({2u, 3u}, op-> getNumOperands()) && "Not an integer dot product op?" ) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4846, __extension__ __PRETTY_FUNCTION__)) | |||
4846 | "Not an integer dot product op?")(static_cast <bool> (llvm::is_contained({2u, 3u}, op-> getNumOperands()) && "Not an integer dot product op?" ) ? void (0) : __assert_fail ("llvm::is_contained({2u, 3u}, op->getNumOperands()) && \"Not an integer dot product op?\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4846, __extension__ __PRETTY_FUNCTION__)); | |||
4847 | assert(op->getNumResults() == 1 && "Expected a single result")(static_cast <bool> (op->getNumResults() == 1 && "Expected a single result") ? void (0) : __assert_fail ("op->getNumResults() == 1 && \"Expected a single result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4847, __extension__ __PRETTY_FUNCTION__)); | |||
4848 | ||||
4849 | Type factorTy = op->getOperand(0).getType(); | |||
4850 | if (op->getOperand(1).getType() != factorTy) | |||
4851 | return op->emitOpError("requires the same type for both vector operands"); | |||
4852 | ||||
4853 | unsigned expectedNumAttrs = 0; | |||
4854 | if (auto intTy = factorTy.dyn_cast<IntegerType>()) { | |||
4855 | ++expectedNumAttrs; | |||
4856 | auto packedVectorFormat = | |||
4857 | op->getAttr(kPackedVectorFormatAttrName) | |||
4858 | .dyn_cast_or_null<spirv::PackedVectorFormatAttr>(); | |||
4859 | if (!packedVectorFormat) | |||
4860 | return op->emitOpError("requires Packed Vector Format attribute for " | |||
4861 | "integer vector operands"); | |||
4862 | ||||
4863 | assert(packedVectorFormat.getValue() ==(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__ __PRETTY_FUNCTION__)) | |||
4864 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit &&(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__ __PRETTY_FUNCTION__)) | |||
4865 | "Unknown Packed Vector Format")(static_cast <bool> (packedVectorFormat.getValue() == spirv ::PackedVectorFormat::PackedVectorFormat4x8Bit && "Unknown Packed Vector Format" ) ? void (0) : __assert_fail ("packedVectorFormat.getValue() == spirv::PackedVectorFormat::PackedVectorFormat4x8Bit && \"Unknown Packed Vector Format\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp", 4865, __extension__ __PRETTY_FUNCTION__)); | |||
4866 | if (intTy.getWidth() != 32) | |||
4867 | return op->emitOpError( | |||
4868 | llvm::formatv("with specified Packed Vector Format ({0}) requires " | |||
4869 | "integer vector operands to be 32-bits wide", | |||
4870 | packedVectorFormat.getValue())); | |||
4871 | } else { | |||
4872 | if (op->hasAttr(kPackedVectorFormatAttrName)) | |||
4873 | return op->emitOpError(llvm::formatv( | |||
4874 | "with invalid format attribute for vector operands of type '{0}'", | |||
4875 | factorTy)); | |||
4876 | } | |||
4877 | ||||
4878 | if (op->getAttrs().size() > expectedNumAttrs) | |||
4879 | return op->emitError( | |||
4880 | "op only supports the 'format' #spirv.packed_vector_format attribute"); | |||
4881 | ||||
4882 | Type resultTy = op->getResultTypes().front(); | |||
4883 | bool hasAccumulator = op->getNumOperands() == 3; | |||
4884 | if (hasAccumulator && op->getOperand(2).getType() != resultTy) | |||
4885 | return op->emitOpError( | |||
4886 | "requires the same accumulator operand and result types"); | |||
4887 | ||||
4888 | unsigned factorBitWidth = getBitWidth(factorTy); | |||
4889 | unsigned resultBitWidth = getBitWidth(resultTy); | |||
4890 | if (factorBitWidth > resultBitWidth) | |||
4891 | return op->emitOpError( | |||
4892 | llvm::formatv("result type has insufficient bit-width ({0} bits) " | |||
4893 | "for the specified vector operand type ({1} bits)", | |||
4894 | resultBitWidth, factorBitWidth)); | |||
4895 | ||||
4896 | return success(); | |||
4897 | } | |||
4898 | ||||
4899 | static std::optional<spirv::Version> getIntegerDotProductMinVersion() { | |||
4900 | return spirv::Version::V_1_0; // Available in SPIR-V >= 1.0. | |||
4901 | } | |||
4902 | ||||
4903 | static std::optional<spirv::Version> getIntegerDotProductMaxVersion() { | |||
4904 | return spirv::Version::V_1_6; // Available in SPIR-V <= 1.6. | |||
4905 | } | |||
4906 | ||||
4907 | static SmallVector<ArrayRef<spirv::Extension>, 1> | |||
4908 | getIntegerDotProductExtensions() { | |||
4909 | // Requires the SPV_KHR_integer_dot_product extension, specified either | |||
4910 | // explicitly or implied by target env's SPIR-V version >= 1.6. | |||
4911 | static const auto extension = spirv::Extension::SPV_KHR_integer_dot_product; | |||
4912 | return {extension}; | |||
4913 | } | |||
4914 | ||||
4915 | static SmallVector<ArrayRef<spirv::Capability>, 1> | |||
4916 | getIntegerDotProductCapabilities(Operation *op) { | |||
4917 | // Requires the the DotProduct capability and capabilities that depend on | |||
4918 | // exact op types. | |||
4919 | static const auto dotProductCap = spirv::Capability::DotProduct; | |||
4920 | static const auto dotProductInput4x8BitPackedCap = | |||
4921 | spirv::Capability::DotProductInput4x8BitPacked; | |||
4922 | static const auto dotProductInput4x8BitCap = | |||
4923 | spirv::Capability::DotProductInput4x8Bit; | |||
4924 | static const auto dotProductInputAllCap = | |||
4925 | spirv::Capability::DotProductInputAll; | |||
4926 | ||||
4927 | SmallVector<ArrayRef<spirv::Capability>, 1> capabilities = {dotProductCap}; | |||
4928 | ||||
4929 | Type factorTy = op->getOperand(0).getType(); | |||
4930 | if (auto intTy = factorTy.dyn_cast<IntegerType>()) { | |||
4931 | auto formatAttr = op->getAttr(kPackedVectorFormatAttrName) | |||
4932 | .cast<spirv::PackedVectorFormatAttr>(); | |||
4933 | if (formatAttr.getValue() == | |||
4934 | spirv::PackedVectorFormat::PackedVectorFormat4x8Bit) | |||
4935 | capabilities.push_back(dotProductInput4x8BitPackedCap); | |||
4936 | ||||
4937 | return capabilities; | |||
4938 | } | |||
4939 | ||||
4940 | auto vecTy = factorTy.cast<VectorType>(); | |||
4941 | if (vecTy.getElementTypeBitWidth() == 8) { | |||
4942 | capabilities.push_back(dotProductInput4x8BitCap); | |||
4943 | return capabilities; | |||
4944 | } | |||
4945 | ||||
4946 | capabilities.push_back(dotProductInputAllCap); | |||
4947 | return capabilities; | |||
4948 | } | |||
4949 | ||||
4950 | #define SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(OpName) \ | |||
4951 | LogicalResult OpName::verify() { return verifyIntegerDotProduct(*this); } \ | |||
4952 | SmallVector<ArrayRef<spirv::Extension>, 1> OpName::getExtensions() { \ | |||
4953 | return getIntegerDotProductExtensions(); \ | |||
4954 | } \ | |||
4955 | SmallVector<ArrayRef<spirv::Capability>, 1> OpName::getCapabilities() { \ | |||
4956 | return getIntegerDotProductCapabilities(*this); \ | |||
4957 | } \ | |||
4958 | std::optional<spirv::Version> OpName::getMinVersion() { \ | |||
4959 | return getIntegerDotProductMinVersion(); \ | |||
4960 | } \ | |||
4961 | std::optional<spirv::Version> OpName::getMaxVersion() { \ | |||
4962 | return getIntegerDotProductMaxVersion(); \ | |||
4963 | } | |||
4964 | ||||
4965 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotOp) | |||
4966 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotOp) | |||
4967 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotOp) | |||
4968 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SDotAccSatOp) | |||
4969 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::SUDotAccSatOp) | |||
4970 | SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP(spirv::UDotAccSatOp) | |||
4971 | ||||
4972 | #undef SPIRV_IMPL_INTEGER_DOT_PRODUCT_OP | |||
4973 | ||||
4974 | // TableGen'erated operation interfaces for querying versions, extensions, and | |||
4975 | // capabilities. | |||
4976 | #include "mlir/Dialect/SPIRV/IR/SPIRVAvailability.cpp.inc" | |||
4977 | ||||
4978 | // TablenGen'erated operation definitions. | |||
4979 | #define GET_OP_CLASSES | |||
4980 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" | |||
4981 | ||||
4982 | namespace mlir { | |||
4983 | namespace spirv { | |||
4984 | // TableGen'erated operation availability interface implementations. | |||
4985 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpAvailabilityImpl.inc" | |||
4986 | } // namespace spirv | |||
4987 | } // namespace mlir |
1 | //===- Builders.h - Helpers for constructing MLIR Classes -------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #ifndef MLIR_IR_BUILDERS_H | |||
10 | #define MLIR_IR_BUILDERS_H | |||
11 | ||||
12 | #include "mlir/IR/OpDefinition.h" | |||
13 | #include "llvm/Support/Compiler.h" | |||
14 | #include <optional> | |||
15 | ||||
16 | namespace mlir { | |||
17 | ||||
18 | class AffineExpr; | |||
19 | class IRMapping; | |||
20 | class UnknownLoc; | |||
21 | class FileLineColLoc; | |||
22 | class Type; | |||
23 | class PrimitiveType; | |||
24 | class IntegerType; | |||
25 | class FloatType; | |||
26 | class FunctionType; | |||
27 | class IndexType; | |||
28 | class MemRefType; | |||
29 | class VectorType; | |||
30 | class RankedTensorType; | |||
31 | class UnrankedTensorType; | |||
32 | class TupleType; | |||
33 | class NoneType; | |||
34 | class BoolAttr; | |||
35 | class IntegerAttr; | |||
36 | class FloatAttr; | |||
37 | class StringAttr; | |||
38 | class TypeAttr; | |||
39 | class ArrayAttr; | |||
40 | class SymbolRefAttr; | |||
41 | class ElementsAttr; | |||
42 | class DenseElementsAttr; | |||
43 | class DenseIntElementsAttr; | |||
44 | class AffineMapAttr; | |||
45 | class AffineMap; | |||
46 | class UnitAttr; | |||
47 | ||||
48 | /// This class is a general helper class for creating context-global objects | |||
49 | /// like types, attributes, and affine expressions. | |||
50 | class Builder { | |||
51 | public: | |||
52 | explicit Builder(MLIRContext *context) : context(context) {} | |||
53 | explicit Builder(Operation *op) : Builder(op->getContext()) {} | |||
54 | ||||
55 | MLIRContext *getContext() const { return context; } | |||
56 | ||||
57 | // Locations. | |||
58 | Location getUnknownLoc(); | |||
59 | Location getFusedLoc(ArrayRef<Location> locs, | |||
60 | Attribute metadata = Attribute()); | |||
61 | ||||
62 | // Types. | |||
63 | FloatType getFloat8E5M2Type(); | |||
64 | FloatType getFloat8E4M3FNType(); | |||
65 | FloatType getFloat8E5M2FNUZType(); | |||
66 | FloatType getFloat8E4M3FNUZType(); | |||
67 | FloatType getFloat8E4M3B11FNUZType(); | |||
68 | FloatType getBF16Type(); | |||
69 | FloatType getF16Type(); | |||
70 | FloatType getF32Type(); | |||
71 | FloatType getF64Type(); | |||
72 | FloatType getF80Type(); | |||
73 | FloatType getF128Type(); | |||
74 | ||||
75 | IndexType getIndexType(); | |||
76 | ||||
77 | IntegerType getI1Type(); | |||
78 | IntegerType getI2Type(); | |||
79 | IntegerType getI4Type(); | |||
80 | IntegerType getI8Type(); | |||
81 | IntegerType getI16Type(); | |||
82 | IntegerType getI32Type(); | |||
83 | IntegerType getI64Type(); | |||
84 | IntegerType getIntegerType(unsigned width); | |||
85 | IntegerType getIntegerType(unsigned width, bool isSigned); | |||
86 | FunctionType getFunctionType(TypeRange inputs, TypeRange results); | |||
87 | TupleType getTupleType(TypeRange elementTypes); | |||
88 | NoneType getNoneType(); | |||
89 | ||||
90 | /// Get or construct an instance of the type `Ty` with provided arguments. | |||
91 | template <typename Ty, typename... Args> | |||
92 | Ty getType(Args &&...args) { | |||
93 | return Ty::get(context, std::forward<Args>(args)...); | |||
94 | } | |||
95 | ||||
96 | /// Get or construct an instance of the attribute `Attr` with provided | |||
97 | /// arguments. | |||
98 | template <typename Attr, typename... Args> | |||
99 | Attr getAttr(Args &&...args) { | |||
100 | return Attr::get(context, std::forward<Args>(args)...); | |||
| ||||
101 | } | |||
102 | ||||
103 | // Attributes. | |||
104 | NamedAttribute getNamedAttr(StringRef name, Attribute val); | |||
105 | ||||
106 | UnitAttr getUnitAttr(); | |||
107 | BoolAttr getBoolAttr(bool value); | |||
108 | DictionaryAttr getDictionaryAttr(ArrayRef<NamedAttribute> value); | |||
109 | IntegerAttr getIntegerAttr(Type type, int64_t value); | |||
110 | IntegerAttr getIntegerAttr(Type type, const APInt &value); | |||
111 | FloatAttr getFloatAttr(Type type, double value); | |||
112 | FloatAttr getFloatAttr(Type type, const APFloat &value); | |||
113 | StringAttr getStringAttr(const Twine &bytes); | |||
114 | ArrayAttr getArrayAttr(ArrayRef<Attribute> value); | |||
115 | ||||
116 | // Returns a 0-valued attribute of the given `type`. This function only | |||
117 | // supports boolean, integer, and 16-/32-/64-bit float types, and vector or | |||
118 | // ranked tensor of them. Returns null attribute otherwise. | |||
119 | TypedAttr getZeroAttr(Type type); | |||
120 | ||||
121 | // Convenience methods for fixed types. | |||
122 | FloatAttr getF16FloatAttr(float value); | |||
123 | FloatAttr getF32FloatAttr(float value); | |||
124 | FloatAttr getF64FloatAttr(double value); | |||
125 | ||||
126 | IntegerAttr getI8IntegerAttr(int8_t value); | |||
127 | IntegerAttr getI16IntegerAttr(int16_t value); | |||
128 | IntegerAttr getI32IntegerAttr(int32_t value); | |||
129 | IntegerAttr getI64IntegerAttr(int64_t value); | |||
130 | IntegerAttr getIndexAttr(int64_t value); | |||
131 | ||||
132 | /// Signed and unsigned integer attribute getters. | |||
133 | IntegerAttr getSI32IntegerAttr(int32_t value); | |||
134 | IntegerAttr getUI32IntegerAttr(uint32_t value); | |||
135 | ||||
136 | /// Vector-typed DenseIntElementsAttr getters. `values` must not be empty. | |||
137 | DenseIntElementsAttr getBoolVectorAttr(ArrayRef<bool> values); | |||
138 | DenseIntElementsAttr getI32VectorAttr(ArrayRef<int32_t> values); | |||
139 | DenseIntElementsAttr getI64VectorAttr(ArrayRef<int64_t> values); | |||
140 | DenseIntElementsAttr getIndexVectorAttr(ArrayRef<int64_t> values); | |||
141 | ||||
142 | /// Tensor-typed DenseIntElementsAttr getters. `values` can be empty. | |||
143 | /// These are generally preferable for representing general lists of integers | |||
144 | /// as attributes. | |||
145 | DenseIntElementsAttr getI32TensorAttr(ArrayRef<int32_t> values); | |||
146 | DenseIntElementsAttr getI64TensorAttr(ArrayRef<int64_t> values); | |||
147 | DenseIntElementsAttr getIndexTensorAttr(ArrayRef<int64_t> values); | |||
148 | ||||
149 | /// Tensor-typed DenseArrayAttr getters. | |||
150 | DenseBoolArrayAttr getDenseBoolArrayAttr(ArrayRef<bool> values); | |||
151 | DenseI8ArrayAttr getDenseI8ArrayAttr(ArrayRef<int8_t> values); | |||
152 | DenseI16ArrayAttr getDenseI16ArrayAttr(ArrayRef<int16_t> values); | |||
153 | DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef<int32_t> values); | |||
154 | DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef<int64_t> values); | |||
155 | DenseF32ArrayAttr getDenseF32ArrayAttr(ArrayRef<float> values); | |||
156 | DenseF64ArrayAttr getDenseF64ArrayAttr(ArrayRef<double> values); | |||
157 | ||||
158 | ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values); | |||
159 | ArrayAttr getBoolArrayAttr(ArrayRef<bool> values); | |||
160 | ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values); | |||
161 | ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values); | |||
162 | ArrayAttr getIndexArrayAttr(ArrayRef<int64_t> values); | |||
163 | ArrayAttr getF32ArrayAttr(ArrayRef<float> values); | |||
164 | ArrayAttr getF64ArrayAttr(ArrayRef<double> values); | |||
165 | ArrayAttr getStrArrayAttr(ArrayRef<StringRef> values); | |||
166 | ArrayAttr getTypeArrayAttr(TypeRange values); | |||
167 | ||||
168 | // Affine expressions and affine maps. | |||
169 | AffineExpr getAffineDimExpr(unsigned position); | |||
170 | AffineExpr getAffineSymbolExpr(unsigned position); | |||
171 | AffineExpr getAffineConstantExpr(int64_t constant); | |||
172 | ||||
173 | // Special cases of affine maps and integer sets | |||
174 | /// Returns a zero result affine map with no dimensions or symbols: () -> (). | |||
175 | AffineMap getEmptyAffineMap(); | |||
176 | /// Returns a single constant result affine map with 0 dimensions and 0 | |||
177 | /// symbols. One constant result: () -> (val). | |||
178 | AffineMap getConstantAffineMap(int64_t val); | |||
179 | // One dimension id identity map: (i) -> (i). | |||
180 | AffineMap getDimIdentityMap(); | |||
181 | // Multi-dimensional identity map: (d0, d1, d2) -> (d0, d1, d2). | |||
182 | AffineMap getMultiDimIdentityMap(unsigned rank); | |||
183 | // One symbol identity map: ()[s] -> (s). | |||
184 | AffineMap getSymbolIdentityMap(); | |||
185 | ||||
186 | /// Returns a map that shifts its (single) input dimension by 'shift'. | |||
187 | /// (d0) -> (d0 + shift) | |||
188 | AffineMap getSingleDimShiftAffineMap(int64_t shift); | |||
189 | ||||
190 | /// Returns an affine map that is a translation (shift) of all result | |||
191 | /// expressions in 'map' by 'shift'. | |||
192 | /// Eg: input: (d0, d1)[s0] -> (d0, d1 + s0), shift = 2 | |||
193 | /// returns: (d0, d1)[s0] -> (d0 + 2, d1 + s0 + 2) | |||
194 | AffineMap getShiftedAffineMap(AffineMap map, int64_t shift); | |||
195 | ||||
196 | protected: | |||
197 | MLIRContext *context; | |||
198 | }; | |||
199 | ||||
200 | /// This class helps build Operations. Operations that are created are | |||
201 | /// automatically inserted at an insertion point. The builder is copyable. | |||
202 | class OpBuilder : public Builder { | |||
203 | public: | |||
204 | struct Listener; | |||
205 | ||||
206 | /// Create a builder with the given context. | |||
207 | explicit OpBuilder(MLIRContext *ctx, Listener *listener = nullptr) | |||
208 | : Builder(ctx), listener(listener) {} | |||
209 | ||||
210 | /// Create a builder and set the insertion point to the start of the region. | |||
211 | explicit OpBuilder(Region *region, Listener *listener = nullptr) | |||
212 | : OpBuilder(region->getContext(), listener) { | |||
213 | if (!region->empty()) | |||
214 | setInsertionPoint(®ion->front(), region->front().begin()); | |||
215 | } | |||
216 | explicit OpBuilder(Region ®ion, Listener *listener = nullptr) | |||
217 | : OpBuilder(®ion, listener) {} | |||
218 | ||||
219 | /// Create a builder and set insertion point to the given operation, which | |||
220 | /// will cause subsequent insertions to go right before it. | |||
221 | explicit OpBuilder(Operation *op, Listener *listener = nullptr) | |||
222 | : OpBuilder(op->getContext(), listener) { | |||
223 | setInsertionPoint(op); | |||
224 | } | |||
225 | ||||
226 | OpBuilder(Block *block, Block::iterator insertPoint, | |||
227 | Listener *listener = nullptr) | |||
228 | : OpBuilder(block->getParent()->getContext(), listener) { | |||
229 | setInsertionPoint(block, insertPoint); | |||
230 | } | |||
231 | ||||
232 | /// Create a builder and set the insertion point to before the first operation | |||
233 | /// in the block but still inside the block. | |||
234 | static OpBuilder atBlockBegin(Block *block, Listener *listener = nullptr) { | |||
235 | return OpBuilder(block, block->begin(), listener); | |||
236 | } | |||
237 | ||||
238 | /// Create a builder and set the insertion point to after the last operation | |||
239 | /// in the block but still inside the block. | |||
240 | static OpBuilder atBlockEnd(Block *block, Listener *listener = nullptr) { | |||
241 | return OpBuilder(block, block->end(), listener); | |||
242 | } | |||
243 | ||||
244 | /// Create a builder and set the insertion point to before the block | |||
245 | /// terminator. | |||
246 | static OpBuilder atBlockTerminator(Block *block, | |||
247 | Listener *listener = nullptr) { | |||
248 | auto *terminator = block->getTerminator(); | |||
249 | assert(terminator != nullptr && "the block has no terminator")(static_cast <bool> (terminator != nullptr && "the block has no terminator" ) ? void (0) : __assert_fail ("terminator != nullptr && \"the block has no terminator\"" , "mlir/include/mlir/IR/Builders.h", 249, __extension__ __PRETTY_FUNCTION__ )); | |||
250 | return OpBuilder(block, Block::iterator(terminator), listener); | |||
251 | } | |||
252 | ||||
253 | //===--------------------------------------------------------------------===// | |||
254 | // Listeners | |||
255 | //===--------------------------------------------------------------------===// | |||
256 | ||||
257 | /// Base class for listeners. | |||
258 | struct ListenerBase { | |||
259 | /// The kind of listener. | |||
260 | enum class Kind { | |||
261 | /// OpBuilder::Listener or user-derived class. | |||
262 | OpBuilderListener = 0, | |||
263 | ||||
264 | /// RewriterBase::Listener or user-derived class. | |||
265 | RewriterBaseListener = 1 | |||
266 | }; | |||
267 | ||||
268 | Kind getKind() const { return kind; } | |||
269 | ||||
270 | protected: | |||
271 | ListenerBase(Kind kind) : kind(kind) {} | |||
272 | ||||
273 | private: | |||
274 | const Kind kind; | |||
275 | }; | |||
276 | ||||
277 | /// This class represents a listener that may be used to hook into various | |||
278 | /// actions within an OpBuilder. | |||
279 | struct Listener : public ListenerBase { | |||
280 | Listener() : ListenerBase(ListenerBase::Kind::OpBuilderListener) {} | |||
281 | ||||
282 | virtual ~Listener() = default; | |||
283 | ||||
284 | /// Notification handler for when an operation is inserted into the builder. | |||
285 | /// `op` is the operation that was inserted. | |||
286 | virtual void notifyOperationInserted(Operation *op) {} | |||
287 | ||||
288 | /// Notification handler for when a block is created using the builder. | |||
289 | /// `block` is the block that was created. | |||
290 | virtual void notifyBlockCreated(Block *block) {} | |||
291 | ||||
292 | protected: | |||
293 | Listener(Kind kind) : ListenerBase(kind) {} | |||
294 | }; | |||
295 | ||||
296 | /// Sets the listener of this builder to the one provided. | |||
297 | void setListener(Listener *newListener) { listener = newListener; } | |||
298 | ||||
299 | /// Returns the current listener of this builder, or nullptr if this builder | |||
300 | /// doesn't have a listener. | |||
301 | Listener *getListener() const { return listener; } | |||
302 | ||||
303 | //===--------------------------------------------------------------------===// | |||
304 | // Insertion Point Management | |||
305 | //===--------------------------------------------------------------------===// | |||
306 | ||||
307 | /// This class represents a saved insertion point. | |||
308 | class InsertPoint { | |||
309 | public: | |||
310 | /// Creates a new insertion point which doesn't point to anything. | |||
311 | InsertPoint() = default; | |||
312 | ||||
313 | /// Creates a new insertion point at the given location. | |||
314 | InsertPoint(Block *insertBlock, Block::iterator insertPt) | |||
315 | : block(insertBlock), point(insertPt) {} | |||
316 | ||||
317 | /// Returns true if this insert point is set. | |||
318 | bool isSet() const { return (block != nullptr); } | |||
319 | ||||
320 | Block *getBlock() const { return block; } | |||
321 | Block::iterator getPoint() const { return point; } | |||
322 | ||||
323 | private: | |||
324 | Block *block = nullptr; | |||
325 | Block::iterator point; | |||
326 | }; | |||
327 | ||||
328 | /// RAII guard to reset the insertion point of the builder when destroyed. | |||
329 | class InsertionGuard { | |||
330 | public: | |||
331 | InsertionGuard(OpBuilder &builder) | |||
332 | : builder(&builder), ip(builder.saveInsertionPoint()) {} | |||
333 | ||||
334 | ~InsertionGuard() { | |||
335 | if (builder) | |||
336 | builder->restoreInsertionPoint(ip); | |||
337 | } | |||
338 | ||||
339 | InsertionGuard(const InsertionGuard &) = delete; | |||
340 | InsertionGuard &operator=(const InsertionGuard &) = delete; | |||
341 | ||||
342 | /// Implement the move constructor to clear the builder field of `other`. | |||
343 | /// That way it does not restore the insertion point upon destruction as | |||
344 | /// that should be done exclusively by the just constructed InsertionGuard. | |||
345 | InsertionGuard(InsertionGuard &&other) noexcept | |||
346 | : builder(other.builder), ip(other.ip) { | |||
347 | other.builder = nullptr; | |||
348 | } | |||
349 | ||||
350 | InsertionGuard &operator=(InsertionGuard &&other) = delete; | |||
351 | ||||
352 | private: | |||
353 | OpBuilder *builder; | |||
354 | OpBuilder::InsertPoint ip; | |||
355 | }; | |||
356 | ||||
357 | /// Reset the insertion point to no location. Creating an operation without a | |||
358 | /// set insertion point is an error, but this can still be useful when the | |||
359 | /// current insertion point a builder refers to is being removed. | |||
360 | void clearInsertionPoint() { | |||
361 | this->block = nullptr; | |||
362 | insertPoint = Block::iterator(); | |||
363 | } | |||
364 | ||||
365 | /// Return a saved insertion point. | |||
366 | InsertPoint saveInsertionPoint() const { | |||
367 | return InsertPoint(getInsertionBlock(), getInsertionPoint()); | |||
368 | } | |||
369 | ||||
370 | /// Restore the insert point to a previously saved point. | |||
371 | void restoreInsertionPoint(InsertPoint ip) { | |||
372 | if (ip.isSet()) | |||
373 | setInsertionPoint(ip.getBlock(), ip.getPoint()); | |||
374 | else | |||
375 | clearInsertionPoint(); | |||
376 | } | |||
377 | ||||
378 | /// Set the insertion point to the specified location. | |||
379 | void setInsertionPoint(Block *block, Block::iterator insertPoint) { | |||
380 | // TODO: check that insertPoint is in this rather than some other block. | |||
381 | this->block = block; | |||
382 | this->insertPoint = insertPoint; | |||
383 | } | |||
384 | ||||
385 | /// Sets the insertion point to the specified operation, which will cause | |||
386 | /// subsequent insertions to go right before it. | |||
387 | void setInsertionPoint(Operation *op) { | |||
388 | setInsertionPoint(op->getBlock(), Block::iterator(op)); | |||
389 | } | |||
390 | ||||
391 | /// Sets the insertion point to the node after the specified operation, which | |||
392 | /// will cause subsequent insertions to go right after it. | |||
393 | void setInsertionPointAfter(Operation *op) { | |||
394 | setInsertionPoint(op->getBlock(), ++Block::iterator(op)); | |||
395 | } | |||
396 | ||||
397 | /// Sets the insertion point to the node after the specified value. If value | |||
398 | /// has a defining operation, sets the insertion point to the node after such | |||
399 | /// defining operation. This will cause subsequent insertions to go right | |||
400 | /// after it. Otherwise, value is a BlockArgument. Sets the insertion point to | |||
401 | /// the start of its block. | |||
402 | void setInsertionPointAfterValue(Value val) { | |||
403 | if (Operation *op = val.getDefiningOp()) { | |||
404 | setInsertionPointAfter(op); | |||
405 | } else { | |||
406 | auto blockArg = val.cast<BlockArgument>(); | |||
407 | setInsertionPointToStart(blockArg.getOwner()); | |||
408 | } | |||
409 | } | |||
410 | ||||
411 | /// Sets the insertion point to the start of the specified block. | |||
412 | void setInsertionPointToStart(Block *block) { | |||
413 | setInsertionPoint(block, block->begin()); | |||
414 | } | |||
415 | ||||
416 | /// Sets the insertion point to the end of the specified block. | |||
417 | void setInsertionPointToEnd(Block *block) { | |||
418 | setInsertionPoint(block, block->end()); | |||
419 | } | |||
420 | ||||
421 | /// Return the block the current insertion point belongs to. Note that the | |||
422 | /// insertion point is not necessarily the end of the block. | |||
423 | Block *getInsertionBlock() const { return block; } | |||
424 | ||||
425 | /// Returns the current insertion point of the builder. | |||
426 | Block::iterator getInsertionPoint() const { return insertPoint; } | |||
427 | ||||
428 | /// Returns the current block of the builder. | |||
429 | Block *getBlock() const { return block; } | |||
430 | ||||
431 | //===--------------------------------------------------------------------===// | |||
432 | // Block Creation | |||
433 | //===--------------------------------------------------------------------===// | |||
434 | ||||
435 | /// Add new block with 'argTypes' arguments and set the insertion point to the | |||
436 | /// end of it. The block is inserted at the provided insertion point of | |||
437 | /// 'parent'. `locs` contains the locations of the inserted arguments, and | |||
438 | /// should match the size of `argTypes`. | |||
439 | Block *createBlock(Region *parent, Region::iterator insertPt = {}, | |||
440 | TypeRange argTypes = std::nullopt, | |||
441 | ArrayRef<Location> locs = std::nullopt); | |||
442 | ||||
443 | /// Add new block with 'argTypes' arguments and set the insertion point to the | |||
444 | /// end of it. The block is placed before 'insertBefore'. `locs` contains the | |||
445 | /// locations of the inserted arguments, and should match the size of | |||
446 | /// `argTypes`. | |||
447 | Block *createBlock(Block *insertBefore, TypeRange argTypes = std::nullopt, | |||
448 | ArrayRef<Location> locs = std::nullopt); | |||
449 | ||||
450 | //===--------------------------------------------------------------------===// | |||
451 | // Operation Creation | |||
452 | //===--------------------------------------------------------------------===// | |||
453 | ||||
454 | /// Insert the given operation at the current insertion point and return it. | |||
455 | Operation *insert(Operation *op); | |||
456 | ||||
457 | /// Creates an operation given the fields represented as an OperationState. | |||
458 | Operation *create(const OperationState &state); | |||
459 | ||||
460 | /// Creates an operation with the given fields. | |||
461 | Operation *create(Location loc, StringAttr opName, ValueRange operands, | |||
462 | TypeRange types = {}, | |||
463 | ArrayRef<NamedAttribute> attributes = {}, | |||
464 | BlockRange successors = {}, | |||
465 | MutableArrayRef<std::unique_ptr<Region>> regions = {}); | |||
466 | ||||
467 | private: | |||
468 | /// Helper for sanity checking preconditions for create* methods below. | |||
469 | template <typename OpT> | |||
470 | RegisteredOperationName getCheckRegisteredInfo(MLIRContext *ctx) { | |||
471 | std::optional<RegisteredOperationName> opName = | |||
472 | RegisteredOperationName::lookup(OpT::getOperationName(), ctx); | |||
473 | if (LLVM_UNLIKELY(!opName)__builtin_expect((bool)(!opName), false)) { | |||
474 | llvm::report_fatal_error( | |||
475 | "Building op `" + OpT::getOperationName() + | |||
476 | "` but it isn't registered in this MLIRContext: the dialect may not " | |||
477 | "be loaded or this operation isn't registered by the dialect. See " | |||
478 | "also https://mlir.llvm.org/getting_started/Faq/" | |||
479 | "#registered-loaded-dependent-whats-up-with-dialects-management"); | |||
480 | } | |||
481 | return *opName; | |||
482 | } | |||
483 | ||||
484 | public: | |||
485 | /// Create an operation of specific op type at the current insertion point. | |||
486 | template <typename OpTy, typename... Args> | |||
487 | OpTy create(Location location, Args &&...args) { | |||
488 | OperationState state(location, | |||
489 | getCheckRegisteredInfo<OpTy>(location.getContext())); | |||
490 | OpTy::build(*this, state, std::forward<Args>(args)...); | |||
491 | auto *op = create(state); | |||
492 | auto result = dyn_cast<OpTy>(op); | |||
493 | assert(result && "builder didn't return the right type")(static_cast <bool> (result && "builder didn't return the right type" ) ? void (0) : __assert_fail ("result && \"builder didn't return the right type\"" , "mlir/include/mlir/IR/Builders.h", 493, __extension__ __PRETTY_FUNCTION__ )); | |||
494 | return result; | |||
495 | } | |||
496 | ||||
497 | /// Create an operation of specific op type at the current insertion point, | |||
498 | /// and immediately try to fold it. This functions populates 'results' with | |||
499 | /// the results after folding the operation. | |||
500 | template <typename OpTy, typename... Args> | |||
501 | void createOrFold(SmallVectorImpl<Value> &results, Location location, | |||
502 | Args &&...args) { | |||
503 | // Create the operation without using 'create' as we don't want to | |||
504 | // insert it yet. | |||
505 | OperationState state(location, | |||
506 | getCheckRegisteredInfo<OpTy>(location.getContext())); | |||
507 | OpTy::build(*this, state, std::forward<Args>(args)...); | |||
508 | Operation *op = Operation::create(state); | |||
509 | ||||
510 | // Fold the operation. If successful destroy it, otherwise insert it. | |||
511 | if (succeeded(tryFold(op, results))) | |||
512 | op->destroy(); | |||
513 | else | |||
514 | insert(op); | |||
515 | } | |||
516 | ||||
517 | /// Overload to create or fold a single result operation. | |||
518 | template <typename OpTy, typename... Args> | |||
519 | std::enable_if_t<OpTy::template hasTrait<OpTrait::OneResult>(), Value> | |||
520 | createOrFold(Location location, Args &&...args) { | |||
521 | SmallVector<Value, 1> results; | |||
522 | createOrFold<OpTy>(results, location, std::forward<Args>(args)...); | |||
523 | return results.front(); | |||
524 | } | |||
525 | ||||
526 | /// Overload to create or fold a zero result operation. | |||
527 | template <typename OpTy, typename... Args> | |||
528 | std::enable_if_t<OpTy::template hasTrait<OpTrait::ZeroResults>(), OpTy> | |||
529 | createOrFold(Location location, Args &&...args) { | |||
530 | auto op = create<OpTy>(location, std::forward<Args>(args)...); | |||
531 | SmallVector<Value, 0> unused; | |||
532 | (void)tryFold(op.getOperation(), unused); | |||
533 | ||||
534 | // Folding cannot remove a zero-result operation, so for convenience we | |||
535 | // continue to return it. | |||
536 | return op; | |||
537 | } | |||
538 | ||||
539 | /// Attempts to fold the given operation and places new results within | |||
540 | /// 'results'. Returns success if the operation was folded, failure otherwise. | |||
541 | /// Note: This function does not erase the operation on a successful fold. | |||
542 | LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results); | |||
543 | ||||
544 | /// Creates a deep copy of the specified operation, remapping any operands | |||
545 | /// that use values outside of the operation using the map that is provided | |||
546 | /// ( leaving them alone if no entry is present). Replaces references to | |||
547 | /// cloned sub-operations to the corresponding operation that is copied, | |||
548 | /// and adds those mappings to the map. | |||
549 | Operation *clone(Operation &op, IRMapping &mapper); | |||
550 | Operation *clone(Operation &op); | |||
551 | ||||
552 | /// Creates a deep copy of this operation but keep the operation regions | |||
553 | /// empty. Operands are remapped using `mapper` (if present), and `mapper` is | |||
554 | /// updated to contain the results. | |||
555 | Operation *cloneWithoutRegions(Operation &op, IRMapping &mapper) { | |||
556 | return insert(op.cloneWithoutRegions(mapper)); | |||
557 | } | |||
558 | Operation *cloneWithoutRegions(Operation &op) { | |||
559 | return insert(op.cloneWithoutRegions()); | |||
560 | } | |||
561 | template <typename OpT> | |||
562 | OpT cloneWithoutRegions(OpT op) { | |||
563 | return cast<OpT>(cloneWithoutRegions(*op.getOperation())); | |||
564 | } | |||
565 | ||||
566 | protected: | |||
567 | /// The optional listener for events of this builder. | |||
568 | Listener *listener; | |||
569 | ||||
570 | private: | |||
571 | /// The current block this builder is inserting into. | |||
572 | Block *block = nullptr; | |||
573 | /// The insertion point within the block that this builder is inserting | |||
574 | /// before. | |||
575 | Block::iterator insertPoint; | |||
576 | }; | |||
577 | ||||
578 | } // namespace mlir | |||
579 | ||||
580 | #endif |