File: | build/source/mlir/lib/Dialect/Quant/IR/TypeParser.cpp |
Warning: | line 87, column 22 The left operand of '>' is a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | |||||
9 | #include "mlir/Dialect/Quant/QuantOps.h" | ||||
10 | #include "mlir/Dialect/Quant/QuantTypes.h" | ||||
11 | #include "mlir/IR/BuiltinTypes.h" | ||||
12 | #include "mlir/IR/DialectImplementation.h" | ||||
13 | #include "mlir/IR/Location.h" | ||||
14 | #include "mlir/IR/Types.h" | ||||
15 | #include "llvm/ADT/APFloat.h" | ||||
16 | #include "llvm/Support/Format.h" | ||||
17 | #include "llvm/Support/MathExtras.h" | ||||
18 | #include "llvm/Support/SourceMgr.h" | ||||
19 | #include "llvm/Support/raw_ostream.h" | ||||
20 | |||||
21 | using namespace mlir; | ||||
22 | using namespace quant; | ||||
23 | |||||
24 | static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { | ||||
25 | auto typeLoc = parser.getCurrentLocation(); | ||||
26 | IntegerType type; | ||||
27 | |||||
28 | // Parse storage type (alpha_ident, integer_literal). | ||||
29 | StringRef identifier; | ||||
30 | unsigned storageTypeWidth = 0; | ||||
31 | OptionalParseResult result = parser.parseOptionalType(type); | ||||
32 | if (result.has_value()) { | ||||
33 | if (!succeeded(*result)) | ||||
34 | return nullptr; | ||||
35 | isSigned = !type.isUnsigned(); | ||||
36 | storageTypeWidth = type.getWidth(); | ||||
37 | } else if (succeeded(parser.parseKeyword(&identifier))) { | ||||
38 | // Otherwise, this must be an unsigned integer (`u` integer-literal). | ||||
39 | if (!identifier.consume_front("u")) { | ||||
40 | parser.emitError(typeLoc, "illegal storage type prefix"); | ||||
41 | return nullptr; | ||||
42 | } | ||||
43 | if (identifier.getAsInteger(10, storageTypeWidth)) { | ||||
44 | parser.emitError(typeLoc, "expected storage type width"); | ||||
45 | return nullptr; | ||||
46 | } | ||||
47 | isSigned = false; | ||||
48 | type = parser.getBuilder().getIntegerType(storageTypeWidth); | ||||
49 | } else { | ||||
50 | return nullptr; | ||||
51 | } | ||||
52 | |||||
53 | if (storageTypeWidth == 0 || | ||||
54 | storageTypeWidth > QuantizedType::MaxStorageBits) { | ||||
55 | parser.emitError(typeLoc, "illegal storage type size: ") | ||||
56 | << storageTypeWidth; | ||||
57 | return nullptr; | ||||
58 | } | ||||
59 | |||||
60 | return type; | ||||
61 | } | ||||
62 | |||||
63 | static ParseResult parseStorageRange(DialectAsmParser &parser, | ||||
64 | IntegerType storageType, bool isSigned, | ||||
65 | int64_t &storageTypeMin, | ||||
66 | int64_t &storageTypeMax) { | ||||
67 | int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( | ||||
68 | isSigned, storageType.getWidth()); | ||||
69 | int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( | ||||
70 | isSigned, storageType.getWidth()); | ||||
71 | if (failed(parser.parseOptionalLess())) { | ||||
72 | storageTypeMin = defaultIntegerMin; | ||||
73 | storageTypeMax = defaultIntegerMax; | ||||
74 | return success(); | ||||
75 | } | ||||
76 | |||||
77 | // Explicit storage min and storage max. | ||||
78 | SMLoc minLoc = parser.getCurrentLocation(), maxLoc; | ||||
79 | if (parser.parseInteger(storageTypeMin) || parser.parseColon() || | ||||
80 | parser.getCurrentLocation(&maxLoc) || | ||||
81 | parser.parseInteger(storageTypeMax) || parser.parseGreater()) | ||||
82 | return failure(); | ||||
83 | if (storageTypeMin < defaultIntegerMin) { | ||||
84 | return parser.emitError(minLoc, "illegal storage type minimum: ") | ||||
85 | << storageTypeMin; | ||||
86 | } | ||||
87 | if (storageTypeMax > defaultIntegerMax) { | ||||
| |||||
88 | return parser.emitError(maxLoc, "illegal storage type maximum: ") | ||||
89 | << storageTypeMax; | ||||
90 | } | ||||
91 | return success(); | ||||
92 | } | ||||
93 | |||||
94 | static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, | ||||
95 | double &min, double &max) { | ||||
96 | auto typeLoc = parser.getCurrentLocation(); | ||||
97 | FloatType type; | ||||
98 | |||||
99 | if (failed(parser.parseType(type))) { | ||||
100 | parser.emitError(typeLoc, "expecting float expressed type"); | ||||
101 | return nullptr; | ||||
102 | } | ||||
103 | |||||
104 | // Calibrated min and max values. | ||||
105 | if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() || | ||||
106 | parser.parseFloat(max) || parser.parseGreater()) { | ||||
107 | parser.emitError(typeLoc, "calibrated values must be present"); | ||||
108 | return nullptr; | ||||
109 | } | ||||
110 | return type; | ||||
111 | } | ||||
112 | |||||
113 | /// Parses an AnyQuantizedType. | ||||
114 | /// | ||||
115 | /// any ::= `any<` storage-spec (expressed-type-spec)?`>` | ||||
116 | /// storage-spec ::= storage-type (`<` storage-range `>`)? | ||||
117 | /// storage-range ::= integer-literal `:` integer-literal | ||||
118 | /// storage-type ::= (`i` | `u`) integer-literal | ||||
119 | /// expressed-type-spec ::= `:` `f` integer-literal | ||||
120 | static Type parseAnyType(DialectAsmParser &parser) { | ||||
121 | IntegerType storageType; | ||||
122 | FloatType expressedType; | ||||
123 | unsigned typeFlags = 0; | ||||
124 | int64_t storageTypeMin; | ||||
125 | int64_t storageTypeMax; | ||||
126 | |||||
127 | // Type specification. | ||||
128 | if (parser.parseLess()) | ||||
129 | return nullptr; | ||||
130 | |||||
131 | // Storage type. | ||||
132 | bool isSigned = false; | ||||
133 | storageType = parseStorageType(parser, isSigned); | ||||
134 | if (!storageType) { | ||||
135 | return nullptr; | ||||
136 | } | ||||
137 | if (isSigned) { | ||||
138 | typeFlags |= QuantizationFlags::Signed; | ||||
139 | } | ||||
140 | |||||
141 | // Storage type range. | ||||
142 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, | ||||
143 | storageTypeMax)) { | ||||
144 | return nullptr; | ||||
145 | } | ||||
146 | |||||
147 | // Optional expressed type. | ||||
148 | if (succeeded(parser.parseOptionalColon())) { | ||||
149 | if (parser.parseType(expressedType)) { | ||||
150 | return nullptr; | ||||
151 | } | ||||
152 | } | ||||
153 | |||||
154 | if (parser.parseGreater()) { | ||||
155 | return nullptr; | ||||
156 | } | ||||
157 | |||||
158 | return parser.getChecked<AnyQuantizedType>( | ||||
159 | typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); | ||||
160 | } | ||||
161 | |||||
162 | static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, | ||||
163 | int64_t &zeroPoint) { | ||||
164 | // scale[:zeroPoint]? | ||||
165 | // scale. | ||||
166 | if (parser.parseFloat(scale)) | ||||
167 | return failure(); | ||||
168 | |||||
169 | // zero point. | ||||
170 | zeroPoint = 0; | ||||
171 | if (failed(parser.parseOptionalColon())) { | ||||
172 | // Default zero point. | ||||
173 | return success(); | ||||
174 | } | ||||
175 | |||||
176 | return parser.parseInteger(zeroPoint); | ||||
177 | } | ||||
178 | |||||
179 | /// Parses a UniformQuantizedType. | ||||
180 | /// | ||||
181 | /// uniform_type ::= uniform_per_layer | ||||
182 | /// | uniform_per_axis | ||||
183 | /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec | ||||
184 | /// `,` scale-zero `>` | ||||
185 | /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec | ||||
186 | /// axis-spec `,` scale-zero-list `>` | ||||
187 | /// storage-spec ::= storage-type (`<` storage-range `>`)? | ||||
188 | /// storage-range ::= integer-literal `:` integer-literal | ||||
189 | /// storage-type ::= (`i` | `u`) integer-literal | ||||
190 | /// expressed-type-spec ::= `:` `f` integer-literal | ||||
191 | /// axis-spec ::= `:` integer-literal | ||||
192 | /// scale-zero ::= float-literal `:` integer-literal | ||||
193 | /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` | ||||
194 | static Type parseUniformType(DialectAsmParser &parser) { | ||||
195 | IntegerType storageType; | ||||
196 | FloatType expressedType; | ||||
197 | unsigned typeFlags = 0; | ||||
198 | int64_t storageTypeMin; | ||||
199 | int64_t storageTypeMax; | ||||
200 | bool isPerAxis = false; | ||||
201 | int32_t quantizedDimension; | ||||
202 | SmallVector<double, 1> scales; | ||||
203 | SmallVector<int64_t, 1> zeroPoints; | ||||
204 | |||||
205 | // Type specification. | ||||
206 | if (parser.parseLess()) { | ||||
207 | return nullptr; | ||||
208 | } | ||||
209 | |||||
210 | // Storage type. | ||||
211 | bool isSigned = false; | ||||
212 | storageType = parseStorageType(parser, isSigned); | ||||
213 | if (!storageType) { | ||||
214 | return nullptr; | ||||
215 | } | ||||
216 | if (isSigned
| ||||
217 | typeFlags |= QuantizationFlags::Signed; | ||||
218 | } | ||||
219 | |||||
220 | // Storage type range. | ||||
221 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, | ||||
222 | storageTypeMax)) { | ||||
223 | return nullptr; | ||||
224 | } | ||||
225 | |||||
226 | // Expressed type. | ||||
227 | if (parser.parseColon() || parser.parseType(expressedType)) { | ||||
228 | return nullptr; | ||||
229 | } | ||||
230 | |||||
231 | // Optionally parse quantized dimension for per-axis quantization. | ||||
232 | if (succeeded(parser.parseOptionalColon())) { | ||||
233 | if (parser.parseInteger(quantizedDimension)) | ||||
234 | return nullptr; | ||||
235 | isPerAxis = true; | ||||
236 | } | ||||
237 | |||||
238 | // Comma leading into range_spec. | ||||
239 | if (parser.parseComma()) { | ||||
240 | return nullptr; | ||||
241 | } | ||||
242 | |||||
243 | // Parameter specification. | ||||
244 | // For per-axis, ranges are in a {} delimitted list. | ||||
245 | if (isPerAxis) { | ||||
246 | if (parser.parseLBrace()) { | ||||
247 | return nullptr; | ||||
248 | } | ||||
249 | } | ||||
250 | |||||
251 | // Parse scales/zeroPoints. | ||||
252 | SMLoc scaleZPLoc = parser.getCurrentLocation(); | ||||
253 | do { | ||||
254 | scales.resize(scales.size() + 1); | ||||
255 | zeroPoints.resize(zeroPoints.size() + 1); | ||||
256 | if (parseQuantParams(parser, scales.back(), zeroPoints.back())) { | ||||
257 | return nullptr; | ||||
258 | } | ||||
259 | } while (isPerAxis && succeeded(parser.parseOptionalComma())); | ||||
260 | |||||
261 | if (isPerAxis) { | ||||
262 | if (parser.parseRBrace()) { | ||||
263 | return nullptr; | ||||
264 | } | ||||
265 | } | ||||
266 | |||||
267 | if (parser.parseGreater()) { | ||||
268 | return nullptr; | ||||
269 | } | ||||
270 | |||||
271 | if (!isPerAxis && scales.size() > 1) { | ||||
272 | return (parser.emitError(scaleZPLoc, | ||||
273 | "multiple scales/zeroPoints provided, but " | ||||
274 | "quantizedDimension wasn't specified"), | ||||
275 | nullptr); | ||||
276 | } | ||||
277 | |||||
278 | if (isPerAxis) { | ||||
279 | ArrayRef<double> scalesRef(scales.begin(), scales.end()); | ||||
280 | ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); | ||||
281 | return parser.getChecked<UniformQuantizedPerAxisType>( | ||||
282 | typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, | ||||
283 | quantizedDimension, storageTypeMin, storageTypeMax); | ||||
284 | } | ||||
285 | |||||
286 | return parser.getChecked<UniformQuantizedType>( | ||||
287 | typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), | ||||
288 | storageTypeMin, storageTypeMax); | ||||
289 | } | ||||
290 | |||||
291 | /// Parses an CalibratedQuantizedType. | ||||
292 | /// | ||||
293 | /// calibrated ::= `calibrated<` expressed-spec `>` | ||||
294 | /// expressed-spec ::= expressed-type `<` calibrated-range `>` | ||||
295 | /// expressed-type ::= `f` integer-literal | ||||
296 | /// calibrated-range ::= float-literal `:` float-literal | ||||
297 | static Type parseCalibratedType(DialectAsmParser &parser) { | ||||
298 | FloatType expressedType; | ||||
299 | double min; | ||||
300 | double max; | ||||
301 | |||||
302 | // Type specification. | ||||
303 | if (parser.parseLess()) | ||||
304 | return nullptr; | ||||
305 | |||||
306 | // Expressed type. | ||||
307 | expressedType = parseExpressedTypeAndRange(parser, min, max); | ||||
308 | if (!expressedType) { | ||||
309 | return nullptr; | ||||
310 | } | ||||
311 | |||||
312 | if (parser.parseGreater()) { | ||||
313 | return nullptr; | ||||
314 | } | ||||
315 | |||||
316 | return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max); | ||||
317 | } | ||||
318 | |||||
319 | /// Parse a type registered to this dialect. | ||||
320 | Type QuantizationDialect::parseType(DialectAsmParser &parser) const { | ||||
321 | // All types start with an identifier that we switch on. | ||||
322 | StringRef typeNameSpelling; | ||||
323 | if (failed(parser.parseKeyword(&typeNameSpelling))) | ||||
| |||||
324 | return nullptr; | ||||
325 | |||||
326 | if (typeNameSpelling == "uniform") | ||||
327 | return parseUniformType(parser); | ||||
328 | if (typeNameSpelling == "any") | ||||
329 | return parseAnyType(parser); | ||||
330 | if (typeNameSpelling == "calibrated") | ||||
331 | return parseCalibratedType(parser); | ||||
332 | |||||
333 | parser.emitError(parser.getNameLoc(), | ||||
334 | "unknown quantized type " + typeNameSpelling); | ||||
335 | return nullptr; | ||||
336 | } | ||||
337 | |||||
338 | static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { | ||||
339 | // storage type | ||||
340 | unsigned storageWidth = type.getStorageTypeIntegralWidth(); | ||||
341 | bool isSigned = type.isSigned(); | ||||
342 | if (isSigned) { | ||||
343 | out << "i" << storageWidth; | ||||
344 | } else { | ||||
345 | out << "u" << storageWidth; | ||||
346 | } | ||||
347 | |||||
348 | // storageTypeMin and storageTypeMax if not default. | ||||
349 | int64_t defaultIntegerMin = | ||||
350 | QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth); | ||||
351 | int64_t defaultIntegerMax = | ||||
352 | QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth); | ||||
353 | if (defaultIntegerMin != type.getStorageTypeMin() || | ||||
354 | defaultIntegerMax != type.getStorageTypeMax()) { | ||||
355 | out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() | ||||
356 | << ">"; | ||||
357 | } | ||||
358 | } | ||||
359 | |||||
360 | static void printQuantParams(double scale, int64_t zeroPoint, | ||||
361 | DialectAsmPrinter &out) { | ||||
362 | out << scale; | ||||
363 | if (zeroPoint != 0) { | ||||
364 | out << ":" << zeroPoint; | ||||
365 | } | ||||
366 | } | ||||
367 | |||||
368 | /// Helper that prints a AnyQuantizedType. | ||||
369 | static void printAnyQuantizedType(AnyQuantizedType type, | ||||
370 | DialectAsmPrinter &out) { | ||||
371 | out << "any<"; | ||||
372 | printStorageType(type, out); | ||||
373 | if (Type expressedType = type.getExpressedType()) { | ||||
374 | out << ":" << expressedType; | ||||
375 | } | ||||
376 | out << ">"; | ||||
377 | } | ||||
378 | |||||
379 | /// Helper that prints a UniformQuantizedType. | ||||
380 | static void printUniformQuantizedType(UniformQuantizedType type, | ||||
381 | DialectAsmPrinter &out) { | ||||
382 | out << "uniform<"; | ||||
383 | printStorageType(type, out); | ||||
384 | out << ":" << type.getExpressedType() << ", "; | ||||
385 | |||||
386 | // scheme specific parameters | ||||
387 | printQuantParams(type.getScale(), type.getZeroPoint(), out); | ||||
388 | out << ">"; | ||||
389 | } | ||||
390 | |||||
391 | /// Helper that prints a UniformQuantizedPerAxisType. | ||||
392 | static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, | ||||
393 | DialectAsmPrinter &out) { | ||||
394 | out << "uniform<"; | ||||
395 | printStorageType(type, out); | ||||
396 | out << ":" << type.getExpressedType() << ":"; | ||||
397 | out << type.getQuantizedDimension(); | ||||
398 | out << ", "; | ||||
399 | |||||
400 | // scheme specific parameters | ||||
401 | ArrayRef<double> scales = type.getScales(); | ||||
402 | ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); | ||||
403 | out << "{"; | ||||
404 | llvm::interleave( | ||||
405 | llvm::seq<size_t>(0, scales.size()), out, | ||||
406 | [&](size_t index) { | ||||
407 | printQuantParams(scales[index], zeroPoints[index], out); | ||||
408 | }, | ||||
409 | ","); | ||||
410 | out << "}>"; | ||||
411 | } | ||||
412 | |||||
413 | /// Helper that prints a CalibratedQuantizedType. | ||||
414 | static void printCalibratedQuantizedType(CalibratedQuantizedType type, | ||||
415 | DialectAsmPrinter &out) { | ||||
416 | out << "calibrated<" << type.getExpressedType(); | ||||
417 | out << "<" << type.getMin() << ":" << type.getMax() << ">"; | ||||
418 | out << ">"; | ||||
419 | } | ||||
420 | |||||
421 | /// Print a type registered to this dialect. | ||||
422 | void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { | ||||
423 | if (auto anyType = type.dyn_cast<AnyQuantizedType>()) | ||||
424 | printAnyQuantizedType(anyType, os); | ||||
425 | else if (auto uniformType = type.dyn_cast<UniformQuantizedType>()) | ||||
426 | printUniformQuantizedType(uniformType, os); | ||||
427 | else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>()) | ||||
428 | printUniformQuantizedPerAxisType(perAxisType, os); | ||||
429 | else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>()) | ||||
430 | printCalibratedQuantizedType(calibratedType, os); | ||||
431 | else | ||||
432 | llvm_unreachable("Unhandled quantized type")::llvm::llvm_unreachable_internal("Unhandled quantized type", "mlir/lib/Dialect/Quant/IR/TypeParser.cpp", 432); | ||||
433 | } |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | #include <optional> |
22 | |
23 | namespace mlir { |
24 | class AsmParsedResourceEntry; |
25 | class AsmResourceBuilder; |
26 | class Builder; |
27 | |
28 | //===----------------------------------------------------------------------===// |
29 | // AsmDialectResourceHandle |
30 | //===----------------------------------------------------------------------===// |
31 | |
32 | /// This class represents an opaque handle to a dialect resource entry. |
33 | class AsmDialectResourceHandle { |
34 | public: |
35 | AsmDialectResourceHandle() = default; |
36 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
37 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
38 | bool operator==(const AsmDialectResourceHandle &other) const { |
39 | return resource == other.resource; |
40 | } |
41 | |
42 | /// Return an opaque pointer to the referenced resource. |
43 | void *getResource() const { return resource; } |
44 | |
45 | /// Return the type ID of the resource. |
46 | TypeID getTypeID() const { return opaqueID; } |
47 | |
48 | /// Return the dialect that owns the resource. |
49 | Dialect *getDialect() const { return dialect; } |
50 | |
51 | private: |
52 | /// The opaque handle to the dialect resource. |
53 | void *resource = nullptr; |
54 | /// The type of the resource referenced. |
55 | TypeID opaqueID; |
56 | /// The dialect owning the given resource. |
57 | Dialect *dialect; |
58 | }; |
59 | |
60 | /// This class represents a CRTP base class for dialect resource handles. It |
61 | /// abstracts away various utilities necessary for defined derived resource |
62 | /// handles. |
63 | template <typename DerivedT, typename ResourceT, typename DialectT> |
64 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
65 | public: |
66 | using Dialect = DialectT; |
67 | |
68 | /// Construct a handle from a pointer to the resource. The given pointer |
69 | /// should be guaranteed to live beyond the life of this handle. |
70 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
71 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
72 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
73 | : AsmDialectResourceHandle(handle) { |
74 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 74, __extension__ __PRETTY_FUNCTION__)); |
75 | } |
76 | |
77 | /// Return the resource referenced by this handle. |
78 | ResourceT *getResource() { |
79 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
80 | } |
81 | const ResourceT *getResource() const { |
82 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
83 | } |
84 | |
85 | /// Return the dialect that owns the resource. |
86 | DialectT *getDialect() const { |
87 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
88 | } |
89 | |
90 | /// Support llvm style casting. |
91 | static bool classof(const AsmDialectResourceHandle *handle) { |
92 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
93 | } |
94 | }; |
95 | |
96 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
97 | return llvm::hash_value(param.getResource()); |
98 | } |
99 | |
100 | //===----------------------------------------------------------------------===// |
101 | // AsmPrinter |
102 | //===----------------------------------------------------------------------===// |
103 | |
104 | /// This base class exposes generic asm printer hooks, usable across the various |
105 | /// derived printers. |
106 | class AsmPrinter { |
107 | public: |
108 | /// This class contains the internal default implementation of the base |
109 | /// printer methods. |
110 | class Impl; |
111 | |
112 | /// Initialize the printer with the given internal implementation. |
113 | AsmPrinter(Impl &impl) : impl(&impl) {} |
114 | virtual ~AsmPrinter(); |
115 | |
116 | /// Return the raw output stream used by this printer. |
117 | virtual raw_ostream &getStream() const; |
118 | |
119 | /// Print the given floating point value in a stabilized form that can be |
120 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
121 | /// hook on the AsmParser. |
122 | virtual void printFloat(const APFloat &value); |
123 | |
124 | virtual void printType(Type type); |
125 | virtual void printAttribute(Attribute attr); |
126 | |
127 | /// Trait to check if `AttrType` provides a `print` method. |
128 | template <typename AttrOrType> |
129 | using has_print_method = |
130 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
131 | template <typename AttrOrType> |
132 | using detect_has_print_method = |
133 | llvm::is_detected<has_print_method, AttrOrType>; |
134 | |
135 | /// Print the provided attribute in the context of an operation custom |
136 | /// printer/parser: this will invoke directly the print method on the |
137 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
138 | template <typename AttrOrType, |
139 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
140 | *sfinae = nullptr> |
141 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
142 | if (succeeded(printAlias(attrOrType))) |
143 | return; |
144 | attrOrType.print(*this); |
145 | } |
146 | |
147 | /// Print the provided array of attributes or types in the context of an |
148 | /// operation custom printer/parser: this will invoke directly the print |
149 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
150 | /// most cases. |
151 | template <typename AttrOrType, |
152 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
153 | *sfinae = nullptr> |
154 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
155 | llvm::interleaveComma( |
156 | attrOrTypes, getStream(), |
157 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
158 | } |
159 | |
160 | /// SFINAE for printing the provided attribute in the context of an operation |
161 | /// custom printer in the case where the attribute does not define a print |
162 | /// method. |
163 | template <typename AttrOrType, |
164 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
165 | *sfinae = nullptr> |
166 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
167 | *this << attrOrType; |
168 | } |
169 | |
170 | /// Print the given attribute without its type. The corresponding parser must |
171 | /// provide a valid type for the attribute. |
172 | virtual void printAttributeWithoutType(Attribute attr); |
173 | |
174 | /// Print the given string as a keyword, or a quoted and escaped string if it |
175 | /// has any special or non-printable characters in it. |
176 | virtual void printKeywordOrString(StringRef keyword); |
177 | |
178 | /// Print the given string as a symbol reference, i.e. a form representable by |
179 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
180 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
181 | /// special or non-printable characters in it. |
182 | virtual void printSymbolName(StringRef symbolRef); |
183 | |
184 | /// Print a handle to the given dialect resource. |
185 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); |
186 | |
187 | /// Print an optional arrow followed by a type list. |
188 | template <typename TypeRange> |
189 | void printOptionalArrowTypeList(TypeRange &&types) { |
190 | if (types.begin() != types.end()) |
191 | printArrowTypeList(types); |
192 | } |
193 | template <typename TypeRange> |
194 | void printArrowTypeList(TypeRange &&types) { |
195 | auto &os = getStream() << " -> "; |
196 | |
197 | bool wrapped = !llvm::hasSingleElement(types) || |
198 | (*types.begin()).template isa<FunctionType>(); |
199 | if (wrapped) |
200 | os << '('; |
201 | llvm::interleaveComma(types, *this); |
202 | if (wrapped) |
203 | os << ')'; |
204 | } |
205 | |
206 | /// Print the two given type ranges in a functional form. |
207 | template <typename InputRangeT, typename ResultRangeT> |
208 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
209 | auto &os = getStream(); |
210 | os << '('; |
211 | llvm::interleaveComma(inputs, *this); |
212 | os << ')'; |
213 | printArrowTypeList(results); |
214 | } |
215 | |
216 | protected: |
217 | /// Initialize the printer with no internal implementation. In this case, all |
218 | /// virtual methods of this class must be overriden. |
219 | AsmPrinter() = default; |
220 | |
221 | private: |
222 | AsmPrinter(const AsmPrinter &) = delete; |
223 | void operator=(const AsmPrinter &) = delete; |
224 | |
225 | /// Print the alias for the given attribute, return failure if no alias could |
226 | /// be printed. |
227 | virtual LogicalResult printAlias(Attribute attr); |
228 | |
229 | /// Print the alias for the given type, return failure if no alias could |
230 | /// be printed. |
231 | virtual LogicalResult printAlias(Type type); |
232 | |
233 | /// The internal implementation of the printer. |
234 | Impl *impl{nullptr}; |
235 | }; |
236 | |
237 | template <typename AsmPrinterT> |
238 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
239 | AsmPrinterT &> |
240 | operator<<(AsmPrinterT &p, Type type) { |
241 | p.printType(type); |
242 | return p; |
243 | } |
244 | |
245 | template <typename AsmPrinterT> |
246 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
247 | AsmPrinterT &> |
248 | operator<<(AsmPrinterT &p, Attribute attr) { |
249 | p.printAttribute(attr); |
250 | return p; |
251 | } |
252 | |
253 | template <typename AsmPrinterT> |
254 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
255 | AsmPrinterT &> |
256 | operator<<(AsmPrinterT &p, const APFloat &value) { |
257 | p.printFloat(value); |
258 | return p; |
259 | } |
260 | template <typename AsmPrinterT> |
261 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
262 | AsmPrinterT &> |
263 | operator<<(AsmPrinterT &p, float value) { |
264 | return p << APFloat(value); |
265 | } |
266 | template <typename AsmPrinterT> |
267 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
268 | AsmPrinterT &> |
269 | operator<<(AsmPrinterT &p, double value) { |
270 | return p << APFloat(value); |
271 | } |
272 | |
273 | // Support printing anything that isn't convertible to one of the other |
274 | // streamable types, even if it isn't exactly one of them. For example, we want |
275 | // to print FunctionType with the Type version above, not have it match this. |
276 | template <typename AsmPrinterT, typename T, |
277 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && |
278 | !std::is_convertible<T &, Type &>::value && |
279 | !std::is_convertible<T &, Attribute &>::value && |
280 | !std::is_convertible<T &, ValueRange>::value && |
281 | !std::is_convertible<T &, APFloat &>::value && |
282 | !llvm::is_one_of<T, bool, float, double>::value, |
283 | T> * = nullptr> |
284 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
285 | AsmPrinterT &> |
286 | operator<<(AsmPrinterT &p, const T &other) { |
287 | p.getStream() << other; |
288 | return p; |
289 | } |
290 | |
291 | template <typename AsmPrinterT> |
292 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
293 | AsmPrinterT &> |
294 | operator<<(AsmPrinterT &p, bool value) { |
295 | return p << (value ? StringRef("true") : "false"); |
296 | } |
297 | |
298 | template <typename AsmPrinterT, typename ValueRangeT> |
299 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
300 | AsmPrinterT &> |
301 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
302 | llvm::interleaveComma(types, p); |
303 | return p; |
304 | } |
305 | template <typename AsmPrinterT> |
306 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
307 | AsmPrinterT &> |
308 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
309 | llvm::interleaveComma(types, p); |
310 | return p; |
311 | } |
312 | template <typename AsmPrinterT, typename ElementT> |
313 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
314 | AsmPrinterT &> |
315 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
316 | llvm::interleaveComma(types, p); |
317 | return p; |
318 | } |
319 | |
320 | //===----------------------------------------------------------------------===// |
321 | // OpAsmPrinter |
322 | //===----------------------------------------------------------------------===// |
323 | |
324 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
325 | /// necessary to implement a custom print() method. |
326 | class OpAsmPrinter : public AsmPrinter { |
327 | public: |
328 | using AsmPrinter::AsmPrinter; |
329 | ~OpAsmPrinter() override; |
330 | |
331 | /// Print a loc(...) specifier if printing debug info is enabled. |
332 | virtual void printOptionalLocationSpecifier(Location loc) = 0; |
333 | |
334 | /// Print a newline and indent the printer to the start of the current |
335 | /// operation. |
336 | virtual void printNewline() = 0; |
337 | |
338 | /// Increase indentation. |
339 | virtual void increaseIndent() = 0; |
340 | |
341 | /// Decrease indentation. |
342 | virtual void decreaseIndent() = 0; |
343 | |
344 | /// Print a block argument in the usual format of: |
345 | /// %ssaName : type {attr1=42} loc("here") |
346 | /// where location printing is controlled by the standard internal option. |
347 | /// You may pass omitType=true to not print a type, and pass an empty |
348 | /// attribute list if you don't care for attributes. |
349 | virtual void printRegionArgument(BlockArgument arg, |
350 | ArrayRef<NamedAttribute> argAttrs = {}, |
351 | bool omitType = false) = 0; |
352 | |
353 | /// Print implementations for various things an operation contains. |
354 | virtual void printOperand(Value value) = 0; |
355 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
356 | |
357 | /// Print a comma separated list of operands. |
358 | template <typename ContainerType> |
359 | void printOperands(const ContainerType &container) { |
360 | printOperands(container.begin(), container.end()); |
361 | } |
362 | |
363 | /// Print a comma separated list of operands. |
364 | template <typename IteratorType> |
365 | void printOperands(IteratorType it, IteratorType end) { |
366 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), |
367 | [this](Value value) { printOperand(value); }); |
368 | } |
369 | |
370 | /// Print the given successor. |
371 | virtual void printSuccessor(Block *successor) = 0; |
372 | |
373 | /// Print the successor and its operands. |
374 | virtual void printSuccessorAndUseList(Block *successor, |
375 | ValueRange succOperands) = 0; |
376 | |
377 | /// If the specified operation has attributes, print out an attribute |
378 | /// dictionary with their values. elidedAttrs allows the client to ignore |
379 | /// specific well known attributes, commonly used if the attribute value is |
380 | /// printed some other way (like as a fixed operand). |
381 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
382 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
383 | |
384 | /// If the specified operation has attributes, print out an attribute |
385 | /// dictionary prefixed with 'attributes'. |
386 | virtual void |
387 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
388 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
389 | |
390 | /// Prints the entire operation with the custom assembly form, if available, |
391 | /// or the generic assembly form, otherwise. |
392 | virtual void printCustomOrGenericOp(Operation *op) = 0; |
393 | |
394 | /// Print the entire operation with the default generic assembly form. |
395 | /// If `printOpName` is true, then the operation name is printed (the default) |
396 | /// otherwise it is omitted and the print will start with the operand list. |
397 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
398 | |
399 | /// Prints a region. |
400 | /// If 'printEntryBlockArgs' is false, the arguments of the |
401 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
402 | /// operation of the block is not printed. If printEmptyBlock is true, then |
403 | /// the block header is printed even if the block is empty. |
404 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
405 | bool printBlockTerminators = true, |
406 | bool printEmptyBlock = false) = 0; |
407 | |
408 | /// Renumber the arguments for the specified region to the same names as the |
409 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
410 | /// operations. If any entry in namesToUse is null, the corresponding |
411 | /// argument name is left alone. |
412 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
413 | |
414 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
415 | /// of dims/symbols. |
416 | /// Operand values must come from single-result sources, and be valid |
417 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
418 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
419 | ValueRange operands) = 0; |
420 | |
421 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
422 | /// dims and symbols. |
423 | /// Operand values must come from single-result sources, and be valid |
424 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
425 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
426 | ValueRange symOperands) = 0; |
427 | |
428 | /// Print the complete type of an operation in functional form. |
429 | void printFunctionalType(Operation *op); |
430 | using AsmPrinter::printFunctionalType; |
431 | }; |
432 | |
433 | // Make the implementations convenient to use. |
434 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
435 | p.printOperand(value); |
436 | return p; |
437 | } |
438 | |
439 | template <typename T, |
440 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && |
441 | !std::is_convertible<T &, Value &>::value, |
442 | T> * = nullptr> |
443 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
444 | p.printOperands(values); |
445 | return p; |
446 | } |
447 | |
448 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
449 | p.printSuccessor(value); |
450 | return p; |
451 | } |
452 | |
453 | //===----------------------------------------------------------------------===// |
454 | // AsmParser |
455 | //===----------------------------------------------------------------------===// |
456 | |
457 | /// This base class exposes generic asm parser hooks, usable across the various |
458 | /// derived parsers. |
459 | class AsmParser { |
460 | public: |
461 | AsmParser() = default; |
462 | virtual ~AsmParser(); |
463 | |
464 | MLIRContext *getContext() const; |
465 | |
466 | /// Return the location of the original name token. |
467 | virtual SMLoc getNameLoc() const = 0; |
468 | |
469 | //===--------------------------------------------------------------------===// |
470 | // Utilities |
471 | //===--------------------------------------------------------------------===// |
472 | |
473 | /// Emit a diagnostic at the specified location and return failure. |
474 | virtual InFlightDiagnostic emitError(SMLoc loc, |
475 | const Twine &message = {}) = 0; |
476 | |
477 | /// Return a builder which provides useful access to MLIRContext, global |
478 | /// objects like types and attributes. |
479 | virtual Builder &getBuilder() const = 0; |
480 | |
481 | /// Get the location of the next token and store it into the argument. This |
482 | /// always succeeds. |
483 | virtual SMLoc getCurrentLocation() = 0; |
484 | ParseResult getCurrentLocation(SMLoc *loc) { |
485 | *loc = getCurrentLocation(); |
486 | return success(); |
487 | } |
488 | |
489 | /// Re-encode the given source location as an MLIR location and return it. |
490 | /// Note: This method should only be used when a `Location` is necessary, as |
491 | /// the encoding process is not efficient. |
492 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
493 | |
494 | //===--------------------------------------------------------------------===// |
495 | // Token Parsing |
496 | //===--------------------------------------------------------------------===// |
497 | |
498 | /// Parse a '->' token. |
499 | virtual ParseResult parseArrow() = 0; |
500 | |
501 | /// Parse a '->' token if present |
502 | virtual ParseResult parseOptionalArrow() = 0; |
503 | |
504 | /// Parse a `{` token. |
505 | virtual ParseResult parseLBrace() = 0; |
506 | |
507 | /// Parse a `{` token if present. |
508 | virtual ParseResult parseOptionalLBrace() = 0; |
509 | |
510 | /// Parse a `}` token. |
511 | virtual ParseResult parseRBrace() = 0; |
512 | |
513 | /// Parse a `}` token if present. |
514 | virtual ParseResult parseOptionalRBrace() = 0; |
515 | |
516 | /// Parse a `:` token. |
517 | virtual ParseResult parseColon() = 0; |
518 | |
519 | /// Parse a `:` token if present. |
520 | virtual ParseResult parseOptionalColon() = 0; |
521 | |
522 | /// Parse a `,` token. |
523 | virtual ParseResult parseComma() = 0; |
524 | |
525 | /// Parse a `,` token if present. |
526 | virtual ParseResult parseOptionalComma() = 0; |
527 | |
528 | /// Parse a `=` token. |
529 | virtual ParseResult parseEqual() = 0; |
530 | |
531 | /// Parse a `=` token if present. |
532 | virtual ParseResult parseOptionalEqual() = 0; |
533 | |
534 | /// Parse a '<' token. |
535 | virtual ParseResult parseLess() = 0; |
536 | |
537 | /// Parse a '<' token if present. |
538 | virtual ParseResult parseOptionalLess() = 0; |
539 | |
540 | /// Parse a '>' token. |
541 | virtual ParseResult parseGreater() = 0; |
542 | |
543 | /// Parse a '>' token if present. |
544 | virtual ParseResult parseOptionalGreater() = 0; |
545 | |
546 | /// Parse a '?' token. |
547 | virtual ParseResult parseQuestion() = 0; |
548 | |
549 | /// Parse a '?' token if present. |
550 | virtual ParseResult parseOptionalQuestion() = 0; |
551 | |
552 | /// Parse a '+' token. |
553 | virtual ParseResult parsePlus() = 0; |
554 | |
555 | /// Parse a '+' token if present. |
556 | virtual ParseResult parseOptionalPlus() = 0; |
557 | |
558 | /// Parse a '*' token. |
559 | virtual ParseResult parseStar() = 0; |
560 | |
561 | /// Parse a '*' token if present. |
562 | virtual ParseResult parseOptionalStar() = 0; |
563 | |
564 | /// Parse a '|' token. |
565 | virtual ParseResult parseVerticalBar() = 0; |
566 | |
567 | /// Parse a '|' token if present. |
568 | virtual ParseResult parseOptionalVerticalBar() = 0; |
569 | |
570 | /// Parse a quoted string token. |
571 | ParseResult parseString(std::string *string) { |
572 | auto loc = getCurrentLocation(); |
573 | if (parseOptionalString(string)) |
574 | return emitError(loc, "expected string"); |
575 | return success(); |
576 | } |
577 | |
578 | /// Parse a quoted string token if present. |
579 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
580 | |
581 | /// Parses a Base64 encoded string of bytes. |
582 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; |
583 | |
584 | /// Parse a `(` token. |
585 | virtual ParseResult parseLParen() = 0; |
586 | |
587 | /// Parse a `(` token if present. |
588 | virtual ParseResult parseOptionalLParen() = 0; |
589 | |
590 | /// Parse a `)` token. |
591 | virtual ParseResult parseRParen() = 0; |
592 | |
593 | /// Parse a `)` token if present. |
594 | virtual ParseResult parseOptionalRParen() = 0; |
595 | |
596 | /// Parse a `[` token. |
597 | virtual ParseResult parseLSquare() = 0; |
598 | |
599 | /// Parse a `[` token if present. |
600 | virtual ParseResult parseOptionalLSquare() = 0; |
601 | |
602 | /// Parse a `]` token. |
603 | virtual ParseResult parseRSquare() = 0; |
604 | |
605 | /// Parse a `]` token if present. |
606 | virtual ParseResult parseOptionalRSquare() = 0; |
607 | |
608 | /// Parse a `...` token. |
609 | virtual ParseResult parseEllipsis() = 0; |
610 | |
611 | /// Parse a `...` token if present; |
612 | virtual ParseResult parseOptionalEllipsis() = 0; |
613 | |
614 | /// Parse a floating point value from the stream. |
615 | virtual ParseResult parseFloat(double &result) = 0; |
616 | |
617 | /// Parse an integer value from the stream. |
618 | template <typename IntT> |
619 | ParseResult parseInteger(IntT &result) { |
620 | auto loc = getCurrentLocation(); |
621 | OptionalParseResult parseResult = parseOptionalInteger(result); |
622 | if (!parseResult.has_value()) |
623 | return emitError(loc, "expected integer value"); |
624 | return *parseResult; |
625 | } |
626 | |
627 | /// Parse an optional integer value from the stream. |
628 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
629 | |
630 | template <typename IntT> |
631 | OptionalParseResult parseOptionalInteger(IntT &result) { |
632 | auto loc = getCurrentLocation(); |
633 | |
634 | // Parse the unsigned variant. |
635 | APInt uintResult; |
636 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
637 | if (!parseResult.has_value() || failed(*parseResult)) |
638 | return parseResult; |
639 | |
640 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
641 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
642 | // zero for non-negated integers. |
643 | result = |
644 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
645 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
646 | return emitError(loc, "integer value too large"); |
647 | return success(); |
648 | } |
649 | |
650 | /// These are the supported delimiters around operand lists and region |
651 | /// argument lists, used by parseOperandList. |
652 | enum class Delimiter { |
653 | /// Zero or more operands with no delimiters. |
654 | None, |
655 | /// Parens surrounding zero or more operands. |
656 | Paren, |
657 | /// Square brackets surrounding zero or more operands. |
658 | Square, |
659 | /// <> brackets surrounding zero or more operands. |
660 | LessGreater, |
661 | /// {} brackets surrounding zero or more operands. |
662 | Braces, |
663 | /// Parens supporting zero or more operands, or nothing. |
664 | OptionalParen, |
665 | /// Square brackets supporting zero or more ops, or nothing. |
666 | OptionalSquare, |
667 | /// <> brackets supporting zero or more ops, or nothing. |
668 | OptionalLessGreater, |
669 | /// {} brackets surrounding zero or more operands, or nothing. |
670 | OptionalBraces, |
671 | }; |
672 | |
673 | /// Parse a list of comma-separated items with an optional delimiter. If a |
674 | /// delimiter is provided, then an empty list is allowed. If not, then at |
675 | /// least one element will be parsed. |
676 | /// |
677 | /// contextMessage is an optional message appended to "expected '('" sorts of |
678 | /// diagnostics when parsing the delimeters. |
679 | virtual ParseResult |
680 | parseCommaSeparatedList(Delimiter delimiter, |
681 | function_ref<ParseResult()> parseElementFn, |
682 | StringRef contextMessage = StringRef()) = 0; |
683 | |
684 | /// Parse a comma separated list of elements that must have at least one entry |
685 | /// in it. |
686 | ParseResult |
687 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
688 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
689 | } |
690 | |
691 | //===--------------------------------------------------------------------===// |
692 | // Keyword Parsing |
693 | //===--------------------------------------------------------------------===// |
694 | |
695 | /// This class represents a StringSwitch like class that is useful for parsing |
696 | /// expected keywords. On construction, it invokes `parseKeyword` and |
697 | /// processes each of the provided cases statements until a match is hit. The |
698 | /// provided `ResultT` must be assignable from `failure()`. |
699 | template <typename ResultT = ParseResult> |
700 | class KeywordSwitch { |
701 | public: |
702 | KeywordSwitch(AsmParser &parser) |
703 | : parser(parser), loc(parser.getCurrentLocation()) { |
704 | if (failed(parser.parseKeywordOrCompletion(&keyword))) |
705 | result = failure(); |
706 | } |
707 | |
708 | /// Case that uses the provided value when true. |
709 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
710 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
711 | } |
712 | KeywordSwitch &Default(ResultT value) { |
713 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
714 | } |
715 | /// Case that invokes the provided functor when true. The parameters passed |
716 | /// to the functor are the keyword, and the location of the keyword (in case |
717 | /// any errors need to be emitted). |
718 | template <typename FnT> |
719 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
720 | Case(StringLiteral str, FnT &&fn) { |
721 | if (result) |
722 | return *this; |
723 | |
724 | // If the word was empty, record this as a completion. |
725 | if (keyword.empty()) |
726 | parser.codeCompleteExpectedTokens(str); |
727 | else if (keyword == str) |
728 | result.emplace(std::move(fn(keyword, loc))); |
729 | return *this; |
730 | } |
731 | template <typename FnT> |
732 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
733 | Default(FnT &&fn) { |
734 | if (!result) |
735 | result.emplace(fn(keyword, loc)); |
736 | return *this; |
737 | } |
738 | |
739 | /// Returns true if this switch has a value yet. |
740 | bool hasValue() const { return result.has_value(); } |
741 | |
742 | /// Return the result of the switch. |
743 | [[nodiscard]] operator ResultT() { |
744 | if (!result) |
745 | return parser.emitError(loc, "unexpected keyword: ") << keyword; |
746 | return std::move(*result); |
747 | } |
748 | |
749 | private: |
750 | /// The parser used to construct this switch. |
751 | AsmParser &parser; |
752 | |
753 | /// The location of the keyword, used to emit errors as necessary. |
754 | SMLoc loc; |
755 | |
756 | /// The parsed keyword itself. |
757 | StringRef keyword; |
758 | |
759 | /// The result of the switch statement or none if currently unknown. |
760 | std::optional<ResultT> result; |
761 | }; |
762 | |
763 | /// Parse a given keyword. |
764 | ParseResult parseKeyword(StringRef keyword) { |
765 | return parseKeyword(keyword, ""); |
766 | } |
767 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
768 | |
769 | /// Parse a keyword into 'keyword'. |
770 | ParseResult parseKeyword(StringRef *keyword) { |
771 | auto loc = getCurrentLocation(); |
772 | if (parseOptionalKeyword(keyword)) |
773 | return emitError(loc, "expected valid keyword"); |
774 | return success(); |
775 | } |
776 | |
777 | /// Parse the given keyword if present. |
778 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
779 | |
780 | /// Parse a keyword, if present, into 'keyword'. |
781 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
782 | |
783 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
784 | /// into 'keyword' |
785 | virtual ParseResult |
786 | parseOptionalKeyword(StringRef *keyword, |
787 | ArrayRef<StringRef> allowedValues) = 0; |
788 | |
789 | /// Parse a keyword or a quoted string. |
790 | ParseResult parseKeywordOrString(std::string *result) { |
791 | if (failed(parseOptionalKeywordOrString(result))) |
792 | return emitError(getCurrentLocation()) |
793 | << "expected valid keyword or string"; |
794 | return success(); |
795 | } |
796 | |
797 | /// Parse an optional keyword or string. |
798 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
799 | |
800 | //===--------------------------------------------------------------------===// |
801 | // Attribute/Type Parsing |
802 | //===--------------------------------------------------------------------===// |
803 | |
804 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
805 | /// the provided location to emit errors in the case of failure. Note that |
806 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
807 | /// context parameter. |
808 | template <typename T, typename... ParamsT> |
809 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
810 | return T::getChecked([&] { return emitError(loc); }, |
811 | std::forward<ParamsT>(params)...); |
812 | } |
813 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
814 | /// errors. |
815 | template <typename T, typename... ParamsT> |
816 | auto getChecked(ParamsT &&...params) { |
817 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
818 | std::forward<ParamsT>(params)...); |
819 | } |
820 | |
821 | //===--------------------------------------------------------------------===// |
822 | // Attribute Parsing |
823 | //===--------------------------------------------------------------------===// |
824 | |
825 | /// Parse an arbitrary attribute of a given type and return it in result. |
826 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
827 | |
828 | /// Parse a custom attribute with the provided callback, unless the next |
829 | /// token is `#`, in which case the generic parser is invoked. |
830 | virtual ParseResult parseCustomAttributeWithFallback( |
831 | Attribute &result, Type type, |
832 | function_ref<ParseResult(Attribute &result, Type type)> |
833 | parseAttribute) = 0; |
834 | |
835 | /// Parse an attribute of a specific kind and type. |
836 | template <typename AttrType> |
837 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
838 | SMLoc loc = getCurrentLocation(); |
839 | |
840 | // Parse any kind of attribute. |
841 | Attribute attr; |
842 | if (parseAttribute(attr, type)) |
843 | return failure(); |
844 | |
845 | // Check for the right kind of attribute. |
846 | if (!(result = attr.dyn_cast<AttrType>())) |
847 | return emitError(loc, "invalid kind of attribute specified"); |
848 | |
849 | return success(); |
850 | } |
851 | |
852 | /// Parse an arbitrary attribute and return it in result. This also adds the |
853 | /// attribute to the specified attribute list with the specified name. |
854 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
855 | NamedAttrList &attrs) { |
856 | return parseAttribute(result, Type(), attrName, attrs); |
857 | } |
858 | |
859 | /// Parse an attribute of a specific kind and type. |
860 | template <typename AttrType> |
861 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
862 | NamedAttrList &attrs) { |
863 | return parseAttribute(result, Type(), attrName, attrs); |
864 | } |
865 | |
866 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
867 | /// This also adds the attribute to the specified attribute list with the |
868 | /// specified name. |
869 | template <typename AttrType> |
870 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
871 | NamedAttrList &attrs) { |
872 | SMLoc loc = getCurrentLocation(); |
873 | |
874 | // Parse any kind of attribute. |
875 | Attribute attr; |
876 | if (parseAttribute(attr, type)) |
877 | return failure(); |
878 | |
879 | // Check for the right kind of attribute. |
880 | result = attr.dyn_cast<AttrType>(); |
881 | if (!result) |
882 | return emitError(loc, "invalid kind of attribute specified"); |
883 | |
884 | attrs.append(attrName, result); |
885 | return success(); |
886 | } |
887 | |
888 | /// Trait to check if `AttrType` provides a `parse` method. |
889 | template <typename AttrType> |
890 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
891 | std::declval<Type>())); |
892 | template <typename AttrType> |
893 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
894 | |
895 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
896 | /// which case the generic parser is invoked. The parsed attribute is |
897 | /// populated in `result` and also added to the specified attribute list with |
898 | /// the specified name. |
899 | template <typename AttrType> |
900 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
901 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
902 | StringRef attrName, NamedAttrList &attrs) { |
903 | SMLoc loc = getCurrentLocation(); |
904 | |
905 | // Parse any kind of attribute. |
906 | Attribute attr; |
907 | if (parseCustomAttributeWithFallback( |
908 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
909 | result = AttrType::parse(*this, type); |
910 | if (!result) |
911 | return failure(); |
912 | return success(); |
913 | })) |
914 | return failure(); |
915 | |
916 | // Check for the right kind of attribute. |
917 | result = attr.dyn_cast<AttrType>(); |
918 | if (!result) |
919 | return emitError(loc, "invalid kind of attribute specified"); |
920 | |
921 | attrs.append(attrName, result); |
922 | return success(); |
923 | } |
924 | |
925 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
926 | template <typename AttrType> |
927 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
928 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
929 | StringRef attrName, NamedAttrList &attrs) { |
930 | return parseAttribute(result, type, attrName, attrs); |
931 | } |
932 | |
933 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
934 | /// which case the generic parser is invoked. The parsed attribute is |
935 | /// populated in `result`. |
936 | template <typename AttrType> |
937 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
938 | parseCustomAttributeWithFallback(AttrType &result) { |
939 | SMLoc loc = getCurrentLocation(); |
940 | |
941 | // Parse any kind of attribute. |
942 | Attribute attr; |
943 | if (parseCustomAttributeWithFallback( |
944 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
945 | result = AttrType::parse(*this, type); |
946 | return success(!!result); |
947 | })) |
948 | return failure(); |
949 | |
950 | // Check for the right kind of attribute. |
951 | result = attr.dyn_cast<AttrType>(); |
952 | if (!result) |
953 | return emitError(loc, "invalid kind of attribute specified"); |
954 | return success(); |
955 | } |
956 | |
957 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
958 | template <typename AttrType> |
959 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
960 | parseCustomAttributeWithFallback(AttrType &result) { |
961 | return parseAttribute(result); |
962 | } |
963 | |
964 | /// Parse an arbitrary optional attribute of a given type and return it in |
965 | /// result. |
966 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
967 | Type type = {}) = 0; |
968 | |
969 | /// Parse an optional array attribute and return it in result. |
970 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
971 | Type type = {}) = 0; |
972 | |
973 | /// Parse an optional string attribute and return it in result. |
974 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
975 | Type type = {}) = 0; |
976 | |
977 | /// Parse an optional symbol ref attribute and return it in result. |
978 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, |
979 | Type type = {}) = 0; |
980 | |
981 | /// Parse an optional attribute of a specific type and add it to the list with |
982 | /// the specified name. |
983 | template <typename AttrType> |
984 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
985 | StringRef attrName, |
986 | NamedAttrList &attrs) { |
987 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
988 | } |
989 | |
990 | /// Parse an optional attribute of a specific type and add it to the list with |
991 | /// the specified name. |
992 | template <typename AttrType> |
993 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
994 | StringRef attrName, |
995 | NamedAttrList &attrs) { |
996 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
997 | if (parseResult.has_value() && succeeded(*parseResult)) |
998 | attrs.append(attrName, result); |
999 | return parseResult; |
1000 | } |
1001 | |
1002 | /// Parse a named dictionary into 'result' if it is present. |
1003 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
1004 | |
1005 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
1006 | /// present. |
1007 | virtual ParseResult |
1008 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
1009 | |
1010 | /// Parse an affine map instance into 'map'. |
1011 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
1012 | |
1013 | /// Parse an integer set instance into 'set'. |
1014 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
1015 | |
1016 | //===--------------------------------------------------------------------===// |
1017 | // Identifier Parsing |
1018 | //===--------------------------------------------------------------------===// |
1019 | |
1020 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1021 | /// attribute. |
1022 | ParseResult parseSymbolName(StringAttr &result) { |
1023 | if (failed(parseOptionalSymbolName(result))) |
1024 | return emitError(getCurrentLocation()) |
1025 | << "expected valid '@'-identifier for symbol name"; |
1026 | return success(); |
1027 | } |
1028 | |
1029 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1030 | /// attribute named 'attrName'. |
1031 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1032 | NamedAttrList &attrs) { |
1033 | if (parseSymbolName(result)) |
1034 | return failure(); |
1035 | attrs.append(attrName, result); |
1036 | return success(); |
1037 | } |
1038 | |
1039 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1040 | /// string attribute. |
1041 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; |
1042 | |
1043 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1044 | /// string attribute named 'attrName'. |
1045 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
1046 | NamedAttrList &attrs) { |
1047 | if (succeeded(parseOptionalSymbolName(result))) { |
1048 | attrs.append(attrName, result); |
1049 | return success(); |
1050 | } |
1051 | return failure(); |
1052 | } |
1053 | |
1054 | //===--------------------------------------------------------------------===// |
1055 | // Resource Parsing |
1056 | //===--------------------------------------------------------------------===// |
1057 | |
1058 | /// Parse a handle to a resource within the assembly format. |
1059 | template <typename ResourceT> |
1060 | FailureOr<ResourceT> parseResourceHandle() { |
1061 | SMLoc handleLoc = getCurrentLocation(); |
1062 | |
1063 | // Try to load the dialect that owns the handle. |
1064 | auto *dialect = |
1065 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1066 | if (!dialect) { |
1067 | return emitError(handleLoc) |
1068 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1069 | << "' is unknown"; |
1070 | } |
1071 | |
1072 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1073 | if (failed(handle)) |
1074 | return failure(); |
1075 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1076 | return std::move(*result); |
1077 | return emitError(handleLoc) << "provided resource handle differs from the " |
1078 | "expected resource type"; |
1079 | } |
1080 | |
1081 | //===--------------------------------------------------------------------===// |
1082 | // Type Parsing |
1083 | //===--------------------------------------------------------------------===// |
1084 | |
1085 | /// Parse a type. |
1086 | virtual ParseResult parseType(Type &result) = 0; |
1087 | |
1088 | /// Parse a custom type with the provided callback, unless the next |
1089 | /// token is `#`, in which case the generic parser is invoked. |
1090 | virtual ParseResult parseCustomTypeWithFallback( |
1091 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1092 | |
1093 | /// Parse an optional type. |
1094 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1095 | |
1096 | /// Parse a type of a specific type. |
1097 | template <typename TypeT> |
1098 | ParseResult parseType(TypeT &result) { |
1099 | SMLoc loc = getCurrentLocation(); |
1100 | |
1101 | // Parse any kind of type. |
1102 | Type type; |
1103 | if (parseType(type)) |
1104 | return failure(); |
1105 | |
1106 | // Check for the right kind of type. |
1107 | result = type.dyn_cast<TypeT>(); |
1108 | if (!result) |
1109 | return emitError(loc, "invalid kind of type specified"); |
1110 | |
1111 | return success(); |
1112 | } |
1113 | |
1114 | /// Trait to check if `TypeT` provides a `parse` method. |
1115 | template <typename TypeT> |
1116 | using type_has_parse_method = |
1117 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1118 | template <typename TypeT> |
1119 | using detect_type_has_parse_method = |
1120 | llvm::is_detected<type_has_parse_method, TypeT>; |
1121 | |
1122 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1123 | /// which case the generic parser is invoked. The parsed Type is |
1124 | /// populated in `result`. |
1125 | template <typename TypeT> |
1126 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1127 | parseCustomTypeWithFallback(TypeT &result) { |
1128 | SMLoc loc = getCurrentLocation(); |
1129 | |
1130 | // Parse any kind of Type. |
1131 | Type type; |
1132 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1133 | result = TypeT::parse(*this); |
1134 | return success(!!result); |
1135 | })) |
1136 | return failure(); |
1137 | |
1138 | // Check for the right kind of Type. |
1139 | result = type.dyn_cast<TypeT>(); |
1140 | if (!result) |
1141 | return emitError(loc, "invalid kind of Type specified"); |
1142 | return success(); |
1143 | } |
1144 | |
1145 | /// SFINAE parsing method for Type that don't implement a parse method. |
1146 | template <typename TypeT> |
1147 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1148 | parseCustomTypeWithFallback(TypeT &result) { |
1149 | return parseType(result); |
1150 | } |
1151 | |
1152 | /// Parse a type list. |
1153 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
1154 | return parseCommaSeparatedList( |
1155 | [&]() { return parseType(result.emplace_back()); }); |
1156 | } |
1157 | |
1158 | /// Parse an arrow followed by a type list. |
1159 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1160 | |
1161 | /// Parse an optional arrow followed by a type list. |
1162 | virtual ParseResult |
1163 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1164 | |
1165 | /// Parse a colon followed by a type. |
1166 | virtual ParseResult parseColonType(Type &result) = 0; |
1167 | |
1168 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1169 | template <typename TypeType> |
1170 | ParseResult parseColonType(TypeType &result) { |
1171 | SMLoc loc = getCurrentLocation(); |
1172 | |
1173 | // Parse any kind of type. |
1174 | Type type; |
1175 | if (parseColonType(type)) |
1176 | return failure(); |
1177 | |
1178 | // Check for the right kind of type. |
1179 | result = type.dyn_cast<TypeType>(); |
1180 | if (!result) |
1181 | return emitError(loc, "invalid kind of type specified"); |
1182 | |
1183 | return success(); |
1184 | } |
1185 | |
1186 | /// Parse a colon followed by a type list, which must have at least one type. |
1187 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1188 | |
1189 | /// Parse an optional colon followed by a type list, which if present must |
1190 | /// have at least one type. |
1191 | virtual ParseResult |
1192 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1193 | |
1194 | /// Parse a keyword followed by a type. |
1195 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1196 | return failure(parseKeyword(keyword) || parseType(result)); |
1197 | } |
1198 | |
1199 | /// Add the specified type to the end of the specified type list and return |
1200 | /// success. This is a helper designed to allow parse methods to be simple |
1201 | /// and chain through || operators. |
1202 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1203 | result.push_back(type); |
1204 | return success(); |
1205 | } |
1206 | |
1207 | /// Add the specified types to the end of the specified type list and return |
1208 | /// success. This is a helper designed to allow parse methods to be simple |
1209 | /// and chain through || operators. |
1210 | ParseResult addTypesToList(ArrayRef<Type> types, |
1211 | SmallVectorImpl<Type> &result) { |
1212 | result.append(types.begin(), types.end()); |
1213 | return success(); |
1214 | } |
1215 | |
1216 | /// Parse a dimension list of a tensor or memref type. This populates the |
1217 | /// dimension list, using ShapedType::kDynamic for the `?` dimensions if |
1218 | /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the |
1219 | /// trailing `x` is configurable. |
1220 | /// |
1221 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1222 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1223 | /// dimension ::= `?` | decimal-literal |
1224 | /// |
1225 | /// When `allowDynamic` is not set, this is used to parse: |
1226 | /// |
1227 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1228 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1229 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1230 | bool allowDynamic = true, |
1231 | bool withTrailingX = true) = 0; |
1232 | |
1233 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1234 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1235 | /// next token. |
1236 | virtual ParseResult parseXInDimensionList() = 0; |
1237 | |
1238 | protected: |
1239 | /// Parse a handle to a resource within the assembly format for the given |
1240 | /// dialect. |
1241 | virtual FailureOr<AsmDialectResourceHandle> |
1242 | parseResourceHandle(Dialect *dialect) = 0; |
1243 | |
1244 | //===--------------------------------------------------------------------===// |
1245 | // Code Completion |
1246 | //===--------------------------------------------------------------------===// |
1247 | |
1248 | /// Parse a keyword, or an empty string if the current location signals a code |
1249 | /// completion. |
1250 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1251 | |
1252 | /// Signal the code completion of a set of expected tokens. |
1253 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1254 | |
1255 | private: |
1256 | AsmParser(const AsmParser &) = delete; |
1257 | void operator=(const AsmParser &) = delete; |
1258 | }; |
1259 | |
1260 | //===----------------------------------------------------------------------===// |
1261 | // OpAsmParser |
1262 | //===----------------------------------------------------------------------===// |
1263 | |
1264 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1265 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1266 | /// that is designed to reduce/constrain syntax innovation in individual |
1267 | /// operations. |
1268 | /// |
1269 | /// For example, consider an op like this: |
1270 | /// |
1271 | /// %x = load %p[%1, %2] : memref<...> |
1272 | /// |
1273 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1274 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1275 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1276 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1277 | /// type. |
1278 | /// |
1279 | class OpAsmParser : public AsmParser { |
1280 | public: |
1281 | using AsmParser::AsmParser; |
1282 | ~OpAsmParser() override; |
1283 | |
1284 | /// Parse a loc(...) specifier if present, filling in result if so. |
1285 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1286 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1287 | /// completes. |
1288 | virtual ParseResult |
1289 | parseOptionalLocationSpecifier(std::optional<Location> &result) = 0; |
1290 | |
1291 | /// Return the name of the specified result in the specified syntax, as well |
1292 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1293 | /// invalid result numbers. For example, in this operation: |
1294 | /// |
1295 | /// %x, %y:2, %z = foo.op |
1296 | /// |
1297 | /// getResultName(0) == {"x", 0 } |
1298 | /// getResultName(1) == {"y", 0 } |
1299 | /// getResultName(2) == {"y", 1 } |
1300 | /// getResultName(3) == {"z", 0 } |
1301 | /// getResultName(4) == {"", ~0U } |
1302 | virtual std::pair<StringRef, unsigned> |
1303 | getResultName(unsigned resultNo) const = 0; |
1304 | |
1305 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1306 | /// example in the comment for `getResultName`. |
1307 | virtual size_t getNumResults() const = 0; |
1308 | |
1309 | // These methods emit an error and return failure or success. This allows |
1310 | // these to be chained together into a linear sequence of || expressions in |
1311 | // many cases. |
1312 | |
1313 | /// Parse an operation in its generic form. |
1314 | /// The parsed operation is parsed in the current context and inserted in the |
1315 | /// provided block and insertion point. The results produced by this operation |
1316 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1317 | /// failure. |
1318 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1319 | Block::iterator insertPt) = 0; |
1320 | |
1321 | /// Parse the name of an operation, in the custom form. On success, return a |
1322 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1323 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1324 | |
1325 | //===--------------------------------------------------------------------===// |
1326 | // Operand Parsing |
1327 | //===--------------------------------------------------------------------===// |
1328 | |
1329 | /// This is the representation of an operand reference. |
1330 | struct UnresolvedOperand { |
1331 | SMLoc location; // Location of the token. |
1332 | StringRef name; // Value name, e.g. %42 or %abc |
1333 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1334 | }; |
1335 | |
1336 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1337 | /// region(s), attribute(s) and function-type, of the generic form of an |
1338 | /// operation instance and populate the input operation-state 'result' with |
1339 | /// those components. If any of the components is explicitly provided, then |
1340 | /// skip parsing that component. |
1341 | virtual ParseResult parseGenericOperationAfterOpName( |
1342 | OperationState &result, |
1343 | std::optional<ArrayRef<UnresolvedOperand>> parsedOperandType = |
1344 | std::nullopt, |
1345 | std::optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, |
1346 | std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1347 | std::nullopt, |
1348 | std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, |
1349 | std::optional<FunctionType> parsedFnType = std::nullopt) = 0; |
1350 | |
1351 | /// Parse a single SSA value operand name along with a result number if |
1352 | /// `allowResultNumber` is true. |
1353 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1354 | bool allowResultNumber = true) = 0; |
1355 | |
1356 | /// Parse a single operand if present. |
1357 | virtual OptionalParseResult |
1358 | parseOptionalOperand(UnresolvedOperand &result, |
1359 | bool allowResultNumber = true) = 0; |
1360 | |
1361 | /// Parse zero or more SSA comma-separated operand references with a specified |
1362 | /// surrounding delimiter, and an optional required operand count. |
1363 | virtual ParseResult |
1364 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1365 | Delimiter delimiter = Delimiter::None, |
1366 | bool allowResultNumber = true, |
1367 | int requiredOperandCount = -1) = 0; |
1368 | |
1369 | /// Parse a specified number of comma separated operands. |
1370 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1371 | int requiredOperandCount, |
1372 | Delimiter delimiter = Delimiter::None) { |
1373 | return parseOperandList(result, delimiter, |
1374 | /*allowResultNumber=*/true, requiredOperandCount); |
1375 | } |
1376 | |
1377 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1378 | /// references with a specified surrounding delimiter, and an optional |
1379 | /// required operand count. A leading comma is expected before the |
1380 | /// operands. |
1381 | ParseResult |
1382 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1383 | Delimiter delimiter = Delimiter::None) { |
1384 | if (failed(parseOptionalComma())) |
1385 | return success(); // The comma is optional. |
1386 | return parseOperandList(result, delimiter); |
1387 | } |
1388 | |
1389 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1390 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1391 | Type type, |
1392 | SmallVectorImpl<Value> &result) = 0; |
1393 | |
1394 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1395 | /// appending the results to the list on success. This method should be used |
1396 | /// when all operands have the same type. |
1397 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1398 | ParseResult resolveOperands(Operands &&operands, Type type, |
1399 | SmallVectorImpl<Value> &result) { |
1400 | for (const UnresolvedOperand &operand : operands) |
1401 | if (resolveOperand(operand, type, result)) |
1402 | return failure(); |
1403 | return success(); |
1404 | } |
1405 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1406 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1407 | SmallVectorImpl<Value> &result) { |
1408 | return resolveOperands(std::forward<Operands>(operands), type, result); |
1409 | } |
1410 | |
1411 | /// Resolve a list of operands and a list of operand types to SSA values, |
1412 | /// emitting an error and returning failure, or appending the results |
1413 | /// to the list on success. |
1414 | template <typename Operands = ArrayRef<UnresolvedOperand>, |
1415 | typename Types = ArrayRef<Type>> |
1416 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1417 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1418 | SmallVectorImpl<Value> &result) { |
1419 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1420 | size_t typeSize = std::distance(types.begin(), types.end()); |
1421 | if (operandSize != typeSize) |
1422 | return emitError(loc) |
1423 | << operandSize << " operands present, but expected " << typeSize; |
1424 | |
1425 | for (auto [operand, type] : llvm::zip(operands, types)) |
1426 | if (resolveOperand(operand, type, result)) |
1427 | return failure(); |
1428 | return success(); |
1429 | } |
1430 | |
1431 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1432 | /// Operand values must come from single-result sources, and be valid |
1433 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1434 | virtual ParseResult |
1435 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1436 | Attribute &map, StringRef attrName, |
1437 | NamedAttrList &attrs, |
1438 | Delimiter delimiter = Delimiter::Square) = 0; |
1439 | |
1440 | /// Parses an affine expression where dims and symbols are SSA operands. |
1441 | /// Operand values must come from single-result sources, and be valid |
1442 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1443 | virtual ParseResult |
1444 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1445 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1446 | AffineExpr &expr) = 0; |
1447 | |
1448 | //===--------------------------------------------------------------------===// |
1449 | // Argument Parsing |
1450 | //===--------------------------------------------------------------------===// |
1451 | |
1452 | struct Argument { |
1453 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1454 | Type type; // Type. |
1455 | DictionaryAttr attrs; // Attributes if present. |
1456 | std::optional<Location> sourceLoc; // Source location specifier if present. |
1457 | }; |
1458 | |
1459 | /// Parse a single argument with the following syntax: |
1460 | /// |
1461 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1462 | /// |
1463 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1464 | /// parts of the grammar are not parsed. |
1465 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1466 | bool allowAttrs = false) = 0; |
1467 | |
1468 | /// Parse a single argument if present. |
1469 | virtual OptionalParseResult |
1470 | parseOptionalArgument(Argument &result, bool allowType = false, |
1471 | bool allowAttrs = false) = 0; |
1472 | |
1473 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1474 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1475 | Delimiter delimiter = Delimiter::None, |
1476 | bool allowType = false, |
1477 | bool allowAttrs = false) = 0; |
1478 | |
1479 | //===--------------------------------------------------------------------===// |
1480 | // Region Parsing |
1481 | //===--------------------------------------------------------------------===// |
1482 | |
1483 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1484 | /// moved to the op regions after the op is created. The first block of the |
1485 | /// region takes 'arguments'. |
1486 | /// |
1487 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1488 | /// shadow the names of other existing SSA values defined above the region |
1489 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1490 | /// to operations that are 'IsolatedFromAbove'. |
1491 | virtual ParseResult parseRegion(Region ®ion, |
1492 | ArrayRef<Argument> arguments = {}, |
1493 | bool enableNameShadowing = false) = 0; |
1494 | |
1495 | /// Parses a region if present. |
1496 | virtual OptionalParseResult |
1497 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1498 | bool enableNameShadowing = false) = 0; |
1499 | |
1500 | /// Parses a region if present. If the region is present, a new region is |
1501 | /// allocated and placed in `region`. If no region is present or on failure, |
1502 | /// `region` remains untouched. |
1503 | virtual OptionalParseResult |
1504 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1505 | ArrayRef<Argument> arguments = {}, |
1506 | bool enableNameShadowing = false) = 0; |
1507 | |
1508 | //===--------------------------------------------------------------------===// |
1509 | // Successor Parsing |
1510 | //===--------------------------------------------------------------------===// |
1511 | |
1512 | /// Parse a single operation successor. |
1513 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1514 | |
1515 | /// Parse an optional operation successor. |
1516 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1517 | |
1518 | /// Parse a single operation successor and its operand list. |
1519 | virtual ParseResult |
1520 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1521 | |
1522 | //===--------------------------------------------------------------------===// |
1523 | // Type Parsing |
1524 | //===--------------------------------------------------------------------===// |
1525 | |
1526 | /// Parse a list of assignments of the form |
1527 | /// (%x1 = %y1, %x2 = %y2, ...) |
1528 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1529 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1530 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1531 | if (!result.has_value()) |
1532 | return emitError(getCurrentLocation(), "expected '('"); |
1533 | return result.value(); |
1534 | } |
1535 | |
1536 | virtual OptionalParseResult |
1537 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1538 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1539 | }; |
1540 | |
1541 | //===--------------------------------------------------------------------===// |
1542 | // Dialect OpAsm interface. |
1543 | //===--------------------------------------------------------------------===// |
1544 | |
1545 | /// A functor used to set the name of the start of a result group of an |
1546 | /// operation. See 'getAsmResultNames' below for more details. |
1547 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1548 | |
1549 | /// A functor used to set the name of blocks in regions directly nested under |
1550 | /// an operation. |
1551 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1552 | |
1553 | class OpAsmDialectInterface |
1554 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1555 | public: |
1556 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1557 | |
1558 | //===------------------------------------------------------------------===// |
1559 | // Aliases |
1560 | //===------------------------------------------------------------------===// |
1561 | |
1562 | /// Holds the result of `getAlias` hook call. |
1563 | enum class AliasResult { |
1564 | /// The object (type or attribute) is not supported by the hook |
1565 | /// and an alias was not provided. |
1566 | NoAlias, |
1567 | /// An alias was provided, but it might be overriden by other hook. |
1568 | OverridableAlias, |
1569 | /// An alias was provided and it should be used |
1570 | /// (no other hooks will be checked). |
1571 | FinalAlias |
1572 | }; |
1573 | |
1574 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1575 | /// not necessarily a part of this dialect. The identifier is used in place of |
1576 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1577 | /// end with a numeric digit([0-9]+). |
1578 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1579 | return AliasResult::NoAlias; |
1580 | } |
1581 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1582 | return AliasResult::NoAlias; |
1583 | } |
1584 | |
1585 | //===--------------------------------------------------------------------===// |
1586 | // Resources |
1587 | //===--------------------------------------------------------------------===// |
1588 | |
1589 | /// Declare a resource with the given key, returning a handle to use for any |
1590 | /// references of this resource key within the IR during parsing. The result |
1591 | /// of `getResourceKey` on the returned handle is permitted to be different |
1592 | /// than `key`. |
1593 | virtual FailureOr<AsmDialectResourceHandle> |
1594 | declareResource(StringRef key) const { |
1595 | return failure(); |
1596 | } |
1597 | |
1598 | /// Return a key to use for the given resource. This key should uniquely |
1599 | /// identify this resource within the dialect. |
1600 | virtual std::string |
1601 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1602 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1603) |
1603 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1603); |
1604 | } |
1605 | |
1606 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1607 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1608 | /// can be emitted via the provided entry. |
1609 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1610 | |
1611 | /// Hook for building resources to use during printing. The given `op` may be |
1612 | /// inspected to help determine what information to include. |
1613 | /// `referencedResources` contains all of the resources detected when printing |
1614 | /// 'op'. |
1615 | virtual void |
1616 | buildResources(Operation *op, |
1617 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1618 | AsmResourceBuilder &builder) const {} |
1619 | }; |
1620 | } // namespace mlir |
1621 | |
1622 | //===--------------------------------------------------------------------===// |
1623 | // Operation OpAsm interface. |
1624 | //===--------------------------------------------------------------------===// |
1625 | |
1626 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1627 | #include "mlir/IR/OpAsmInterface.h.inc" |
1628 | |
1629 | namespace llvm { |
1630 | template <> |
1631 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1632 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1633 | return {DenseMapInfo<void *>::getEmptyKey(), |
1634 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1635 | } |
1636 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1637 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1638 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1639 | } |
1640 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1641 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1642 | } |
1643 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1644 | const mlir::AsmDialectResourceHandle &rhs) { |
1645 | return lhs.getResource() == rhs.getResource(); |
1646 | } |
1647 | }; |
1648 | } // namespace llvm |
1649 | |
1650 | #endif |