File: | build/source/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp |
Warning: | line 349, column 10 2nd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- LLVMDialect.cpp - MLIR SPIR-V dialect ------------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file defines the SPIR-V dialect in MLIR. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | |||
14 | #include "mlir/Dialect/SPIRV/IR/ParserUtils.h" | |||
15 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" | |||
16 | #include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h" | |||
17 | #include "mlir/Dialect/SPIRV/IR/TargetAndABI.h" | |||
18 | #include "mlir/IR/Builders.h" | |||
19 | #include "mlir/IR/BuiltinTypes.h" | |||
20 | #include "mlir/IR/DialectImplementation.h" | |||
21 | #include "mlir/IR/MLIRContext.h" | |||
22 | #include "mlir/Parser/Parser.h" | |||
23 | #include "mlir/Transforms/InliningUtils.h" | |||
24 | #include "llvm/ADT/DenseMap.h" | |||
25 | #include "llvm/ADT/Sequence.h" | |||
26 | #include "llvm/ADT/SetVector.h" | |||
27 | #include "llvm/ADT/StringExtras.h" | |||
28 | #include "llvm/ADT/StringMap.h" | |||
29 | #include "llvm/ADT/TypeSwitch.h" | |||
30 | #include "llvm/Support/raw_ostream.h" | |||
31 | ||||
32 | using namespace mlir; | |||
33 | using namespace mlir::spirv; | |||
34 | ||||
35 | #include "mlir/Dialect/SPIRV/IR/SPIRVOpsDialect.cpp.inc" | |||
36 | ||||
37 | //===----------------------------------------------------------------------===// | |||
38 | // InlinerInterface | |||
39 | //===----------------------------------------------------------------------===// | |||
40 | ||||
41 | /// Returns true if the given region contains spirv.Return or spirv.ReturnValue | |||
42 | /// ops. | |||
43 | static inline bool containsReturn(Region ®ion) { | |||
44 | return llvm::any_of(region, [](Block &block) { | |||
45 | Operation *terminator = block.getTerminator(); | |||
46 | return isa<spirv::ReturnOp, spirv::ReturnValueOp>(terminator); | |||
47 | }); | |||
48 | } | |||
49 | ||||
50 | namespace { | |||
51 | /// This class defines the interface for inlining within the SPIR-V dialect. | |||
52 | struct SPIRVInlinerInterface : public DialectInlinerInterface { | |||
53 | using DialectInlinerInterface::DialectInlinerInterface; | |||
54 | ||||
55 | /// All call operations within SPIRV can be inlined. | |||
56 | bool isLegalToInline(Operation *call, Operation *callable, | |||
57 | bool wouldBeCloned) const final { | |||
58 | return true; | |||
59 | } | |||
60 | ||||
61 | /// Returns true if the given region 'src' can be inlined into the region | |||
62 | /// 'dest' that is attached to an operation registered to the current dialect. | |||
63 | bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, | |||
64 | BlockAndValueMapping &) const final { | |||
65 | // Return true here when inlining into spirv.func, spirv.mlir.selection, and | |||
66 | // spirv.mlir.loop operations. | |||
67 | auto *op = dest->getParentOp(); | |||
68 | return isa<spirv::FuncOp, spirv::SelectionOp, spirv::LoopOp>(op); | |||
69 | } | |||
70 | ||||
71 | /// Returns true if the given operation 'op', that is registered to this | |||
72 | /// dialect, can be inlined into the region 'dest' that is attached to an | |||
73 | /// operation registered to the current dialect. | |||
74 | bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned, | |||
75 | BlockAndValueMapping &) const final { | |||
76 | // TODO: Enable inlining structured control flows with return. | |||
77 | if ((isa<spirv::SelectionOp, spirv::LoopOp>(op)) && | |||
78 | containsReturn(op->getRegion(0))) | |||
79 | return false; | |||
80 | // TODO: we need to filter OpKill here to avoid inlining it to | |||
81 | // a loop continue construct: | |||
82 | // https://github.com/KhronosGroup/SPIRV-Headers/issues/86 | |||
83 | // However OpKill is fragment shader specific and we don't support it yet. | |||
84 | return true; | |||
85 | } | |||
86 | ||||
87 | /// Handle the given inlined terminator by replacing it with a new operation | |||
88 | /// as necessary. | |||
89 | void handleTerminator(Operation *op, Block *newDest) const final { | |||
90 | if (auto returnOp = dyn_cast<spirv::ReturnOp>(op)) { | |||
91 | OpBuilder(op).create<spirv::BranchOp>(op->getLoc(), newDest); | |||
92 | op->erase(); | |||
93 | } else if (auto retValOp = dyn_cast<spirv::ReturnValueOp>(op)) { | |||
94 | llvm_unreachable("unimplemented spirv.ReturnValue in inliner")::llvm::llvm_unreachable_internal("unimplemented spirv.ReturnValue in inliner" , "mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp", 94); | |||
95 | } | |||
96 | } | |||
97 | ||||
98 | /// Handle the given inlined terminator by replacing it with a new operation | |||
99 | /// as necessary. | |||
100 | void handleTerminator(Operation *op, | |||
101 | ArrayRef<Value> valuesToRepl) const final { | |||
102 | // Only spirv.ReturnValue needs to be handled here. | |||
103 | auto retValOp = dyn_cast<spirv::ReturnValueOp>(op); | |||
104 | if (!retValOp) | |||
105 | return; | |||
106 | ||||
107 | // Replace the values directly with the return operands. | |||
108 | assert(valuesToRepl.size() == 1 &&(static_cast <bool> (valuesToRepl.size() == 1 && "spirv.ReturnValue expected to only handle one result") ? void (0) : __assert_fail ("valuesToRepl.size() == 1 && \"spirv.ReturnValue expected to only handle one result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp", 109, __extension__ __PRETTY_FUNCTION__)) | |||
109 | "spirv.ReturnValue expected to only handle one result")(static_cast <bool> (valuesToRepl.size() == 1 && "spirv.ReturnValue expected to only handle one result") ? void (0) : __assert_fail ("valuesToRepl.size() == 1 && \"spirv.ReturnValue expected to only handle one result\"" , "mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp", 109, __extension__ __PRETTY_FUNCTION__)); | |||
110 | valuesToRepl.front().replaceAllUsesWith(retValOp.getValue()); | |||
111 | } | |||
112 | }; | |||
113 | } // namespace | |||
114 | ||||
115 | //===----------------------------------------------------------------------===// | |||
116 | // SPIR-V Dialect | |||
117 | //===----------------------------------------------------------------------===// | |||
118 | ||||
119 | void SPIRVDialect::initialize() { | |||
120 | registerAttributes(); | |||
121 | registerTypes(); | |||
122 | ||||
123 | // Add SPIR-V ops. | |||
124 | addOperations< | |||
125 | #define GET_OP_LIST | |||
126 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.cpp.inc" | |||
127 | >(); | |||
128 | ||||
129 | addInterfaces<SPIRVInlinerInterface>(); | |||
130 | ||||
131 | // Allow unknown operations because SPIR-V is extensible. | |||
132 | allowUnknownOperations(); | |||
133 | } | |||
134 | ||||
135 | std::string SPIRVDialect::getAttributeName(Decoration decoration) { | |||
136 | return llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)); | |||
137 | } | |||
138 | ||||
139 | //===----------------------------------------------------------------------===// | |||
140 | // Type Parsing | |||
141 | //===----------------------------------------------------------------------===// | |||
142 | ||||
143 | // Forward declarations. | |||
144 | template <typename ValTy> | |||
145 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, | |||
146 | DialectAsmParser &parser); | |||
147 | template <> | |||
148 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, | |||
149 | DialectAsmParser &parser); | |||
150 | ||||
151 | template <> | |||
152 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, | |||
153 | DialectAsmParser &parser); | |||
154 | ||||
155 | static Type parseAndVerifyType(SPIRVDialect const &dialect, | |||
156 | DialectAsmParser &parser) { | |||
157 | Type type; | |||
158 | SMLoc typeLoc = parser.getCurrentLocation(); | |||
159 | if (parser.parseType(type)) | |||
160 | return Type(); | |||
161 | ||||
162 | // Allow SPIR-V dialect types | |||
163 | if (&type.getDialect() == &dialect) | |||
164 | return type; | |||
165 | ||||
166 | // Check other allowed types | |||
167 | if (auto t = type.dyn_cast<FloatType>()) { | |||
168 | if (type.isBF16()) { | |||
169 | parser.emitError(typeLoc, "cannot use 'bf16' to compose SPIR-V types"); | |||
170 | return Type(); | |||
171 | } | |||
172 | } else if (auto t = type.dyn_cast<IntegerType>()) { | |||
173 | if (!ScalarType::isValid(t)) { | |||
174 | parser.emitError(typeLoc, | |||
175 | "only 1/8/16/32/64-bit integer type allowed but found ") | |||
176 | << type; | |||
177 | return Type(); | |||
178 | } | |||
179 | } else if (auto t = type.dyn_cast<VectorType>()) { | |||
180 | if (t.getRank() != 1) { | |||
181 | parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; | |||
182 | return Type(); | |||
183 | } | |||
184 | if (t.getNumElements() > 4) { | |||
185 | parser.emitError( | |||
186 | typeLoc, "vector length has to be less than or equal to 4 but found ") | |||
187 | << t.getNumElements(); | |||
188 | return Type(); | |||
189 | } | |||
190 | } else { | |||
191 | parser.emitError(typeLoc, "cannot use ") | |||
192 | << type << " to compose SPIR-V types"; | |||
193 | return Type(); | |||
194 | } | |||
195 | ||||
196 | return type; | |||
197 | } | |||
198 | ||||
199 | static Type parseAndVerifyMatrixType(SPIRVDialect const &dialect, | |||
200 | DialectAsmParser &parser) { | |||
201 | Type type; | |||
202 | SMLoc typeLoc = parser.getCurrentLocation(); | |||
203 | if (parser.parseType(type)) | |||
204 | return Type(); | |||
205 | ||||
206 | if (auto t = type.dyn_cast<VectorType>()) { | |||
207 | if (t.getRank() != 1) { | |||
208 | parser.emitError(typeLoc, "only 1-D vector allowed but found ") << t; | |||
209 | return Type(); | |||
210 | } | |||
211 | if (t.getNumElements() > 4 || t.getNumElements() < 2) { | |||
212 | parser.emitError(typeLoc, | |||
213 | "matrix columns size has to be less than or equal " | |||
214 | "to 4 and greater than or equal 2, but found ") | |||
215 | << t.getNumElements(); | |||
216 | return Type(); | |||
217 | } | |||
218 | ||||
219 | if (!t.getElementType().isa<FloatType>()) { | |||
220 | parser.emitError(typeLoc, "matrix columns' elements must be of " | |||
221 | "Float type, got ") | |||
222 | << t.getElementType(); | |||
223 | return Type(); | |||
224 | } | |||
225 | } else { | |||
226 | parser.emitError(typeLoc, "matrix must be composed using vector " | |||
227 | "type, got ") | |||
228 | << type; | |||
229 | return Type(); | |||
230 | } | |||
231 | ||||
232 | return type; | |||
233 | } | |||
234 | ||||
235 | static Type parseAndVerifySampledImageType(SPIRVDialect const &dialect, | |||
236 | DialectAsmParser &parser) { | |||
237 | Type type; | |||
238 | SMLoc typeLoc = parser.getCurrentLocation(); | |||
239 | if (parser.parseType(type)) | |||
240 | return Type(); | |||
241 | ||||
242 | if (!type.isa<ImageType>()) { | |||
243 | parser.emitError(typeLoc, | |||
244 | "sampled image must be composed using image type, got ") | |||
245 | << type; | |||
246 | return Type(); | |||
247 | } | |||
248 | ||||
249 | return type; | |||
250 | } | |||
251 | ||||
252 | /// Parses an optional `, stride = N` assembly segment. If no parsing failure | |||
253 | /// occurs, writes `N` to `stride` if existing and writes 0 to `stride` if | |||
254 | /// missing. | |||
255 | static LogicalResult parseOptionalArrayStride(const SPIRVDialect &dialect, | |||
256 | DialectAsmParser &parser, | |||
257 | unsigned &stride) { | |||
258 | if (failed(parser.parseOptionalComma())) { | |||
259 | stride = 0; | |||
260 | return success(); | |||
261 | } | |||
262 | ||||
263 | if (parser.parseKeyword("stride") || parser.parseEqual()) | |||
264 | return failure(); | |||
265 | ||||
266 | SMLoc strideLoc = parser.getCurrentLocation(); | |||
267 | std::optional<unsigned> optStride = parseAndVerify<unsigned>(dialect, parser); | |||
268 | if (!optStride) | |||
269 | return failure(); | |||
270 | ||||
271 | if (!(stride = *optStride)) { | |||
272 | parser.emitError(strideLoc, "ArrayStride must be greater than zero"); | |||
273 | return failure(); | |||
274 | } | |||
275 | return success(); | |||
276 | } | |||
277 | ||||
278 | // element-type ::= integer-type | |||
279 | // | floating-point-type | |||
280 | // | vector-type | |||
281 | // | spirv-type | |||
282 | // | |||
283 | // array-type ::= `!spirv.array` `<` integer-literal `x` element-type | |||
284 | // (`,` `stride` `=` integer-literal)? `>` | |||
285 | static Type parseArrayType(SPIRVDialect const &dialect, | |||
286 | DialectAsmParser &parser) { | |||
287 | if (parser.parseLess()) | |||
288 | return Type(); | |||
289 | ||||
290 | SmallVector<int64_t, 1> countDims; | |||
291 | SMLoc countLoc = parser.getCurrentLocation(); | |||
292 | if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) | |||
293 | return Type(); | |||
294 | if (countDims.size() != 1) { | |||
295 | parser.emitError(countLoc, | |||
296 | "expected single integer for array element count"); | |||
297 | return Type(); | |||
298 | } | |||
299 | ||||
300 | // According to the SPIR-V spec: | |||
301 | // "Length is the number of elements in the array. It must be at least 1." | |||
302 | int64_t count = countDims[0]; | |||
303 | if (count == 0) { | |||
304 | parser.emitError(countLoc, "expected array length greater than 0"); | |||
305 | return Type(); | |||
306 | } | |||
307 | ||||
308 | Type elementType = parseAndVerifyType(dialect, parser); | |||
309 | if (!elementType) | |||
310 | return Type(); | |||
311 | ||||
312 | unsigned stride = 0; | |||
313 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) | |||
314 | return Type(); | |||
315 | ||||
316 | if (parser.parseGreater()) | |||
317 | return Type(); | |||
318 | return ArrayType::get(elementType, count, stride); | |||
319 | } | |||
320 | ||||
321 | // cooperative-matrix-type ::= `!spirv.coopmatrix` `<` element-type ',' scope | |||
322 | // ',' | |||
323 | // rows ',' columns>` | |||
324 | static Type parseCooperativeMatrixType(SPIRVDialect const &dialect, | |||
325 | DialectAsmParser &parser) { | |||
326 | if (parser.parseLess()) | |||
327 | return Type(); | |||
328 | ||||
329 | SmallVector<int64_t, 2> dims; | |||
330 | SMLoc countLoc = parser.getCurrentLocation(); | |||
331 | if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) | |||
332 | return Type(); | |||
333 | ||||
334 | if (dims.size() != 2) { | |||
335 | parser.emitError(countLoc, "expected rows and columns size"); | |||
336 | return Type(); | |||
337 | } | |||
338 | ||||
339 | auto elementTy = parseAndVerifyType(dialect, parser); | |||
340 | if (!elementTy) | |||
341 | return Type(); | |||
342 | ||||
343 | Scope scope; | |||
344 | if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>")) | |||
345 | return Type(); | |||
346 | ||||
347 | if (parser.parseGreater()) | |||
348 | return Type(); | |||
349 | return CooperativeMatrixNVType::get(elementTy, scope, dims[0], dims[1]); | |||
| ||||
350 | } | |||
351 | ||||
352 | // joint-matrix-type ::= `!spirv.jointmatrix` `<`rows `x` columns `x` | |||
353 | // element-type | |||
354 | // `,` layout `,` scope`>` | |||
355 | static Type parseJointMatrixType(SPIRVDialect const &dialect, | |||
356 | DialectAsmParser &parser) { | |||
357 | if (parser.parseLess()) | |||
358 | return Type(); | |||
359 | ||||
360 | SmallVector<int64_t, 2> dims; | |||
361 | SMLoc countLoc = parser.getCurrentLocation(); | |||
362 | if (parser.parseDimensionList(dims, /*allowDynamic=*/false)) | |||
363 | return Type(); | |||
364 | ||||
365 | if (dims.size() != 2) { | |||
366 | parser.emitError(countLoc, "expected rows and columns size"); | |||
367 | return Type(); | |||
368 | } | |||
369 | ||||
370 | auto elementTy = parseAndVerifyType(dialect, parser); | |||
371 | if (!elementTy) | |||
372 | return Type(); | |||
373 | MatrixLayout matrixLayout; | |||
374 | if (parser.parseComma() || | |||
375 | parseEnumKeywordAttr(matrixLayout, parser, "matrixLayout <id>")) | |||
376 | return Type(); | |||
377 | Scope scope; | |||
378 | if (parser.parseComma() || parseEnumKeywordAttr(scope, parser, "scope <id>")) | |||
379 | return Type(); | |||
380 | if (parser.parseGreater()) | |||
381 | return Type(); | |||
382 | return JointMatrixINTELType::get(elementTy, scope, dims[0], dims[1], | |||
383 | matrixLayout); | |||
384 | } | |||
385 | ||||
386 | // TODO: Reorder methods to be utilities first and parse*Type | |||
387 | // methods in alphabetical order | |||
388 | // | |||
389 | // storage-class ::= `UniformConstant` | |||
390 | // | `Uniform` | |||
391 | // | `Workgroup` | |||
392 | // | <and other storage classes...> | |||
393 | // | |||
394 | // pointer-type ::= `!spirv.ptr<` element-type `,` storage-class `>` | |||
395 | static Type parsePointerType(SPIRVDialect const &dialect, | |||
396 | DialectAsmParser &parser) { | |||
397 | if (parser.parseLess()) | |||
398 | return Type(); | |||
399 | ||||
400 | auto pointeeType = parseAndVerifyType(dialect, parser); | |||
401 | if (!pointeeType) | |||
402 | return Type(); | |||
403 | ||||
404 | StringRef storageClassSpec; | |||
405 | SMLoc storageClassLoc = parser.getCurrentLocation(); | |||
406 | if (parser.parseComma() || parser.parseKeyword(&storageClassSpec)) | |||
407 | return Type(); | |||
408 | ||||
409 | auto storageClass = symbolizeStorageClass(storageClassSpec); | |||
410 | if (!storageClass) { | |||
411 | parser.emitError(storageClassLoc, "unknown storage class: ") | |||
412 | << storageClassSpec; | |||
413 | return Type(); | |||
414 | } | |||
415 | if (parser.parseGreater()) | |||
416 | return Type(); | |||
417 | return PointerType::get(pointeeType, *storageClass); | |||
418 | } | |||
419 | ||||
420 | // runtime-array-type ::= `!spirv.rtarray` `<` element-type | |||
421 | // (`,` `stride` `=` integer-literal)? `>` | |||
422 | static Type parseRuntimeArrayType(SPIRVDialect const &dialect, | |||
423 | DialectAsmParser &parser) { | |||
424 | if (parser.parseLess()) | |||
425 | return Type(); | |||
426 | ||||
427 | Type elementType = parseAndVerifyType(dialect, parser); | |||
428 | if (!elementType) | |||
429 | return Type(); | |||
430 | ||||
431 | unsigned stride = 0; | |||
432 | if (failed(parseOptionalArrayStride(dialect, parser, stride))) | |||
433 | return Type(); | |||
434 | ||||
435 | if (parser.parseGreater()) | |||
436 | return Type(); | |||
437 | return RuntimeArrayType::get(elementType, stride); | |||
438 | } | |||
439 | ||||
440 | // matrix-type ::= `!spirv.matrix` `<` integer-literal `x` element-type `>` | |||
441 | static Type parseMatrixType(SPIRVDialect const &dialect, | |||
442 | DialectAsmParser &parser) { | |||
443 | if (parser.parseLess()) | |||
444 | return Type(); | |||
445 | ||||
446 | SmallVector<int64_t, 1> countDims; | |||
447 | SMLoc countLoc = parser.getCurrentLocation(); | |||
448 | if (parser.parseDimensionList(countDims, /*allowDynamic=*/false)) | |||
449 | return Type(); | |||
450 | if (countDims.size() != 1) { | |||
451 | parser.emitError(countLoc, "expected single unsigned " | |||
452 | "integer for number of columns"); | |||
453 | return Type(); | |||
454 | } | |||
455 | ||||
456 | int64_t columnCount = countDims[0]; | |||
457 | // According to the specification, Matrices can have 2, 3, or 4 columns | |||
458 | if (columnCount < 2 || columnCount > 4) { | |||
459 | parser.emitError(countLoc, "matrix is expected to have 2, 3, or 4 " | |||
460 | "columns"); | |||
461 | return Type(); | |||
462 | } | |||
463 | ||||
464 | Type columnType = parseAndVerifyMatrixType(dialect, parser); | |||
465 | if (!columnType) | |||
466 | return Type(); | |||
467 | ||||
468 | if (parser.parseGreater()) | |||
469 | return Type(); | |||
470 | ||||
471 | return MatrixType::get(columnType, columnCount); | |||
472 | } | |||
473 | ||||
474 | // Specialize this function to parse each of the parameters that define an | |||
475 | // ImageType. By default it assumes this is an enum type. | |||
476 | template <typename ValTy> | |||
477 | static std::optional<ValTy> parseAndVerify(SPIRVDialect const &dialect, | |||
478 | DialectAsmParser &parser) { | |||
479 | StringRef enumSpec; | |||
480 | SMLoc enumLoc = parser.getCurrentLocation(); | |||
481 | if (parser.parseKeyword(&enumSpec)) { | |||
482 | return std::nullopt; | |||
483 | } | |||
484 | ||||
485 | auto val = spirv::symbolizeEnum<ValTy>(enumSpec); | |||
486 | if (!val) | |||
487 | parser.emitError(enumLoc, "unknown attribute: '") << enumSpec << "'"; | |||
488 | return val; | |||
489 | } | |||
490 | ||||
491 | template <> | |||
492 | std::optional<Type> parseAndVerify<Type>(SPIRVDialect const &dialect, | |||
493 | DialectAsmParser &parser) { | |||
494 | // TODO: Further verify that the element type can be sampled | |||
495 | auto ty = parseAndVerifyType(dialect, parser); | |||
496 | if (!ty) | |||
497 | return std::nullopt; | |||
498 | return ty; | |||
499 | } | |||
500 | ||||
501 | template <typename IntTy> | |||
502 | static std::optional<IntTy> parseAndVerifyInteger(SPIRVDialect const &dialect, | |||
503 | DialectAsmParser &parser) { | |||
504 | IntTy offsetVal = std::numeric_limits<IntTy>::max(); | |||
505 | if (parser.parseInteger(offsetVal)) | |||
506 | return std::nullopt; | |||
507 | return offsetVal; | |||
508 | } | |||
509 | ||||
510 | template <> | |||
511 | std::optional<unsigned> parseAndVerify<unsigned>(SPIRVDialect const &dialect, | |||
512 | DialectAsmParser &parser) { | |||
513 | return parseAndVerifyInteger<unsigned>(dialect, parser); | |||
514 | } | |||
515 | ||||
516 | namespace { | |||
517 | // Functor object to parse a comma separated list of specs. The function | |||
518 | // parseAndVerify does the actual parsing and verification of individual | |||
519 | // elements. This is a functor since parsing the last element of the list | |||
520 | // (termination condition) needs partial specialization. | |||
521 | template <typename ParseType, typename... Args> | |||
522 | struct ParseCommaSeparatedList { | |||
523 | std::optional<std::tuple<ParseType, Args...>> | |||
524 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { | |||
525 | auto parseVal = parseAndVerify<ParseType>(dialect, parser); | |||
526 | if (!parseVal) | |||
527 | return std::nullopt; | |||
528 | ||||
529 | auto numArgs = std::tuple_size<std::tuple<Args...>>::value; | |||
530 | if (numArgs != 0 && failed(parser.parseComma())) | |||
531 | return std::nullopt; | |||
532 | auto remainingValues = ParseCommaSeparatedList<Args...>{}(dialect, parser); | |||
533 | if (!remainingValues) | |||
534 | return std::nullopt; | |||
535 | return std::tuple_cat(std::tuple<ParseType>(parseVal.value()), | |||
536 | remainingValues.value()); | |||
537 | } | |||
538 | }; | |||
539 | ||||
540 | // Partial specialization of the function to parse a comma separated list of | |||
541 | // specs to parse the last element of the list. | |||
542 | template <typename ParseType> | |||
543 | struct ParseCommaSeparatedList<ParseType> { | |||
544 | std::optional<std::tuple<ParseType>> | |||
545 | operator()(SPIRVDialect const &dialect, DialectAsmParser &parser) const { | |||
546 | if (auto value = parseAndVerify<ParseType>(dialect, parser)) | |||
547 | return std::tuple<ParseType>(*value); | |||
548 | return std::nullopt; | |||
549 | } | |||
550 | }; | |||
551 | } // namespace | |||
552 | ||||
553 | // dim ::= `1D` | `2D` | `3D` | `Cube` | <and other SPIR-V Dim specifiers...> | |||
554 | // | |||
555 | // depth-info ::= `NoDepth` | `IsDepth` | `DepthUnknown` | |||
556 | // | |||
557 | // arrayed-info ::= `NonArrayed` | `Arrayed` | |||
558 | // | |||
559 | // sampling-info ::= `SingleSampled` | `MultiSampled` | |||
560 | // | |||
561 | // sampler-use-info ::= `SamplerUnknown` | `NeedSampler` | `NoSampler` | |||
562 | // | |||
563 | // format ::= `Unknown` | `Rgba32f` | <and other SPIR-V Image formats...> | |||
564 | // | |||
565 | // image-type ::= `!spirv.image<` element-type `,` dim `,` depth-info `,` | |||
566 | // arrayed-info `,` sampling-info `,` | |||
567 | // sampler-use-info `,` format `>` | |||
568 | static Type parseImageType(SPIRVDialect const &dialect, | |||
569 | DialectAsmParser &parser) { | |||
570 | if (parser.parseLess()) | |||
571 | return Type(); | |||
572 | ||||
573 | auto value = | |||
574 | ParseCommaSeparatedList<Type, Dim, ImageDepthInfo, ImageArrayedInfo, | |||
575 | ImageSamplingInfo, ImageSamplerUseInfo, | |||
576 | ImageFormat>{}(dialect, parser); | |||
577 | if (!value) | |||
578 | return Type(); | |||
579 | ||||
580 | if (parser.parseGreater()) | |||
581 | return Type(); | |||
582 | return ImageType::get(*value); | |||
583 | } | |||
584 | ||||
585 | // sampledImage-type :: = `!spirv.sampledImage<` image-type `>` | |||
586 | static Type parseSampledImageType(SPIRVDialect const &dialect, | |||
587 | DialectAsmParser &parser) { | |||
588 | if (parser.parseLess()) | |||
589 | return Type(); | |||
590 | ||||
591 | Type parsedType = parseAndVerifySampledImageType(dialect, parser); | |||
592 | if (!parsedType) | |||
593 | return Type(); | |||
594 | ||||
595 | if (parser.parseGreater()) | |||
596 | return Type(); | |||
597 | return SampledImageType::get(parsedType); | |||
598 | } | |||
599 | ||||
600 | // Parse decorations associated with a member. | |||
601 | static ParseResult parseStructMemberDecorations( | |||
602 | SPIRVDialect const &dialect, DialectAsmParser &parser, | |||
603 | ArrayRef<Type> memberTypes, | |||
604 | SmallVectorImpl<StructType::OffsetInfo> &offsetInfo, | |||
605 | SmallVectorImpl<StructType::MemberDecorationInfo> &memberDecorationInfo) { | |||
606 | ||||
607 | // Check if the first element is offset. | |||
608 | SMLoc offsetLoc = parser.getCurrentLocation(); | |||
609 | StructType::OffsetInfo offset = 0; | |||
610 | OptionalParseResult offsetParseResult = parser.parseOptionalInteger(offset); | |||
611 | if (offsetParseResult.has_value()) { | |||
612 | if (failed(*offsetParseResult)) | |||
613 | return failure(); | |||
614 | ||||
615 | if (offsetInfo.size() != memberTypes.size() - 1) { | |||
616 | return parser.emitError(offsetLoc, | |||
617 | "offset specification must be given for " | |||
618 | "all members"); | |||
619 | } | |||
620 | offsetInfo.push_back(offset); | |||
621 | } | |||
622 | ||||
623 | // Check for no spirv::Decorations. | |||
624 | if (succeeded(parser.parseOptionalRSquare())) | |||
625 | return success(); | |||
626 | ||||
627 | // If there was an offset, make sure to parse the comma. | |||
628 | if (offsetParseResult.has_value() && parser.parseComma()) | |||
629 | return failure(); | |||
630 | ||||
631 | // Check for spirv::Decorations. | |||
632 | auto parseDecorations = [&]() { | |||
633 | auto memberDecoration = parseAndVerify<spirv::Decoration>(dialect, parser); | |||
634 | if (!memberDecoration) | |||
635 | return failure(); | |||
636 | ||||
637 | // Parse member decoration value if it exists. | |||
638 | if (succeeded(parser.parseOptionalEqual())) { | |||
639 | auto memberDecorationValue = | |||
640 | parseAndVerifyInteger<uint32_t>(dialect, parser); | |||
641 | ||||
642 | if (!memberDecorationValue) | |||
643 | return failure(); | |||
644 | ||||
645 | memberDecorationInfo.emplace_back( | |||
646 | static_cast<uint32_t>(memberTypes.size() - 1), 1, | |||
647 | memberDecoration.value(), memberDecorationValue.value()); | |||
648 | } else { | |||
649 | memberDecorationInfo.emplace_back( | |||
650 | static_cast<uint32_t>(memberTypes.size() - 1), 0, | |||
651 | memberDecoration.value(), 0); | |||
652 | } | |||
653 | return success(); | |||
654 | }; | |||
655 | if (failed(parser.parseCommaSeparatedList(parseDecorations)) || | |||
656 | failed(parser.parseRSquare())) | |||
657 | return failure(); | |||
658 | ||||
659 | return success(); | |||
660 | } | |||
661 | ||||
662 | // struct-member-decoration ::= integer-literal? spirv-decoration* | |||
663 | // struct-type ::= | |||
664 | // `!spirv.struct<` (id `,`)? | |||
665 | // `(` | |||
666 | // (spirv-type (`[` struct-member-decoration `]`)?)* | |||
667 | // `)>` | |||
668 | static Type parseStructType(SPIRVDialect const &dialect, | |||
669 | DialectAsmParser &parser) { | |||
670 | // TODO: This function is quite lengthy. Break it down into smaller chunks. | |||
671 | ||||
672 | // To properly resolve recursive references while parsing recursive struct | |||
673 | // types, we need to maintain a list of enclosing struct type names. This set | |||
674 | // maintains the names of struct types in which the type we are about to parse | |||
675 | // is nested. | |||
676 | // | |||
677 | // Note: This has to be thread_local to enable multiple threads to safely | |||
678 | // parse concurrently. | |||
679 | thread_local SetVector<StringRef> structContext; | |||
680 | ||||
681 | static auto removeIdentifierAndFail = [](SetVector<StringRef> &structContext, | |||
682 | StringRef identifier) { | |||
683 | if (!identifier.empty()) | |||
684 | structContext.remove(identifier); | |||
685 | ||||
686 | return Type(); | |||
687 | }; | |||
688 | ||||
689 | if (parser.parseLess()) | |||
690 | return Type(); | |||
691 | ||||
692 | StringRef identifier; | |||
693 | ||||
694 | // Check if this is an identified struct type. | |||
695 | if (succeeded(parser.parseOptionalKeyword(&identifier))) { | |||
696 | // Check if this is a possible recursive reference. | |||
697 | if (succeeded(parser.parseOptionalGreater())) { | |||
698 | if (structContext.count(identifier) == 0) { | |||
699 | parser.emitError( | |||
700 | parser.getNameLoc(), | |||
701 | "recursive struct reference not nested in struct definition"); | |||
702 | ||||
703 | return Type(); | |||
704 | } | |||
705 | ||||
706 | return StructType::getIdentified(dialect.getContext(), identifier); | |||
707 | } | |||
708 | ||||
709 | if (failed(parser.parseComma())) | |||
710 | return Type(); | |||
711 | ||||
712 | if (structContext.count(identifier) != 0) { | |||
713 | parser.emitError(parser.getNameLoc(), | |||
714 | "identifier already used for an enclosing struct"); | |||
715 | ||||
716 | return removeIdentifierAndFail(structContext, identifier); | |||
717 | } | |||
718 | ||||
719 | structContext.insert(identifier); | |||
720 | } | |||
721 | ||||
722 | if (failed(parser.parseLParen())) | |||
723 | return removeIdentifierAndFail(structContext, identifier); | |||
724 | ||||
725 | if (succeeded(parser.parseOptionalRParen()) && | |||
726 | succeeded(parser.parseOptionalGreater())) { | |||
727 | if (!identifier.empty()) | |||
728 | structContext.remove(identifier); | |||
729 | ||||
730 | return StructType::getEmpty(dialect.getContext(), identifier); | |||
731 | } | |||
732 | ||||
733 | StructType idStructTy; | |||
734 | ||||
735 | if (!identifier.empty()) | |||
736 | idStructTy = StructType::getIdentified(dialect.getContext(), identifier); | |||
737 | ||||
738 | SmallVector<Type, 4> memberTypes; | |||
739 | SmallVector<StructType::OffsetInfo, 4> offsetInfo; | |||
740 | SmallVector<StructType::MemberDecorationInfo, 4> memberDecorationInfo; | |||
741 | ||||
742 | do { | |||
743 | Type memberType; | |||
744 | if (parser.parseType(memberType)) | |||
745 | return removeIdentifierAndFail(structContext, identifier); | |||
746 | memberTypes.push_back(memberType); | |||
747 | ||||
748 | if (succeeded(parser.parseOptionalLSquare())) | |||
749 | if (parseStructMemberDecorations(dialect, parser, memberTypes, offsetInfo, | |||
750 | memberDecorationInfo)) | |||
751 | return removeIdentifierAndFail(structContext, identifier); | |||
752 | } while (succeeded(parser.parseOptionalComma())); | |||
753 | ||||
754 | if (!offsetInfo.empty() && memberTypes.size() != offsetInfo.size()) { | |||
755 | parser.emitError(parser.getNameLoc(), | |||
756 | "offset specification must be given for all members"); | |||
757 | return removeIdentifierAndFail(structContext, identifier); | |||
758 | } | |||
759 | ||||
760 | if (failed(parser.parseRParen()) || failed(parser.parseGreater())) | |||
761 | return removeIdentifierAndFail(structContext, identifier); | |||
762 | ||||
763 | if (!identifier.empty()) { | |||
764 | if (failed(idStructTy.trySetBody(memberTypes, offsetInfo, | |||
765 | memberDecorationInfo))) | |||
766 | return Type(); | |||
767 | ||||
768 | structContext.remove(identifier); | |||
769 | return idStructTy; | |||
770 | } | |||
771 | ||||
772 | return StructType::get(memberTypes, offsetInfo, memberDecorationInfo); | |||
773 | } | |||
774 | ||||
775 | // spirv-type ::= array-type | |||
776 | // | element-type | |||
777 | // | image-type | |||
778 | // | pointer-type | |||
779 | // | runtime-array-type | |||
780 | // | sampled-image-type | |||
781 | // | struct-type | |||
782 | Type SPIRVDialect::parseType(DialectAsmParser &parser) const { | |||
783 | StringRef keyword; | |||
784 | if (parser.parseKeyword(&keyword)) | |||
| ||||
785 | return Type(); | |||
786 | ||||
787 | if (keyword == "array") | |||
788 | return parseArrayType(*this, parser); | |||
789 | if (keyword == "coopmatrix") | |||
790 | return parseCooperativeMatrixType(*this, parser); | |||
791 | if (keyword == "jointmatrix") | |||
792 | return parseJointMatrixType(*this, parser); | |||
793 | if (keyword == "image") | |||
794 | return parseImageType(*this, parser); | |||
795 | if (keyword == "ptr") | |||
796 | return parsePointerType(*this, parser); | |||
797 | if (keyword == "rtarray") | |||
798 | return parseRuntimeArrayType(*this, parser); | |||
799 | if (keyword == "sampled_image") | |||
800 | return parseSampledImageType(*this, parser); | |||
801 | if (keyword == "struct") | |||
802 | return parseStructType(*this, parser); | |||
803 | if (keyword == "matrix") | |||
804 | return parseMatrixType(*this, parser); | |||
805 | parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword; | |||
806 | return Type(); | |||
807 | } | |||
808 | ||||
809 | //===----------------------------------------------------------------------===// | |||
810 | // Type Printing | |||
811 | //===----------------------------------------------------------------------===// | |||
812 | ||||
813 | static void print(ArrayType type, DialectAsmPrinter &os) { | |||
814 | os << "array<" << type.getNumElements() << " x " << type.getElementType(); | |||
815 | if (unsigned stride = type.getArrayStride()) | |||
816 | os << ", stride=" << stride; | |||
817 | os << ">"; | |||
818 | } | |||
819 | ||||
820 | static void print(RuntimeArrayType type, DialectAsmPrinter &os) { | |||
821 | os << "rtarray<" << type.getElementType(); | |||
822 | if (unsigned stride = type.getArrayStride()) | |||
823 | os << ", stride=" << stride; | |||
824 | os << ">"; | |||
825 | } | |||
826 | ||||
827 | static void print(PointerType type, DialectAsmPrinter &os) { | |||
828 | os << "ptr<" << type.getPointeeType() << ", " | |||
829 | << stringifyStorageClass(type.getStorageClass()) << ">"; | |||
830 | } | |||
831 | ||||
832 | static void print(ImageType type, DialectAsmPrinter &os) { | |||
833 | os << "image<" << type.getElementType() << ", " << stringifyDim(type.getDim()) | |||
834 | << ", " << stringifyImageDepthInfo(type.getDepthInfo()) << ", " | |||
835 | << stringifyImageArrayedInfo(type.getArrayedInfo()) << ", " | |||
836 | << stringifyImageSamplingInfo(type.getSamplingInfo()) << ", " | |||
837 | << stringifyImageSamplerUseInfo(type.getSamplerUseInfo()) << ", " | |||
838 | << stringifyImageFormat(type.getImageFormat()) << ">"; | |||
839 | } | |||
840 | ||||
841 | static void print(SampledImageType type, DialectAsmPrinter &os) { | |||
842 | os << "sampled_image<" << type.getImageType() << ">"; | |||
843 | } | |||
844 | ||||
845 | static void print(StructType type, DialectAsmPrinter &os) { | |||
846 | thread_local SetVector<StringRef> structContext; | |||
847 | ||||
848 | os << "struct<"; | |||
849 | ||||
850 | if (type.isIdentified()) { | |||
851 | os << type.getIdentifier(); | |||
852 | ||||
853 | if (structContext.count(type.getIdentifier())) { | |||
854 | os << ">"; | |||
855 | return; | |||
856 | } | |||
857 | ||||
858 | os << ", "; | |||
859 | structContext.insert(type.getIdentifier()); | |||
860 | } | |||
861 | ||||
862 | os << "("; | |||
863 | ||||
864 | auto printMember = [&](unsigned i) { | |||
865 | os << type.getElementType(i); | |||
866 | SmallVector<spirv::StructType::MemberDecorationInfo, 0> decorations; | |||
867 | type.getMemberDecorations(i, decorations); | |||
868 | if (type.hasOffset() || !decorations.empty()) { | |||
869 | os << " ["; | |||
870 | if (type.hasOffset()) { | |||
871 | os << type.getMemberOffset(i); | |||
872 | if (!decorations.empty()) | |||
873 | os << ", "; | |||
874 | } | |||
875 | auto eachFn = [&os](spirv::StructType::MemberDecorationInfo decoration) { | |||
876 | os << stringifyDecoration(decoration.decoration); | |||
877 | if (decoration.hasValue) { | |||
878 | os << "=" << decoration.decorationValue; | |||
879 | } | |||
880 | }; | |||
881 | llvm::interleaveComma(decorations, os, eachFn); | |||
882 | os << "]"; | |||
883 | } | |||
884 | }; | |||
885 | llvm::interleaveComma(llvm::seq<unsigned>(0, type.getNumElements()), os, | |||
886 | printMember); | |||
887 | os << ")>"; | |||
888 | ||||
889 | if (type.isIdentified()) | |||
890 | structContext.remove(type.getIdentifier()); | |||
891 | } | |||
892 | ||||
893 | static void print(CooperativeMatrixNVType type, DialectAsmPrinter &os) { | |||
894 | os << "coopmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; | |||
895 | os << type.getElementType() << ", " << stringifyScope(type.getScope()); | |||
896 | os << ">"; | |||
897 | } | |||
898 | ||||
899 | static void print(JointMatrixINTELType type, DialectAsmPrinter &os) { | |||
900 | os << "jointmatrix<" << type.getRows() << "x" << type.getColumns() << "x"; | |||
901 | os << type.getElementType() << ", " | |||
902 | << stringifyMatrixLayout(type.getMatrixLayout()); | |||
903 | os << ", " << stringifyScope(type.getScope()) << ">"; | |||
904 | } | |||
905 | ||||
906 | static void print(MatrixType type, DialectAsmPrinter &os) { | |||
907 | os << "matrix<" << type.getNumColumns() << " x " << type.getColumnType(); | |||
908 | os << ">"; | |||
909 | } | |||
910 | ||||
911 | void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const { | |||
912 | TypeSwitch<Type>(type) | |||
913 | .Case<ArrayType, CooperativeMatrixNVType, JointMatrixINTELType, | |||
914 | PointerType, RuntimeArrayType, ImageType, SampledImageType, | |||
915 | StructType, MatrixType>([&](auto type) { print(type, os); }) | |||
916 | .Default([](Type) { llvm_unreachable("unhandled SPIR-V type")::llvm::llvm_unreachable_internal("unhandled SPIR-V type", "mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp" , 916); }); | |||
917 | } | |||
918 | ||||
919 | //===----------------------------------------------------------------------===// | |||
920 | // Constant | |||
921 | //===----------------------------------------------------------------------===// | |||
922 | ||||
923 | Operation *SPIRVDialect::materializeConstant(OpBuilder &builder, | |||
924 | Attribute value, Type type, | |||
925 | Location loc) { | |||
926 | if (!spirv::ConstantOp::isBuildableWith(type)) | |||
927 | return nullptr; | |||
928 | ||||
929 | return builder.create<spirv::ConstantOp>(loc, type, value); | |||
930 | } | |||
931 | ||||
932 | //===----------------------------------------------------------------------===// | |||
933 | // Shader Interface ABI | |||
934 | //===----------------------------------------------------------------------===// | |||
935 | ||||
936 | LogicalResult SPIRVDialect::verifyOperationAttribute(Operation *op, | |||
937 | NamedAttribute attribute) { | |||
938 | StringRef symbol = attribute.getName().strref(); | |||
939 | Attribute attr = attribute.getValue(); | |||
940 | ||||
941 | if (symbol == spirv::getEntryPointABIAttrName()) { | |||
942 | if (!attr.isa<spirv::EntryPointABIAttr>()) { | |||
943 | return op->emitError("'") | |||
944 | << symbol << "' attribute must be an entry point ABI attribute"; | |||
945 | } | |||
946 | } else if (symbol == spirv::getTargetEnvAttrName()) { | |||
947 | if (!attr.isa<spirv::TargetEnvAttr>()) | |||
948 | return op->emitError("'") << symbol << "' must be a spirv::TargetEnvAttr"; | |||
949 | } else { | |||
950 | return op->emitError("found unsupported '") | |||
951 | << symbol << "' attribute on operation"; | |||
952 | } | |||
953 | ||||
954 | return success(); | |||
955 | } | |||
956 | ||||
957 | /// Verifies the given SPIR-V `attribute` attached to a value of the given | |||
958 | /// `valueType` is valid. | |||
959 | static LogicalResult verifyRegionAttribute(Location loc, Type valueType, | |||
960 | NamedAttribute attribute) { | |||
961 | StringRef symbol = attribute.getName().strref(); | |||
962 | Attribute attr = attribute.getValue(); | |||
963 | ||||
964 | if (symbol != spirv::getInterfaceVarABIAttrName()) | |||
965 | return emitError(loc, "found unsupported '") | |||
966 | << symbol << "' attribute on region argument"; | |||
967 | ||||
968 | auto varABIAttr = attr.dyn_cast<spirv::InterfaceVarABIAttr>(); | |||
969 | if (!varABIAttr) | |||
970 | return emitError(loc, "'") | |||
971 | << symbol << "' must be a spirv::InterfaceVarABIAttr"; | |||
972 | ||||
973 | if (varABIAttr.getStorageClass() && !valueType.isIntOrIndexOrFloat()) | |||
974 | return emitError(loc, "'") << symbol | |||
975 | << "' attribute cannot specify storage class " | |||
976 | "when attaching to a non-scalar value"; | |||
977 | ||||
978 | return success(); | |||
979 | } | |||
980 | ||||
981 | LogicalResult SPIRVDialect::verifyRegionArgAttribute(Operation *op, | |||
982 | unsigned regionIndex, | |||
983 | unsigned argIndex, | |||
984 | NamedAttribute attribute) { | |||
985 | return verifyRegionAttribute( | |||
986 | op->getLoc(), op->getRegion(regionIndex).getArgument(argIndex).getType(), | |||
987 | attribute); | |||
988 | } | |||
989 | ||||
990 | LogicalResult SPIRVDialect::verifyRegionResultAttribute( | |||
991 | Operation *op, unsigned /*regionIndex*/, unsigned /*resultIndex*/, | |||
992 | NamedAttribute attribute) { | |||
993 | return op->emitError("cannot attach SPIR-V attributes to region result"); | |||
994 | } |
1 | //===------------ ParserUtils.h - Parse text to SPIR-V ops ----------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines utilities used for parsing types and ops for SPIR-V |
10 | // dialect. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ |
15 | #define MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ |
16 | |
17 | #include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "mlir/IR/OpImplementation.h" |
20 | |
21 | namespace mlir { |
22 | |
23 | /// Parses the next keyword in `parser` as an enumerant of the given |
24 | /// `EnumClass`. |
25 | template <typename EnumClass, typename ParserType> |
26 | static ParseResult |
27 | parseEnumKeywordAttr(EnumClass &value, ParserType &parser, |
28 | StringRef attrName = spirv::attributeName<EnumClass>()) { |
29 | StringRef keyword; |
30 | SmallVector<NamedAttribute, 1> attr; |
31 | auto loc = parser.getCurrentLocation(); |
32 | if (parser.parseKeyword(&keyword)) |
33 | return failure(); |
34 | if (std::optional<EnumClass> attr = |
35 | spirv::symbolizeEnum<EnumClass>(keyword)) { |
36 | value = *attr; |
37 | return success(); |
38 | } |
39 | return parser.emitError(loc, "invalid ") |
40 | << attrName << " attribute specification: " << keyword; |
41 | } |
42 | |
43 | } // namespace mlir |
44 | |
45 | #endif // MLIR_DIALECT_SPIRV_IR_PARSERUTILS_H_ |