File: | build/source/mlir/include/mlir/IR/OpImplementation.h |
Warning: | line 839, column 12 6th function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | |||||
9 | #include "mlir/Dialect/Quant/QuantOps.h" | ||||
10 | #include "mlir/Dialect/Quant/QuantTypes.h" | ||||
11 | #include "mlir/IR/BuiltinTypes.h" | ||||
12 | #include "mlir/IR/DialectImplementation.h" | ||||
13 | #include "mlir/IR/Location.h" | ||||
14 | #include "mlir/IR/Types.h" | ||||
15 | #include "llvm/ADT/APFloat.h" | ||||
16 | #include "llvm/Support/Format.h" | ||||
17 | #include "llvm/Support/MathExtras.h" | ||||
18 | #include "llvm/Support/SourceMgr.h" | ||||
19 | #include "llvm/Support/raw_ostream.h" | ||||
20 | |||||
21 | using namespace mlir; | ||||
22 | using namespace quant; | ||||
23 | |||||
24 | static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { | ||||
25 | auto typeLoc = parser.getCurrentLocation(); | ||||
26 | IntegerType type; | ||||
27 | |||||
28 | // Parse storage type (alpha_ident, integer_literal). | ||||
29 | StringRef identifier; | ||||
30 | unsigned storageTypeWidth = 0; | ||||
31 | OptionalParseResult result = parser.parseOptionalType(type); | ||||
32 | if (result.has_value()) { | ||||
33 | if (!succeeded(*result)) | ||||
34 | return nullptr; | ||||
35 | isSigned = !type.isUnsigned(); | ||||
36 | storageTypeWidth = type.getWidth(); | ||||
37 | } else if (succeeded(parser.parseKeyword(&identifier))) { | ||||
38 | // Otherwise, this must be an unsigned integer (`u` integer-literal). | ||||
39 | if (!identifier.consume_front("u")) { | ||||
40 | parser.emitError(typeLoc, "illegal storage type prefix"); | ||||
41 | return nullptr; | ||||
42 | } | ||||
43 | if (identifier.getAsInteger(10, storageTypeWidth)) { | ||||
44 | parser.emitError(typeLoc, "expected storage type width"); | ||||
45 | return nullptr; | ||||
46 | } | ||||
47 | isSigned = false; | ||||
48 | type = parser.getBuilder().getIntegerType(storageTypeWidth); | ||||
49 | } else { | ||||
50 | return nullptr; | ||||
51 | } | ||||
52 | |||||
53 | if (storageTypeWidth == 0 || | ||||
54 | storageTypeWidth > QuantizedType::MaxStorageBits) { | ||||
55 | parser.emitError(typeLoc, "illegal storage type size: ") | ||||
56 | << storageTypeWidth; | ||||
57 | return nullptr; | ||||
58 | } | ||||
59 | |||||
60 | return type; | ||||
61 | } | ||||
62 | |||||
63 | static ParseResult parseStorageRange(DialectAsmParser &parser, | ||||
64 | IntegerType storageType, bool isSigned, | ||||
65 | int64_t &storageTypeMin, | ||||
66 | int64_t &storageTypeMax) { | ||||
67 | int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( | ||||
68 | isSigned, storageType.getWidth()); | ||||
69 | int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( | ||||
70 | isSigned, storageType.getWidth()); | ||||
71 | if (failed(parser.parseOptionalLess())) { | ||||
72 | storageTypeMin = defaultIntegerMin; | ||||
73 | storageTypeMax = defaultIntegerMax; | ||||
74 | return success(); | ||||
75 | } | ||||
76 | |||||
77 | // Explicit storage min and storage max. | ||||
78 | SMLoc minLoc = parser.getCurrentLocation(), maxLoc; | ||||
79 | if (parser.parseInteger(storageTypeMin) || parser.parseColon() || | ||||
80 | parser.getCurrentLocation(&maxLoc) || | ||||
81 | parser.parseInteger(storageTypeMax) || parser.parseGreater()) | ||||
82 | return failure(); | ||||
83 | if (storageTypeMin < defaultIntegerMin) { | ||||
84 | return parser.emitError(minLoc, "illegal storage type minimum: ") | ||||
85 | << storageTypeMin; | ||||
86 | } | ||||
87 | if (storageTypeMax > defaultIntegerMax) { | ||||
88 | return parser.emitError(maxLoc, "illegal storage type maximum: ") | ||||
89 | << storageTypeMax; | ||||
90 | } | ||||
91 | return success(); | ||||
92 | } | ||||
93 | |||||
94 | static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, | ||||
95 | double &min, double &max) { | ||||
96 | auto typeLoc = parser.getCurrentLocation(); | ||||
97 | FloatType type; | ||||
98 | |||||
99 | if (failed(parser.parseType(type))) { | ||||
100 | parser.emitError(typeLoc, "expecting float expressed type"); | ||||
101 | return nullptr; | ||||
102 | } | ||||
103 | |||||
104 | // Calibrated min and max values. | ||||
105 | if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() || | ||||
106 | parser.parseFloat(max) || parser.parseGreater()) { | ||||
107 | parser.emitError(typeLoc, "calibrated values must be present"); | ||||
108 | return nullptr; | ||||
109 | } | ||||
110 | return type; | ||||
111 | } | ||||
112 | |||||
113 | /// Parses an AnyQuantizedType. | ||||
114 | /// | ||||
115 | /// any ::= `any<` storage-spec (expressed-type-spec)?`>` | ||||
116 | /// storage-spec ::= storage-type (`<` storage-range `>`)? | ||||
117 | /// storage-range ::= integer-literal `:` integer-literal | ||||
118 | /// storage-type ::= (`i` | `u`) integer-literal | ||||
119 | /// expressed-type-spec ::= `:` `f` integer-literal | ||||
120 | static Type parseAnyType(DialectAsmParser &parser) { | ||||
121 | IntegerType storageType; | ||||
122 | FloatType expressedType; | ||||
123 | unsigned typeFlags = 0; | ||||
124 | int64_t storageTypeMin; | ||||
125 | int64_t storageTypeMax; | ||||
126 | |||||
127 | // Type specification. | ||||
128 | if (parser.parseLess()) | ||||
129 | return nullptr; | ||||
130 | |||||
131 | // Storage type. | ||||
132 | bool isSigned = false; | ||||
133 | storageType = parseStorageType(parser, isSigned); | ||||
134 | if (!storageType) { | ||||
135 | return nullptr; | ||||
136 | } | ||||
137 | if (isSigned
| ||||
138 | typeFlags |= QuantizationFlags::Signed; | ||||
139 | } | ||||
140 | |||||
141 | // Storage type range. | ||||
142 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, | ||||
143 | storageTypeMax)) { | ||||
144 | return nullptr; | ||||
145 | } | ||||
146 | |||||
147 | // Optional expressed type. | ||||
148 | if (succeeded(parser.parseOptionalColon())) { | ||||
149 | if (parser.parseType(expressedType)) { | ||||
150 | return nullptr; | ||||
151 | } | ||||
152 | } | ||||
153 | |||||
154 | if (parser.parseGreater()) { | ||||
155 | return nullptr; | ||||
156 | } | ||||
157 | |||||
158 | return parser.getChecked<AnyQuantizedType>( | ||||
159 | typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); | ||||
160 | } | ||||
161 | |||||
162 | static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, | ||||
163 | int64_t &zeroPoint) { | ||||
164 | // scale[:zeroPoint]? | ||||
165 | // scale. | ||||
166 | if (parser.parseFloat(scale)) | ||||
167 | return failure(); | ||||
168 | |||||
169 | // zero point. | ||||
170 | zeroPoint = 0; | ||||
171 | if (failed(parser.parseOptionalColon())) { | ||||
172 | // Default zero point. | ||||
173 | return success(); | ||||
174 | } | ||||
175 | |||||
176 | return parser.parseInteger(zeroPoint); | ||||
177 | } | ||||
178 | |||||
179 | /// Parses a UniformQuantizedType. | ||||
180 | /// | ||||
181 | /// uniform_type ::= uniform_per_layer | ||||
182 | /// | uniform_per_axis | ||||
183 | /// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec | ||||
184 | /// `,` scale-zero `>` | ||||
185 | /// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec | ||||
186 | /// axis-spec `,` scale-zero-list `>` | ||||
187 | /// storage-spec ::= storage-type (`<` storage-range `>`)? | ||||
188 | /// storage-range ::= integer-literal `:` integer-literal | ||||
189 | /// storage-type ::= (`i` | `u`) integer-literal | ||||
190 | /// expressed-type-spec ::= `:` `f` integer-literal | ||||
191 | /// axis-spec ::= `:` integer-literal | ||||
192 | /// scale-zero ::= float-literal `:` integer-literal | ||||
193 | /// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}` | ||||
194 | static Type parseUniformType(DialectAsmParser &parser) { | ||||
195 | IntegerType storageType; | ||||
196 | FloatType expressedType; | ||||
197 | unsigned typeFlags = 0; | ||||
198 | int64_t storageTypeMin; | ||||
199 | int64_t storageTypeMax; | ||||
200 | bool isPerAxis = false; | ||||
201 | int32_t quantizedDimension; | ||||
202 | SmallVector<double, 1> scales; | ||||
203 | SmallVector<int64_t, 1> zeroPoints; | ||||
204 | |||||
205 | // Type specification. | ||||
206 | if (parser.parseLess()) { | ||||
207 | return nullptr; | ||||
208 | } | ||||
209 | |||||
210 | // Storage type. | ||||
211 | bool isSigned = false; | ||||
212 | storageType = parseStorageType(parser, isSigned); | ||||
213 | if (!storageType) { | ||||
214 | return nullptr; | ||||
215 | } | ||||
216 | if (isSigned) { | ||||
217 | typeFlags |= QuantizationFlags::Signed; | ||||
218 | } | ||||
219 | |||||
220 | // Storage type range. | ||||
221 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, | ||||
222 | storageTypeMax)) { | ||||
223 | return nullptr; | ||||
224 | } | ||||
225 | |||||
226 | // Expressed type. | ||||
227 | if (parser.parseColon() || parser.parseType(expressedType)) { | ||||
228 | return nullptr; | ||||
229 | } | ||||
230 | |||||
231 | // Optionally parse quantized dimension for per-axis quantization. | ||||
232 | if (succeeded(parser.parseOptionalColon())) { | ||||
233 | if (parser.parseInteger(quantizedDimension)) | ||||
234 | return nullptr; | ||||
235 | isPerAxis = true; | ||||
236 | } | ||||
237 | |||||
238 | // Comma leading into range_spec. | ||||
239 | if (parser.parseComma()) { | ||||
240 | return nullptr; | ||||
241 | } | ||||
242 | |||||
243 | // Parameter specification. | ||||
244 | // For per-axis, ranges are in a {} delimitted list. | ||||
245 | if (isPerAxis) { | ||||
246 | if (parser.parseLBrace()) { | ||||
247 | return nullptr; | ||||
248 | } | ||||
249 | } | ||||
250 | |||||
251 | // Parse scales/zeroPoints. | ||||
252 | SMLoc scaleZPLoc = parser.getCurrentLocation(); | ||||
253 | do { | ||||
254 | scales.resize(scales.size() + 1); | ||||
255 | zeroPoints.resize(zeroPoints.size() + 1); | ||||
256 | if (parseQuantParams(parser, scales.back(), zeroPoints.back())) { | ||||
257 | return nullptr; | ||||
258 | } | ||||
259 | } while (isPerAxis && succeeded(parser.parseOptionalComma())); | ||||
260 | |||||
261 | if (isPerAxis) { | ||||
262 | if (parser.parseRBrace()) { | ||||
263 | return nullptr; | ||||
264 | } | ||||
265 | } | ||||
266 | |||||
267 | if (parser.parseGreater()) { | ||||
268 | return nullptr; | ||||
269 | } | ||||
270 | |||||
271 | if (!isPerAxis && scales.size() > 1) { | ||||
272 | return (parser.emitError(scaleZPLoc, | ||||
273 | "multiple scales/zeroPoints provided, but " | ||||
274 | "quantizedDimension wasn't specified"), | ||||
275 | nullptr); | ||||
276 | } | ||||
277 | |||||
278 | if (isPerAxis) { | ||||
279 | ArrayRef<double> scalesRef(scales.begin(), scales.end()); | ||||
280 | ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); | ||||
281 | return parser.getChecked<UniformQuantizedPerAxisType>( | ||||
282 | typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, | ||||
283 | quantizedDimension, storageTypeMin, storageTypeMax); | ||||
284 | } | ||||
285 | |||||
286 | return parser.getChecked<UniformQuantizedType>( | ||||
287 | typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), | ||||
288 | storageTypeMin, storageTypeMax); | ||||
289 | } | ||||
290 | |||||
291 | /// Parses an CalibratedQuantizedType. | ||||
292 | /// | ||||
293 | /// calibrated ::= `calibrated<` expressed-spec `>` | ||||
294 | /// expressed-spec ::= expressed-type `<` calibrated-range `>` | ||||
295 | /// expressed-type ::= `f` integer-literal | ||||
296 | /// calibrated-range ::= float-literal `:` float-literal | ||||
297 | static Type parseCalibratedType(DialectAsmParser &parser) { | ||||
298 | FloatType expressedType; | ||||
299 | double min; | ||||
300 | double max; | ||||
301 | |||||
302 | // Type specification. | ||||
303 | if (parser.parseLess()) | ||||
304 | return nullptr; | ||||
305 | |||||
306 | // Expressed type. | ||||
307 | expressedType = parseExpressedTypeAndRange(parser, min, max); | ||||
308 | if (!expressedType) { | ||||
309 | return nullptr; | ||||
310 | } | ||||
311 | |||||
312 | if (parser.parseGreater()) { | ||||
313 | return nullptr; | ||||
314 | } | ||||
315 | |||||
316 | return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max); | ||||
317 | } | ||||
318 | |||||
319 | /// Parse a type registered to this dialect. | ||||
320 | Type QuantizationDialect::parseType(DialectAsmParser &parser) const { | ||||
321 | // All types start with an identifier that we switch on. | ||||
322 | StringRef typeNameSpelling; | ||||
323 | if (failed(parser.parseKeyword(&typeNameSpelling))) | ||||
| |||||
324 | return nullptr; | ||||
325 | |||||
326 | if (typeNameSpelling == "uniform") | ||||
327 | return parseUniformType(parser); | ||||
328 | if (typeNameSpelling == "any") | ||||
329 | return parseAnyType(parser); | ||||
330 | if (typeNameSpelling == "calibrated") | ||||
331 | return parseCalibratedType(parser); | ||||
332 | |||||
333 | parser.emitError(parser.getNameLoc(), | ||||
334 | "unknown quantized type " + typeNameSpelling); | ||||
335 | return nullptr; | ||||
336 | } | ||||
337 | |||||
338 | static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { | ||||
339 | // storage type | ||||
340 | unsigned storageWidth = type.getStorageTypeIntegralWidth(); | ||||
341 | bool isSigned = type.isSigned(); | ||||
342 | if (isSigned) { | ||||
343 | out << "i" << storageWidth; | ||||
344 | } else { | ||||
345 | out << "u" << storageWidth; | ||||
346 | } | ||||
347 | |||||
348 | // storageTypeMin and storageTypeMax if not default. | ||||
349 | int64_t defaultIntegerMin = | ||||
350 | QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth); | ||||
351 | int64_t defaultIntegerMax = | ||||
352 | QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth); | ||||
353 | if (defaultIntegerMin != type.getStorageTypeMin() || | ||||
354 | defaultIntegerMax != type.getStorageTypeMax()) { | ||||
355 | out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() | ||||
356 | << ">"; | ||||
357 | } | ||||
358 | } | ||||
359 | |||||
360 | static void printQuantParams(double scale, int64_t zeroPoint, | ||||
361 | DialectAsmPrinter &out) { | ||||
362 | out << scale; | ||||
363 | if (zeroPoint != 0) { | ||||
364 | out << ":" << zeroPoint; | ||||
365 | } | ||||
366 | } | ||||
367 | |||||
368 | /// Helper that prints a AnyQuantizedType. | ||||
369 | static void printAnyQuantizedType(AnyQuantizedType type, | ||||
370 | DialectAsmPrinter &out) { | ||||
371 | out << "any<"; | ||||
372 | printStorageType(type, out); | ||||
373 | if (Type expressedType = type.getExpressedType()) { | ||||
374 | out << ":" << expressedType; | ||||
375 | } | ||||
376 | out << ">"; | ||||
377 | } | ||||
378 | |||||
379 | /// Helper that prints a UniformQuantizedType. | ||||
380 | static void printUniformQuantizedType(UniformQuantizedType type, | ||||
381 | DialectAsmPrinter &out) { | ||||
382 | out << "uniform<"; | ||||
383 | printStorageType(type, out); | ||||
384 | out << ":" << type.getExpressedType() << ", "; | ||||
385 | |||||
386 | // scheme specific parameters | ||||
387 | printQuantParams(type.getScale(), type.getZeroPoint(), out); | ||||
388 | out << ">"; | ||||
389 | } | ||||
390 | |||||
391 | /// Helper that prints a UniformQuantizedPerAxisType. | ||||
392 | static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, | ||||
393 | DialectAsmPrinter &out) { | ||||
394 | out << "uniform<"; | ||||
395 | printStorageType(type, out); | ||||
396 | out << ":" << type.getExpressedType() << ":"; | ||||
397 | out << type.getQuantizedDimension(); | ||||
398 | out << ", "; | ||||
399 | |||||
400 | // scheme specific parameters | ||||
401 | ArrayRef<double> scales = type.getScales(); | ||||
402 | ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); | ||||
403 | out << "{"; | ||||
404 | llvm::interleave( | ||||
405 | llvm::seq<size_t>(0, scales.size()), out, | ||||
406 | [&](size_t index) { | ||||
407 | printQuantParams(scales[index], zeroPoints[index], out); | ||||
408 | }, | ||||
409 | ","); | ||||
410 | out << "}>"; | ||||
411 | } | ||||
412 | |||||
413 | /// Helper that prints a CalibratedQuantizedType. | ||||
414 | static void printCalibratedQuantizedType(CalibratedQuantizedType type, | ||||
415 | DialectAsmPrinter &out) { | ||||
416 | out << "calibrated<" << type.getExpressedType(); | ||||
417 | out << "<" << type.getMin() << ":" << type.getMax() << ">"; | ||||
418 | out << ">"; | ||||
419 | } | ||||
420 | |||||
421 | /// Print a type registered to this dialect. | ||||
422 | void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { | ||||
423 | if (auto anyType = type.dyn_cast<AnyQuantizedType>()) | ||||
424 | printAnyQuantizedType(anyType, os); | ||||
425 | else if (auto uniformType = type.dyn_cast<UniformQuantizedType>()) | ||||
426 | printUniformQuantizedType(uniformType, os); | ||||
427 | else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>()) | ||||
428 | printUniformQuantizedPerAxisType(perAxisType, os); | ||||
429 | else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>()) | ||||
430 | printCalibratedQuantizedType(calibratedType, os); | ||||
431 | else | ||||
432 | llvm_unreachable("Unhandled quantized type")::llvm::llvm_unreachable_internal("Unhandled quantized type", "mlir/lib/Dialect/Quant/IR/TypeParser.cpp", 432); | ||||
433 | } |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This classes used by the implementation details of Op types. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H | |||
14 | #define MLIR_IR_OPIMPLEMENTATION_H | |||
15 | ||||
16 | #include "mlir/IR/BuiltinTypes.h" | |||
17 | #include "mlir/IR/DialectInterface.h" | |||
18 | #include "mlir/IR/OpDefinition.h" | |||
19 | #include "llvm/ADT/Twine.h" | |||
20 | #include "llvm/Support/SMLoc.h" | |||
21 | #include <optional> | |||
22 | ||||
23 | namespace mlir { | |||
24 | class AsmParsedResourceEntry; | |||
25 | class AsmResourceBuilder; | |||
26 | class Builder; | |||
27 | ||||
28 | //===----------------------------------------------------------------------===// | |||
29 | // AsmDialectResourceHandle | |||
30 | //===----------------------------------------------------------------------===// | |||
31 | ||||
32 | /// This class represents an opaque handle to a dialect resource entry. | |||
33 | class AsmDialectResourceHandle { | |||
34 | public: | |||
35 | AsmDialectResourceHandle() = default; | |||
36 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) | |||
37 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} | |||
38 | bool operator==(const AsmDialectResourceHandle &other) const { | |||
39 | return resource == other.resource; | |||
40 | } | |||
41 | ||||
42 | /// Return an opaque pointer to the referenced resource. | |||
43 | void *getResource() const { return resource; } | |||
44 | ||||
45 | /// Return the type ID of the resource. | |||
46 | TypeID getTypeID() const { return opaqueID; } | |||
47 | ||||
48 | /// Return the dialect that owns the resource. | |||
49 | Dialect *getDialect() const { return dialect; } | |||
50 | ||||
51 | private: | |||
52 | /// The opaque handle to the dialect resource. | |||
53 | void *resource = nullptr; | |||
54 | /// The type of the resource referenced. | |||
55 | TypeID opaqueID; | |||
56 | /// The dialect owning the given resource. | |||
57 | Dialect *dialect; | |||
58 | }; | |||
59 | ||||
60 | /// This class represents a CRTP base class for dialect resource handles. It | |||
61 | /// abstracts away various utilities necessary for defined derived resource | |||
62 | /// handles. | |||
63 | template <typename DerivedT, typename ResourceT, typename DialectT> | |||
64 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { | |||
65 | public: | |||
66 | using Dialect = DialectT; | |||
67 | ||||
68 | /// Construct a handle from a pointer to the resource. The given pointer | |||
69 | /// should be guaranteed to live beyond the life of this handle. | |||
70 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) | |||
71 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} | |||
72 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) | |||
73 | : AsmDialectResourceHandle(handle) { | |||
74 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 74, __extension__ __PRETTY_FUNCTION__)); | |||
75 | } | |||
76 | ||||
77 | /// Return the resource referenced by this handle. | |||
78 | ResourceT *getResource() { | |||
79 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); | |||
80 | } | |||
81 | const ResourceT *getResource() const { | |||
82 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); | |||
83 | } | |||
84 | ||||
85 | /// Return the dialect that owns the resource. | |||
86 | DialectT *getDialect() const { | |||
87 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); | |||
88 | } | |||
89 | ||||
90 | /// Support llvm style casting. | |||
91 | static bool classof(const AsmDialectResourceHandle *handle) { | |||
92 | return handle->getTypeID() == TypeID::get<DerivedT>(); | |||
93 | } | |||
94 | }; | |||
95 | ||||
96 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { | |||
97 | return llvm::hash_value(param.getResource()); | |||
98 | } | |||
99 | ||||
100 | //===----------------------------------------------------------------------===// | |||
101 | // AsmPrinter | |||
102 | //===----------------------------------------------------------------------===// | |||
103 | ||||
104 | /// This base class exposes generic asm printer hooks, usable across the various | |||
105 | /// derived printers. | |||
106 | class AsmPrinter { | |||
107 | public: | |||
108 | /// This class contains the internal default implementation of the base | |||
109 | /// printer methods. | |||
110 | class Impl; | |||
111 | ||||
112 | /// Initialize the printer with the given internal implementation. | |||
113 | AsmPrinter(Impl &impl) : impl(&impl) {} | |||
114 | virtual ~AsmPrinter(); | |||
115 | ||||
116 | /// Return the raw output stream used by this printer. | |||
117 | virtual raw_ostream &getStream() const; | |||
118 | ||||
119 | /// Print the given floating point value in a stabilized form that can be | |||
120 | /// roundtripped through the IR. This is the companion to the 'parseFloat' | |||
121 | /// hook on the AsmParser. | |||
122 | virtual void printFloat(const APFloat &value); | |||
123 | ||||
124 | virtual void printType(Type type); | |||
125 | virtual void printAttribute(Attribute attr); | |||
126 | ||||
127 | /// Trait to check if `AttrType` provides a `print` method. | |||
128 | template <typename AttrOrType> | |||
129 | using has_print_method = | |||
130 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); | |||
131 | template <typename AttrOrType> | |||
132 | using detect_has_print_method = | |||
133 | llvm::is_detected<has_print_method, AttrOrType>; | |||
134 | ||||
135 | /// Print the provided attribute in the context of an operation custom | |||
136 | /// printer/parser: this will invoke directly the print method on the | |||
137 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. | |||
138 | template <typename AttrOrType, | |||
139 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> | |||
140 | *sfinae = nullptr> | |||
141 | void printStrippedAttrOrType(AttrOrType attrOrType) { | |||
142 | if (succeeded(printAlias(attrOrType))) | |||
143 | return; | |||
144 | ||||
145 | raw_ostream &os = getStream(); | |||
146 | uint64_t posPrior = os.tell(); | |||
147 | attrOrType.print(*this); | |||
148 | if (posPrior != os.tell()) | |||
149 | return; | |||
150 | ||||
151 | // Fallback to printing with prefix if the above failed to write anything | |||
152 | // to the output stream. | |||
153 | *this << attrOrType; | |||
154 | } | |||
155 | ||||
156 | /// Print the provided array of attributes or types in the context of an | |||
157 | /// operation custom printer/parser: this will invoke directly the print | |||
158 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in | |||
159 | /// most cases. | |||
160 | template <typename AttrOrType, | |||
161 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> | |||
162 | *sfinae = nullptr> | |||
163 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { | |||
164 | llvm::interleaveComma( | |||
165 | attrOrTypes, getStream(), | |||
166 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); | |||
167 | } | |||
168 | ||||
169 | /// SFINAE for printing the provided attribute in the context of an operation | |||
170 | /// custom printer in the case where the attribute does not define a print | |||
171 | /// method. | |||
172 | template <typename AttrOrType, | |||
173 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> | |||
174 | *sfinae = nullptr> | |||
175 | void printStrippedAttrOrType(AttrOrType attrOrType) { | |||
176 | *this << attrOrType; | |||
177 | } | |||
178 | ||||
179 | /// Print the given attribute without its type. The corresponding parser must | |||
180 | /// provide a valid type for the attribute. | |||
181 | virtual void printAttributeWithoutType(Attribute attr); | |||
182 | ||||
183 | /// Print the given string as a keyword, or a quoted and escaped string if it | |||
184 | /// has any special or non-printable characters in it. | |||
185 | virtual void printKeywordOrString(StringRef keyword); | |||
186 | ||||
187 | /// Print the given string as a symbol reference, i.e. a form representable by | |||
188 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed | |||
189 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any | |||
190 | /// special or non-printable characters in it. | |||
191 | virtual void printSymbolName(StringRef symbolRef); | |||
192 | ||||
193 | /// Print a handle to the given dialect resource. | |||
194 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); | |||
195 | ||||
196 | /// Print an optional arrow followed by a type list. | |||
197 | template <typename TypeRange> | |||
198 | void printOptionalArrowTypeList(TypeRange &&types) { | |||
199 | if (types.begin() != types.end()) | |||
200 | printArrowTypeList(types); | |||
201 | } | |||
202 | template <typename TypeRange> | |||
203 | void printArrowTypeList(TypeRange &&types) { | |||
204 | auto &os = getStream() << " -> "; | |||
205 | ||||
206 | bool wrapped = !llvm::hasSingleElement(types) || | |||
207 | (*types.begin()).template isa<FunctionType>(); | |||
208 | if (wrapped) | |||
209 | os << '('; | |||
210 | llvm::interleaveComma(types, *this); | |||
211 | if (wrapped) | |||
212 | os << ')'; | |||
213 | } | |||
214 | ||||
215 | /// Print the two given type ranges in a functional form. | |||
216 | template <typename InputRangeT, typename ResultRangeT> | |||
217 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { | |||
218 | auto &os = getStream(); | |||
219 | os << '('; | |||
220 | llvm::interleaveComma(inputs, *this); | |||
221 | os << ')'; | |||
222 | printArrowTypeList(results); | |||
223 | } | |||
224 | ||||
225 | protected: | |||
226 | /// Initialize the printer with no internal implementation. In this case, all | |||
227 | /// virtual methods of this class must be overriden. | |||
228 | AsmPrinter() = default; | |||
229 | ||||
230 | private: | |||
231 | AsmPrinter(const AsmPrinter &) = delete; | |||
232 | void operator=(const AsmPrinter &) = delete; | |||
233 | ||||
234 | /// Print the alias for the given attribute, return failure if no alias could | |||
235 | /// be printed. | |||
236 | virtual LogicalResult printAlias(Attribute attr); | |||
237 | ||||
238 | /// Print the alias for the given type, return failure if no alias could | |||
239 | /// be printed. | |||
240 | virtual LogicalResult printAlias(Type type); | |||
241 | ||||
242 | /// The internal implementation of the printer. | |||
243 | Impl *impl{nullptr}; | |||
244 | }; | |||
245 | ||||
246 | template <typename AsmPrinterT> | |||
247 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
248 | AsmPrinterT &> | |||
249 | operator<<(AsmPrinterT &p, Type type) { | |||
250 | p.printType(type); | |||
251 | return p; | |||
252 | } | |||
253 | ||||
254 | template <typename AsmPrinterT> | |||
255 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
256 | AsmPrinterT &> | |||
257 | operator<<(AsmPrinterT &p, Attribute attr) { | |||
258 | p.printAttribute(attr); | |||
259 | return p; | |||
260 | } | |||
261 | ||||
262 | template <typename AsmPrinterT> | |||
263 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
264 | AsmPrinterT &> | |||
265 | operator<<(AsmPrinterT &p, const APFloat &value) { | |||
266 | p.printFloat(value); | |||
267 | return p; | |||
268 | } | |||
269 | template <typename AsmPrinterT> | |||
270 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
271 | AsmPrinterT &> | |||
272 | operator<<(AsmPrinterT &p, float value) { | |||
273 | return p << APFloat(value); | |||
274 | } | |||
275 | template <typename AsmPrinterT> | |||
276 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
277 | AsmPrinterT &> | |||
278 | operator<<(AsmPrinterT &p, double value) { | |||
279 | return p << APFloat(value); | |||
280 | } | |||
281 | ||||
282 | // Support printing anything that isn't convertible to one of the other | |||
283 | // streamable types, even if it isn't exactly one of them. For example, we want | |||
284 | // to print FunctionType with the Type version above, not have it match this. | |||
285 | template <typename AsmPrinterT, typename T, | |||
286 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && | |||
287 | !std::is_convertible<T &, Type &>::value && | |||
288 | !std::is_convertible<T &, Attribute &>::value && | |||
289 | !std::is_convertible<T &, ValueRange>::value && | |||
290 | !std::is_convertible<T &, APFloat &>::value && | |||
291 | !llvm::is_one_of<T, bool, float, double>::value, | |||
292 | T> * = nullptr> | |||
293 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
294 | AsmPrinterT &> | |||
295 | operator<<(AsmPrinterT &p, const T &other) { | |||
296 | p.getStream() << other; | |||
297 | return p; | |||
298 | } | |||
299 | ||||
300 | template <typename AsmPrinterT> | |||
301 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
302 | AsmPrinterT &> | |||
303 | operator<<(AsmPrinterT &p, bool value) { | |||
304 | return p << (value ? StringRef("true") : "false"); | |||
305 | } | |||
306 | ||||
307 | template <typename AsmPrinterT, typename ValueRangeT> | |||
308 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
309 | AsmPrinterT &> | |||
310 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { | |||
311 | llvm::interleaveComma(types, p); | |||
312 | return p; | |||
313 | } | |||
314 | ||||
315 | template <typename AsmPrinterT> | |||
316 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
317 | AsmPrinterT &> | |||
318 | operator<<(AsmPrinterT &p, const TypeRange &types) { | |||
319 | llvm::interleaveComma(types, p); | |||
320 | return p; | |||
321 | } | |||
322 | ||||
323 | // Prevent matching the TypeRange version above for ValueRange | |||
324 | // printing through base AsmPrinter. This is needed so that the | |||
325 | // ValueRange printing behaviour does not change from printing | |||
326 | // the SSA values to printing the types for the operands when | |||
327 | // using AsmPrinter instead of OpAsmPrinter. | |||
328 | template <typename AsmPrinterT, typename T> | |||
329 | inline std::enable_if_t<std::is_same<AsmPrinter, AsmPrinterT>::value && | |||
330 | std::is_convertible<T &, ValueRange>::value, | |||
331 | AsmPrinterT &> | |||
332 | operator<<(AsmPrinterT &p, const T &other) = delete; | |||
333 | ||||
334 | template <typename AsmPrinterT, typename ElementT> | |||
335 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, | |||
336 | AsmPrinterT &> | |||
337 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { | |||
338 | llvm::interleaveComma(types, p); | |||
339 | return p; | |||
340 | } | |||
341 | ||||
342 | //===----------------------------------------------------------------------===// | |||
343 | // OpAsmPrinter | |||
344 | //===----------------------------------------------------------------------===// | |||
345 | ||||
346 | /// This is a pure-virtual base class that exposes the asmprinter hooks | |||
347 | /// necessary to implement a custom print() method. | |||
348 | class OpAsmPrinter : public AsmPrinter { | |||
349 | public: | |||
350 | using AsmPrinter::AsmPrinter; | |||
351 | ~OpAsmPrinter() override; | |||
352 | ||||
353 | /// Print a loc(...) specifier if printing debug info is enabled. | |||
354 | virtual void printOptionalLocationSpecifier(Location loc) = 0; | |||
355 | ||||
356 | /// Print a newline and indent the printer to the start of the current | |||
357 | /// operation. | |||
358 | virtual void printNewline() = 0; | |||
359 | ||||
360 | /// Increase indentation. | |||
361 | virtual void increaseIndent() = 0; | |||
362 | ||||
363 | /// Decrease indentation. | |||
364 | virtual void decreaseIndent() = 0; | |||
365 | ||||
366 | /// Print a block argument in the usual format of: | |||
367 | /// %ssaName : type {attr1=42} loc("here") | |||
368 | /// where location printing is controlled by the standard internal option. | |||
369 | /// You may pass omitType=true to not print a type, and pass an empty | |||
370 | /// attribute list if you don't care for attributes. | |||
371 | virtual void printRegionArgument(BlockArgument arg, | |||
372 | ArrayRef<NamedAttribute> argAttrs = {}, | |||
373 | bool omitType = false) = 0; | |||
374 | ||||
375 | /// Print implementations for various things an operation contains. | |||
376 | virtual void printOperand(Value value) = 0; | |||
377 | virtual void printOperand(Value value, raw_ostream &os) = 0; | |||
378 | ||||
379 | /// Print a comma separated list of operands. | |||
380 | template <typename ContainerType> | |||
381 | void printOperands(const ContainerType &container) { | |||
382 | printOperands(container.begin(), container.end()); | |||
383 | } | |||
384 | ||||
385 | /// Print a comma separated list of operands. | |||
386 | template <typename IteratorType> | |||
387 | void printOperands(IteratorType it, IteratorType end) { | |||
388 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), | |||
389 | [this](Value value) { printOperand(value); }); | |||
390 | } | |||
391 | ||||
392 | /// Print the given successor. | |||
393 | virtual void printSuccessor(Block *successor) = 0; | |||
394 | ||||
395 | /// Print the successor and its operands. | |||
396 | virtual void printSuccessorAndUseList(Block *successor, | |||
397 | ValueRange succOperands) = 0; | |||
398 | ||||
399 | /// If the specified operation has attributes, print out an attribute | |||
400 | /// dictionary with their values. elidedAttrs allows the client to ignore | |||
401 | /// specific well known attributes, commonly used if the attribute value is | |||
402 | /// printed some other way (like as a fixed operand). | |||
403 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, | |||
404 | ArrayRef<StringRef> elidedAttrs = {}) = 0; | |||
405 | ||||
406 | /// If the specified operation has attributes, print out an attribute | |||
407 | /// dictionary prefixed with 'attributes'. | |||
408 | virtual void | |||
409 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, | |||
410 | ArrayRef<StringRef> elidedAttrs = {}) = 0; | |||
411 | ||||
412 | /// Prints the entire operation with the custom assembly form, if available, | |||
413 | /// or the generic assembly form, otherwise. | |||
414 | virtual void printCustomOrGenericOp(Operation *op) = 0; | |||
415 | ||||
416 | /// Print the entire operation with the default generic assembly form. | |||
417 | /// If `printOpName` is true, then the operation name is printed (the default) | |||
418 | /// otherwise it is omitted and the print will start with the operand list. | |||
419 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; | |||
420 | ||||
421 | /// Prints a region. | |||
422 | /// If 'printEntryBlockArgs' is false, the arguments of the | |||
423 | /// block are not printed. If 'printBlockTerminator' is false, the terminator | |||
424 | /// operation of the block is not printed. If printEmptyBlock is true, then | |||
425 | /// the block header is printed even if the block is empty. | |||
426 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, | |||
427 | bool printBlockTerminators = true, | |||
428 | bool printEmptyBlock = false) = 0; | |||
429 | ||||
430 | /// Renumber the arguments for the specified region to the same names as the | |||
431 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove | |||
432 | /// operations. If any entry in namesToUse is null, the corresponding | |||
433 | /// argument name is left alone. | |||
434 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; | |||
435 | ||||
436 | /// Prints an affine map of SSA ids, where SSA id names are used in place | |||
437 | /// of dims/symbols. | |||
438 | /// Operand values must come from single-result sources, and be valid | |||
439 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. | |||
440 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, | |||
441 | ValueRange operands) = 0; | |||
442 | ||||
443 | /// Prints an affine expression of SSA ids with SSA id names used instead of | |||
444 | /// dims and symbols. | |||
445 | /// Operand values must come from single-result sources, and be valid | |||
446 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. | |||
447 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, | |||
448 | ValueRange symOperands) = 0; | |||
449 | ||||
450 | /// Print the complete type of an operation in functional form. | |||
451 | void printFunctionalType(Operation *op); | |||
452 | using AsmPrinter::printFunctionalType; | |||
453 | }; | |||
454 | ||||
455 | // Make the implementations convenient to use. | |||
456 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { | |||
457 | p.printOperand(value); | |||
458 | return p; | |||
459 | } | |||
460 | ||||
461 | template <typename T, | |||
462 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && | |||
463 | !std::is_convertible<T &, Value &>::value, | |||
464 | T> * = nullptr> | |||
465 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { | |||
466 | p.printOperands(values); | |||
467 | return p; | |||
468 | } | |||
469 | ||||
470 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { | |||
471 | p.printSuccessor(value); | |||
472 | return p; | |||
473 | } | |||
474 | ||||
475 | //===----------------------------------------------------------------------===// | |||
476 | // AsmParser | |||
477 | //===----------------------------------------------------------------------===// | |||
478 | ||||
479 | /// This base class exposes generic asm parser hooks, usable across the various | |||
480 | /// derived parsers. | |||
481 | class AsmParser { | |||
482 | public: | |||
483 | AsmParser() = default; | |||
484 | virtual ~AsmParser(); | |||
485 | ||||
486 | MLIRContext *getContext() const; | |||
487 | ||||
488 | /// Return the location of the original name token. | |||
489 | virtual SMLoc getNameLoc() const = 0; | |||
490 | ||||
491 | //===--------------------------------------------------------------------===// | |||
492 | // Utilities | |||
493 | //===--------------------------------------------------------------------===// | |||
494 | ||||
495 | /// Emit a diagnostic at the specified location and return failure. | |||
496 | virtual InFlightDiagnostic emitError(SMLoc loc, | |||
497 | const Twine &message = {}) = 0; | |||
498 | ||||
499 | /// Return a builder which provides useful access to MLIRContext, global | |||
500 | /// objects like types and attributes. | |||
501 | virtual Builder &getBuilder() const = 0; | |||
502 | ||||
503 | /// Get the location of the next token and store it into the argument. This | |||
504 | /// always succeeds. | |||
505 | virtual SMLoc getCurrentLocation() = 0; | |||
506 | ParseResult getCurrentLocation(SMLoc *loc) { | |||
507 | *loc = getCurrentLocation(); | |||
508 | return success(); | |||
509 | } | |||
510 | ||||
511 | /// Re-encode the given source location as an MLIR location and return it. | |||
512 | /// Note: This method should only be used when a `Location` is necessary, as | |||
513 | /// the encoding process is not efficient. | |||
514 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; | |||
515 | ||||
516 | //===--------------------------------------------------------------------===// | |||
517 | // Token Parsing | |||
518 | //===--------------------------------------------------------------------===// | |||
519 | ||||
520 | /// Parse a '->' token. | |||
521 | virtual ParseResult parseArrow() = 0; | |||
522 | ||||
523 | /// Parse a '->' token if present | |||
524 | virtual ParseResult parseOptionalArrow() = 0; | |||
525 | ||||
526 | /// Parse a `{` token. | |||
527 | virtual ParseResult parseLBrace() = 0; | |||
528 | ||||
529 | /// Parse a `{` token if present. | |||
530 | virtual ParseResult parseOptionalLBrace() = 0; | |||
531 | ||||
532 | /// Parse a `}` token. | |||
533 | virtual ParseResult parseRBrace() = 0; | |||
534 | ||||
535 | /// Parse a `}` token if present. | |||
536 | virtual ParseResult parseOptionalRBrace() = 0; | |||
537 | ||||
538 | /// Parse a `:` token. | |||
539 | virtual ParseResult parseColon() = 0; | |||
540 | ||||
541 | /// Parse a `:` token if present. | |||
542 | virtual ParseResult parseOptionalColon() = 0; | |||
543 | ||||
544 | /// Parse a `,` token. | |||
545 | virtual ParseResult parseComma() = 0; | |||
546 | ||||
547 | /// Parse a `,` token if present. | |||
548 | virtual ParseResult parseOptionalComma() = 0; | |||
549 | ||||
550 | /// Parse a `=` token. | |||
551 | virtual ParseResult parseEqual() = 0; | |||
552 | ||||
553 | /// Parse a `=` token if present. | |||
554 | virtual ParseResult parseOptionalEqual() = 0; | |||
555 | ||||
556 | /// Parse a '<' token. | |||
557 | virtual ParseResult parseLess() = 0; | |||
558 | ||||
559 | /// Parse a '<' token if present. | |||
560 | virtual ParseResult parseOptionalLess() = 0; | |||
561 | ||||
562 | /// Parse a '>' token. | |||
563 | virtual ParseResult parseGreater() = 0; | |||
564 | ||||
565 | /// Parse a '>' token if present. | |||
566 | virtual ParseResult parseOptionalGreater() = 0; | |||
567 | ||||
568 | /// Parse a '?' token. | |||
569 | virtual ParseResult parseQuestion() = 0; | |||
570 | ||||
571 | /// Parse a '?' token if present. | |||
572 | virtual ParseResult parseOptionalQuestion() = 0; | |||
573 | ||||
574 | /// Parse a '+' token. | |||
575 | virtual ParseResult parsePlus() = 0; | |||
576 | ||||
577 | /// Parse a '+' token if present. | |||
578 | virtual ParseResult parseOptionalPlus() = 0; | |||
579 | ||||
580 | /// Parse a '*' token. | |||
581 | virtual ParseResult parseStar() = 0; | |||
582 | ||||
583 | /// Parse a '*' token if present. | |||
584 | virtual ParseResult parseOptionalStar() = 0; | |||
585 | ||||
586 | /// Parse a '|' token. | |||
587 | virtual ParseResult parseVerticalBar() = 0; | |||
588 | ||||
589 | /// Parse a '|' token if present. | |||
590 | virtual ParseResult parseOptionalVerticalBar() = 0; | |||
591 | ||||
592 | /// Parse a quoted string token. | |||
593 | ParseResult parseString(std::string *string) { | |||
594 | auto loc = getCurrentLocation(); | |||
595 | if (parseOptionalString(string)) | |||
596 | return emitError(loc, "expected string"); | |||
597 | return success(); | |||
598 | } | |||
599 | ||||
600 | /// Parse a quoted string token if present. | |||
601 | virtual ParseResult parseOptionalString(std::string *string) = 0; | |||
602 | ||||
603 | /// Parses a Base64 encoded string of bytes. | |||
604 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; | |||
605 | ||||
606 | /// Parse a `(` token. | |||
607 | virtual ParseResult parseLParen() = 0; | |||
608 | ||||
609 | /// Parse a `(` token if present. | |||
610 | virtual ParseResult parseOptionalLParen() = 0; | |||
611 | ||||
612 | /// Parse a `)` token. | |||
613 | virtual ParseResult parseRParen() = 0; | |||
614 | ||||
615 | /// Parse a `)` token if present. | |||
616 | virtual ParseResult parseOptionalRParen() = 0; | |||
617 | ||||
618 | /// Parse a `[` token. | |||
619 | virtual ParseResult parseLSquare() = 0; | |||
620 | ||||
621 | /// Parse a `[` token if present. | |||
622 | virtual ParseResult parseOptionalLSquare() = 0; | |||
623 | ||||
624 | /// Parse a `]` token. | |||
625 | virtual ParseResult parseRSquare() = 0; | |||
626 | ||||
627 | /// Parse a `]` token if present. | |||
628 | virtual ParseResult parseOptionalRSquare() = 0; | |||
629 | ||||
630 | /// Parse a `...` token. | |||
631 | virtual ParseResult parseEllipsis() = 0; | |||
632 | ||||
633 | /// Parse a `...` token if present; | |||
634 | virtual ParseResult parseOptionalEllipsis() = 0; | |||
635 | ||||
636 | /// Parse a floating point value from the stream. | |||
637 | virtual ParseResult parseFloat(double &result) = 0; | |||
638 | ||||
639 | /// Parse an integer value from the stream. | |||
640 | template <typename IntT> | |||
641 | ParseResult parseInteger(IntT &result) { | |||
642 | auto loc = getCurrentLocation(); | |||
643 | OptionalParseResult parseResult = parseOptionalInteger(result); | |||
644 | if (!parseResult.has_value()) | |||
645 | return emitError(loc, "expected integer value"); | |||
646 | return *parseResult; | |||
647 | } | |||
648 | ||||
649 | /// Parse an optional integer value from the stream. | |||
650 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; | |||
651 | ||||
652 | template <typename IntT> | |||
653 | OptionalParseResult parseOptionalInteger(IntT &result) { | |||
654 | auto loc = getCurrentLocation(); | |||
655 | ||||
656 | // Parse the unsigned variant. | |||
657 | APInt uintResult; | |||
658 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); | |||
659 | if (!parseResult.has_value() || failed(*parseResult)) | |||
660 | return parseResult; | |||
661 | ||||
662 | // Try to convert to the provided integer type. sextOrTrunc is correct even | |||
663 | // for unsigned types because parseOptionalInteger ensures the sign bit is | |||
664 | // zero for non-negated integers. | |||
665 | result = | |||
666 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); | |||
667 | if (APInt(uintResult.getBitWidth(), result) != uintResult) | |||
668 | return emitError(loc, "integer value too large"); | |||
669 | return success(); | |||
670 | } | |||
671 | ||||
672 | /// These are the supported delimiters around operand lists and region | |||
673 | /// argument lists, used by parseOperandList. | |||
674 | enum class Delimiter { | |||
675 | /// Zero or more operands with no delimiters. | |||
676 | None, | |||
677 | /// Parens surrounding zero or more operands. | |||
678 | Paren, | |||
679 | /// Square brackets surrounding zero or more operands. | |||
680 | Square, | |||
681 | /// <> brackets surrounding zero or more operands. | |||
682 | LessGreater, | |||
683 | /// {} brackets surrounding zero or more operands. | |||
684 | Braces, | |||
685 | /// Parens supporting zero or more operands, or nothing. | |||
686 | OptionalParen, | |||
687 | /// Square brackets supporting zero or more ops, or nothing. | |||
688 | OptionalSquare, | |||
689 | /// <> brackets supporting zero or more ops, or nothing. | |||
690 | OptionalLessGreater, | |||
691 | /// {} brackets surrounding zero or more operands, or nothing. | |||
692 | OptionalBraces, | |||
693 | }; | |||
694 | ||||
695 | /// Parse a list of comma-separated items with an optional delimiter. If a | |||
696 | /// delimiter is provided, then an empty list is allowed. If not, then at | |||
697 | /// least one element will be parsed. | |||
698 | /// | |||
699 | /// contextMessage is an optional message appended to "expected '('" sorts of | |||
700 | /// diagnostics when parsing the delimeters. | |||
701 | virtual ParseResult | |||
702 | parseCommaSeparatedList(Delimiter delimiter, | |||
703 | function_ref<ParseResult()> parseElementFn, | |||
704 | StringRef contextMessage = StringRef()) = 0; | |||
705 | ||||
706 | /// Parse a comma separated list of elements that must have at least one entry | |||
707 | /// in it. | |||
708 | ParseResult | |||
709 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { | |||
710 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); | |||
711 | } | |||
712 | ||||
713 | //===--------------------------------------------------------------------===// | |||
714 | // Keyword Parsing | |||
715 | //===--------------------------------------------------------------------===// | |||
716 | ||||
717 | /// This class represents a StringSwitch like class that is useful for parsing | |||
718 | /// expected keywords. On construction, it invokes `parseKeyword` and | |||
719 | /// processes each of the provided cases statements until a match is hit. The | |||
720 | /// provided `ResultT` must be assignable from `failure()`. | |||
721 | template <typename ResultT = ParseResult> | |||
722 | class KeywordSwitch { | |||
723 | public: | |||
724 | KeywordSwitch(AsmParser &parser) | |||
725 | : parser(parser), loc(parser.getCurrentLocation()) { | |||
726 | if (failed(parser.parseKeywordOrCompletion(&keyword))) | |||
727 | result = failure(); | |||
728 | } | |||
729 | ||||
730 | /// Case that uses the provided value when true. | |||
731 | KeywordSwitch &Case(StringLiteral str, ResultT value) { | |||
732 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); | |||
733 | } | |||
734 | KeywordSwitch &Default(ResultT value) { | |||
735 | return Default([&](StringRef, SMLoc) { return std::move(value); }); | |||
736 | } | |||
737 | /// Case that invokes the provided functor when true. The parameters passed | |||
738 | /// to the functor are the keyword, and the location of the keyword (in case | |||
739 | /// any errors need to be emitted). | |||
740 | template <typename FnT> | |||
741 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> | |||
742 | Case(StringLiteral str, FnT &&fn) { | |||
743 | if (result) | |||
744 | return *this; | |||
745 | ||||
746 | // If the word was empty, record this as a completion. | |||
747 | if (keyword.empty()) | |||
748 | parser.codeCompleteExpectedTokens(str); | |||
749 | else if (keyword == str) | |||
750 | result.emplace(std::move(fn(keyword, loc))); | |||
751 | return *this; | |||
752 | } | |||
753 | template <typename FnT> | |||
754 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> | |||
755 | Default(FnT &&fn) { | |||
756 | if (!result) | |||
757 | result.emplace(fn(keyword, loc)); | |||
758 | return *this; | |||
759 | } | |||
760 | ||||
761 | /// Returns true if this switch has a value yet. | |||
762 | bool hasValue() const { return result.has_value(); } | |||
763 | ||||
764 | /// Return the result of the switch. | |||
765 | [[nodiscard]] operator ResultT() { | |||
766 | if (!result) | |||
767 | return parser.emitError(loc, "unexpected keyword: ") << keyword; | |||
768 | return std::move(*result); | |||
769 | } | |||
770 | ||||
771 | private: | |||
772 | /// The parser used to construct this switch. | |||
773 | AsmParser &parser; | |||
774 | ||||
775 | /// The location of the keyword, used to emit errors as necessary. | |||
776 | SMLoc loc; | |||
777 | ||||
778 | /// The parsed keyword itself. | |||
779 | StringRef keyword; | |||
780 | ||||
781 | /// The result of the switch statement or std::nullopt if currently unknown. | |||
782 | std::optional<ResultT> result; | |||
783 | }; | |||
784 | ||||
785 | /// Parse a given keyword. | |||
786 | ParseResult parseKeyword(StringRef keyword) { | |||
787 | return parseKeyword(keyword, ""); | |||
788 | } | |||
789 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; | |||
790 | ||||
791 | /// Parse a keyword into 'keyword'. | |||
792 | ParseResult parseKeyword(StringRef *keyword) { | |||
793 | auto loc = getCurrentLocation(); | |||
794 | if (parseOptionalKeyword(keyword)) | |||
795 | return emitError(loc, "expected valid keyword"); | |||
796 | return success(); | |||
797 | } | |||
798 | ||||
799 | /// Parse the given keyword if present. | |||
800 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; | |||
801 | ||||
802 | /// Parse a keyword, if present, into 'keyword'. | |||
803 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; | |||
804 | ||||
805 | /// Parse a keyword, if present, and if one of the 'allowedValues', | |||
806 | /// into 'keyword' | |||
807 | virtual ParseResult | |||
808 | parseOptionalKeyword(StringRef *keyword, | |||
809 | ArrayRef<StringRef> allowedValues) = 0; | |||
810 | ||||
811 | /// Parse a keyword or a quoted string. | |||
812 | ParseResult parseKeywordOrString(std::string *result) { | |||
813 | if (failed(parseOptionalKeywordOrString(result))) | |||
814 | return emitError(getCurrentLocation()) | |||
815 | << "expected valid keyword or string"; | |||
816 | return success(); | |||
817 | } | |||
818 | ||||
819 | /// Parse an optional keyword or string. | |||
820 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; | |||
821 | ||||
822 | //===--------------------------------------------------------------------===// | |||
823 | // Attribute/Type Parsing | |||
824 | //===--------------------------------------------------------------------===// | |||
825 | ||||
826 | /// Invoke the `getChecked` method of the given Attribute or Type class, using | |||
827 | /// the provided location to emit errors in the case of failure. Note that | |||
828 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a | |||
829 | /// context parameter. | |||
830 | template <typename T, typename... ParamsT> | |||
831 | auto getChecked(SMLoc loc, ParamsT &&...params) { | |||
832 | return T::getChecked([&] { return emitError(loc); }, | |||
833 | std::forward<ParamsT>(params)...); | |||
834 | } | |||
835 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit | |||
836 | /// errors. | |||
837 | template <typename T, typename... ParamsT> | |||
838 | auto getChecked(ParamsT &&...params) { | |||
839 | return T::getChecked([&] { return emitError(getNameLoc()); }, | |||
| ||||
840 | std::forward<ParamsT>(params)...); | |||
841 | } | |||
842 | ||||
843 | //===--------------------------------------------------------------------===// | |||
844 | // Attribute Parsing | |||
845 | //===--------------------------------------------------------------------===// | |||
846 | ||||
847 | /// Parse an arbitrary attribute of a given type and return it in result. | |||
848 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; | |||
849 | ||||
850 | /// Parse a custom attribute with the provided callback, unless the next | |||
851 | /// token is `#`, in which case the generic parser is invoked. | |||
852 | virtual ParseResult parseCustomAttributeWithFallback( | |||
853 | Attribute &result, Type type, | |||
854 | function_ref<ParseResult(Attribute &result, Type type)> | |||
855 | parseAttribute) = 0; | |||
856 | ||||
857 | /// Parse an attribute of a specific kind and type. | |||
858 | template <typename AttrType> | |||
859 | ParseResult parseAttribute(AttrType &result, Type type = {}) { | |||
860 | SMLoc loc = getCurrentLocation(); | |||
861 | ||||
862 | // Parse any kind of attribute. | |||
863 | Attribute attr; | |||
864 | if (parseAttribute(attr, type)) | |||
865 | return failure(); | |||
866 | ||||
867 | // Check for the right kind of attribute. | |||
868 | if (!(result = attr.dyn_cast<AttrType>())) | |||
869 | return emitError(loc, "invalid kind of attribute specified"); | |||
870 | ||||
871 | return success(); | |||
872 | } | |||
873 | ||||
874 | /// Parse an arbitrary attribute and return it in result. This also adds the | |||
875 | /// attribute to the specified attribute list with the specified name. | |||
876 | ParseResult parseAttribute(Attribute &result, StringRef attrName, | |||
877 | NamedAttrList &attrs) { | |||
878 | return parseAttribute(result, Type(), attrName, attrs); | |||
879 | } | |||
880 | ||||
881 | /// Parse an attribute of a specific kind and type. | |||
882 | template <typename AttrType> | |||
883 | ParseResult parseAttribute(AttrType &result, StringRef attrName, | |||
884 | NamedAttrList &attrs) { | |||
885 | return parseAttribute(result, Type(), attrName, attrs); | |||
886 | } | |||
887 | ||||
888 | /// Parse an arbitrary attribute of a given type and populate it in `result`. | |||
889 | /// This also adds the attribute to the specified attribute list with the | |||
890 | /// specified name. | |||
891 | template <typename AttrType> | |||
892 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, | |||
893 | NamedAttrList &attrs) { | |||
894 | SMLoc loc = getCurrentLocation(); | |||
895 | ||||
896 | // Parse any kind of attribute. | |||
897 | Attribute attr; | |||
898 | if (parseAttribute(attr, type)) | |||
899 | return failure(); | |||
900 | ||||
901 | // Check for the right kind of attribute. | |||
902 | result = attr.dyn_cast<AttrType>(); | |||
903 | if (!result) | |||
904 | return emitError(loc, "invalid kind of attribute specified"); | |||
905 | ||||
906 | attrs.append(attrName, result); | |||
907 | return success(); | |||
908 | } | |||
909 | ||||
910 | /// Trait to check if `AttrType` provides a `parse` method. | |||
911 | template <typename AttrType> | |||
912 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), | |||
913 | std::declval<Type>())); | |||
914 | template <typename AttrType> | |||
915 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; | |||
916 | ||||
917 | /// Parse a custom attribute of a given type unless the next token is `#`, in | |||
918 | /// which case the generic parser is invoked. The parsed attribute is | |||
919 | /// populated in `result` and also added to the specified attribute list with | |||
920 | /// the specified name. | |||
921 | template <typename AttrType> | |||
922 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> | |||
923 | parseCustomAttributeWithFallback(AttrType &result, Type type, | |||
924 | StringRef attrName, NamedAttrList &attrs) { | |||
925 | SMLoc loc = getCurrentLocation(); | |||
926 | ||||
927 | // Parse any kind of attribute. | |||
928 | Attribute attr; | |||
929 | if (parseCustomAttributeWithFallback( | |||
930 | attr, type, [&](Attribute &result, Type type) -> ParseResult { | |||
931 | result = AttrType::parse(*this, type); | |||
932 | if (!result) | |||
933 | return failure(); | |||
934 | return success(); | |||
935 | })) | |||
936 | return failure(); | |||
937 | ||||
938 | // Check for the right kind of attribute. | |||
939 | result = attr.dyn_cast<AttrType>(); | |||
940 | if (!result) | |||
941 | return emitError(loc, "invalid kind of attribute specified"); | |||
942 | ||||
943 | attrs.append(attrName, result); | |||
944 | return success(); | |||
945 | } | |||
946 | ||||
947 | /// SFINAE parsing method for Attribute that don't implement a parse method. | |||
948 | template <typename AttrType> | |||
949 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> | |||
950 | parseCustomAttributeWithFallback(AttrType &result, Type type, | |||
951 | StringRef attrName, NamedAttrList &attrs) { | |||
952 | return parseAttribute(result, type, attrName, attrs); | |||
953 | } | |||
954 | ||||
955 | /// Parse a custom attribute of a given type unless the next token is `#`, in | |||
956 | /// which case the generic parser is invoked. The parsed attribute is | |||
957 | /// populated in `result`. | |||
958 | template <typename AttrType> | |||
959 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> | |||
960 | parseCustomAttributeWithFallback(AttrType &result, Type type = {}) { | |||
961 | SMLoc loc = getCurrentLocation(); | |||
962 | ||||
963 | // Parse any kind of attribute. | |||
964 | Attribute attr; | |||
965 | if (parseCustomAttributeWithFallback( | |||
966 | attr, type, [&](Attribute &result, Type type) -> ParseResult { | |||
967 | result = AttrType::parse(*this, type); | |||
968 | return success(!!result); | |||
969 | })) | |||
970 | return failure(); | |||
971 | ||||
972 | // Check for the right kind of attribute. | |||
973 | result = attr.dyn_cast<AttrType>(); | |||
974 | if (!result) | |||
975 | return emitError(loc, "invalid kind of attribute specified"); | |||
976 | return success(); | |||
977 | } | |||
978 | ||||
979 | /// SFINAE parsing method for Attribute that don't implement a parse method. | |||
980 | template <typename AttrType> | |||
981 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> | |||
982 | parseCustomAttributeWithFallback(AttrType &result, Type type = {}) { | |||
983 | return parseAttribute(result, type); | |||
984 | } | |||
985 | ||||
986 | /// Parse an arbitrary optional attribute of a given type and return it in | |||
987 | /// result. | |||
988 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, | |||
989 | Type type = {}) = 0; | |||
990 | ||||
991 | /// Parse an optional array attribute and return it in result. | |||
992 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, | |||
993 | Type type = {}) = 0; | |||
994 | ||||
995 | /// Parse an optional string attribute and return it in result. | |||
996 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, | |||
997 | Type type = {}) = 0; | |||
998 | ||||
999 | /// Parse an optional symbol ref attribute and return it in result. | |||
1000 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, | |||
1001 | Type type = {}) = 0; | |||
1002 | ||||
1003 | /// Parse an optional attribute of a specific type and add it to the list with | |||
1004 | /// the specified name. | |||
1005 | template <typename AttrType> | |||
1006 | OptionalParseResult parseOptionalAttribute(AttrType &result, | |||
1007 | StringRef attrName, | |||
1008 | NamedAttrList &attrs) { | |||
1009 | return parseOptionalAttribute(result, Type(), attrName, attrs); | |||
1010 | } | |||
1011 | ||||
1012 | /// Parse an optional attribute of a specific type and add it to the list with | |||
1013 | /// the specified name. | |||
1014 | template <typename AttrType> | |||
1015 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, | |||
1016 | StringRef attrName, | |||
1017 | NamedAttrList &attrs) { | |||
1018 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); | |||
1019 | if (parseResult.has_value() && succeeded(*parseResult)) | |||
1020 | attrs.append(attrName, result); | |||
1021 | return parseResult; | |||
1022 | } | |||
1023 | ||||
1024 | /// Parse a named dictionary into 'result' if it is present. | |||
1025 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; | |||
1026 | ||||
1027 | /// Parse a named dictionary into 'result' if the `attributes` keyword is | |||
1028 | /// present. | |||
1029 | virtual ParseResult | |||
1030 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; | |||
1031 | ||||
1032 | /// Parse an affine map instance into 'map'. | |||
1033 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; | |||
1034 | ||||
1035 | /// Parse an integer set instance into 'set'. | |||
1036 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; | |||
1037 | ||||
1038 | //===--------------------------------------------------------------------===// | |||
1039 | // Identifier Parsing | |||
1040 | //===--------------------------------------------------------------------===// | |||
1041 | ||||
1042 | /// Parse an @-identifier and store it (without the '@' symbol) in a string | |||
1043 | /// attribute. | |||
1044 | ParseResult parseSymbolName(StringAttr &result) { | |||
1045 | if (failed(parseOptionalSymbolName(result))) | |||
1046 | return emitError(getCurrentLocation()) | |||
1047 | << "expected valid '@'-identifier for symbol name"; | |||
1048 | return success(); | |||
1049 | } | |||
1050 | ||||
1051 | /// Parse an @-identifier and store it (without the '@' symbol) in a string | |||
1052 | /// attribute named 'attrName'. | |||
1053 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, | |||
1054 | NamedAttrList &attrs) { | |||
1055 | if (parseSymbolName(result)) | |||
1056 | return failure(); | |||
1057 | attrs.append(attrName, result); | |||
1058 | return success(); | |||
1059 | } | |||
1060 | ||||
1061 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a | |||
1062 | /// string attribute. | |||
1063 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; | |||
1064 | ||||
1065 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a | |||
1066 | /// string attribute named 'attrName'. | |||
1067 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, | |||
1068 | NamedAttrList &attrs) { | |||
1069 | if (succeeded(parseOptionalSymbolName(result))) { | |||
1070 | attrs.append(attrName, result); | |||
1071 | return success(); | |||
1072 | } | |||
1073 | return failure(); | |||
1074 | } | |||
1075 | ||||
1076 | //===--------------------------------------------------------------------===// | |||
1077 | // Resource Parsing | |||
1078 | //===--------------------------------------------------------------------===// | |||
1079 | ||||
1080 | /// Parse a handle to a resource within the assembly format. | |||
1081 | template <typename ResourceT> | |||
1082 | FailureOr<ResourceT> parseResourceHandle() { | |||
1083 | SMLoc handleLoc = getCurrentLocation(); | |||
1084 | ||||
1085 | // Try to load the dialect that owns the handle. | |||
1086 | auto *dialect = | |||
1087 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); | |||
1088 | if (!dialect) { | |||
1089 | return emitError(handleLoc) | |||
1090 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() | |||
1091 | << "' is unknown"; | |||
1092 | } | |||
1093 | ||||
1094 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); | |||
1095 | if (failed(handle)) | |||
1096 | return failure(); | |||
1097 | if (auto *result = dyn_cast<ResourceT>(&*handle)) | |||
1098 | return std::move(*result); | |||
1099 | return emitError(handleLoc) << "provided resource handle differs from the " | |||
1100 | "expected resource type"; | |||
1101 | } | |||
1102 | ||||
1103 | //===--------------------------------------------------------------------===// | |||
1104 | // Type Parsing | |||
1105 | //===--------------------------------------------------------------------===// | |||
1106 | ||||
1107 | /// Parse a type. | |||
1108 | virtual ParseResult parseType(Type &result) = 0; | |||
1109 | ||||
1110 | /// Parse a custom type with the provided callback, unless the next | |||
1111 | /// token is `#`, in which case the generic parser is invoked. | |||
1112 | virtual ParseResult parseCustomTypeWithFallback( | |||
1113 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; | |||
1114 | ||||
1115 | /// Parse an optional type. | |||
1116 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; | |||
1117 | ||||
1118 | /// Parse a type of a specific type. | |||
1119 | template <typename TypeT> | |||
1120 | ParseResult parseType(TypeT &result) { | |||
1121 | SMLoc loc = getCurrentLocation(); | |||
1122 | ||||
1123 | // Parse any kind of type. | |||
1124 | Type type; | |||
1125 | if (parseType(type)) | |||
1126 | return failure(); | |||
1127 | ||||
1128 | // Check for the right kind of type. | |||
1129 | result = type.dyn_cast<TypeT>(); | |||
1130 | if (!result) | |||
1131 | return emitError(loc, "invalid kind of type specified"); | |||
1132 | ||||
1133 | return success(); | |||
1134 | } | |||
1135 | ||||
1136 | /// Trait to check if `TypeT` provides a `parse` method. | |||
1137 | template <typename TypeT> | |||
1138 | using type_has_parse_method = | |||
1139 | decltype(TypeT::parse(std::declval<AsmParser &>())); | |||
1140 | template <typename TypeT> | |||
1141 | using detect_type_has_parse_method = | |||
1142 | llvm::is_detected<type_has_parse_method, TypeT>; | |||
1143 | ||||
1144 | /// Parse a custom Type of a given type unless the next token is `#`, in | |||
1145 | /// which case the generic parser is invoked. The parsed Type is | |||
1146 | /// populated in `result`. | |||
1147 | template <typename TypeT> | |||
1148 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> | |||
1149 | parseCustomTypeWithFallback(TypeT &result) { | |||
1150 | SMLoc loc = getCurrentLocation(); | |||
1151 | ||||
1152 | // Parse any kind of Type. | |||
1153 | Type type; | |||
1154 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { | |||
1155 | result = TypeT::parse(*this); | |||
1156 | return success(!!result); | |||
1157 | })) | |||
1158 | return failure(); | |||
1159 | ||||
1160 | // Check for the right kind of Type. | |||
1161 | result = type.dyn_cast<TypeT>(); | |||
1162 | if (!result) | |||
1163 | return emitError(loc, "invalid kind of Type specified"); | |||
1164 | return success(); | |||
1165 | } | |||
1166 | ||||
1167 | /// SFINAE parsing method for Type that don't implement a parse method. | |||
1168 | template <typename TypeT> | |||
1169 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> | |||
1170 | parseCustomTypeWithFallback(TypeT &result) { | |||
1171 | return parseType(result); | |||
1172 | } | |||
1173 | ||||
1174 | /// Parse a type list. | |||
1175 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { | |||
1176 | return parseCommaSeparatedList( | |||
1177 | [&]() { return parseType(result.emplace_back()); }); | |||
1178 | } | |||
1179 | ||||
1180 | /// Parse an arrow followed by a type list. | |||
1181 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; | |||
1182 | ||||
1183 | /// Parse an optional arrow followed by a type list. | |||
1184 | virtual ParseResult | |||
1185 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; | |||
1186 | ||||
1187 | /// Parse a colon followed by a type. | |||
1188 | virtual ParseResult parseColonType(Type &result) = 0; | |||
1189 | ||||
1190 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. | |||
1191 | template <typename TypeType> | |||
1192 | ParseResult parseColonType(TypeType &result) { | |||
1193 | SMLoc loc = getCurrentLocation(); | |||
1194 | ||||
1195 | // Parse any kind of type. | |||
1196 | Type type; | |||
1197 | if (parseColonType(type)) | |||
1198 | return failure(); | |||
1199 | ||||
1200 | // Check for the right kind of type. | |||
1201 | result = type.dyn_cast<TypeType>(); | |||
1202 | if (!result) | |||
1203 | return emitError(loc, "invalid kind of type specified"); | |||
1204 | ||||
1205 | return success(); | |||
1206 | } | |||
1207 | ||||
1208 | /// Parse a colon followed by a type list, which must have at least one type. | |||
1209 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; | |||
1210 | ||||
1211 | /// Parse an optional colon followed by a type list, which if present must | |||
1212 | /// have at least one type. | |||
1213 | virtual ParseResult | |||
1214 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; | |||
1215 | ||||
1216 | /// Parse a keyword followed by a type. | |||
1217 | ParseResult parseKeywordType(const char *keyword, Type &result) { | |||
1218 | return failure(parseKeyword(keyword) || parseType(result)); | |||
1219 | } | |||
1220 | ||||
1221 | /// Add the specified type to the end of the specified type list and return | |||
1222 | /// success. This is a helper designed to allow parse methods to be simple | |||
1223 | /// and chain through || operators. | |||
1224 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { | |||
1225 | result.push_back(type); | |||
1226 | return success(); | |||
1227 | } | |||
1228 | ||||
1229 | /// Add the specified types to the end of the specified type list and return | |||
1230 | /// success. This is a helper designed to allow parse methods to be simple | |||
1231 | /// and chain through || operators. | |||
1232 | ParseResult addTypesToList(ArrayRef<Type> types, | |||
1233 | SmallVectorImpl<Type> &result) { | |||
1234 | result.append(types.begin(), types.end()); | |||
1235 | return success(); | |||
1236 | } | |||
1237 | ||||
1238 | /// Parse a dimension list of a tensor or memref type. This populates the | |||
1239 | /// dimension list, using ShapedType::kDynamic for the `?` dimensions if | |||
1240 | /// `allowDynamic` is set and errors out on `?` otherwise. Parsing the | |||
1241 | /// trailing `x` is configurable. | |||
1242 | /// | |||
1243 | /// dimension-list ::= eps | dimension (`x` dimension)* | |||
1244 | /// dimension-list-with-trailing-x ::= (dimension `x`)* | |||
1245 | /// dimension ::= `?` | decimal-literal | |||
1246 | /// | |||
1247 | /// When `allowDynamic` is not set, this is used to parse: | |||
1248 | /// | |||
1249 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* | |||
1250 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* | |||
1251 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, | |||
1252 | bool allowDynamic = true, | |||
1253 | bool withTrailingX = true) = 0; | |||
1254 | ||||
1255 | /// Parse an 'x' token in a dimension list, handling the case where the x is | |||
1256 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the | |||
1257 | /// next token. | |||
1258 | virtual ParseResult parseXInDimensionList() = 0; | |||
1259 | ||||
1260 | protected: | |||
1261 | /// Parse a handle to a resource within the assembly format for the given | |||
1262 | /// dialect. | |||
1263 | virtual FailureOr<AsmDialectResourceHandle> | |||
1264 | parseResourceHandle(Dialect *dialect) = 0; | |||
1265 | ||||
1266 | //===--------------------------------------------------------------------===// | |||
1267 | // Code Completion | |||
1268 | //===--------------------------------------------------------------------===// | |||
1269 | ||||
1270 | /// Parse a keyword, or an empty string if the current location signals a code | |||
1271 | /// completion. | |||
1272 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; | |||
1273 | ||||
1274 | /// Signal the code completion of a set of expected tokens. | |||
1275 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; | |||
1276 | ||||
1277 | private: | |||
1278 | AsmParser(const AsmParser &) = delete; | |||
1279 | void operator=(const AsmParser &) = delete; | |||
1280 | }; | |||
1281 | ||||
1282 | //===----------------------------------------------------------------------===// | |||
1283 | // OpAsmParser | |||
1284 | //===----------------------------------------------------------------------===// | |||
1285 | ||||
1286 | /// The OpAsmParser has methods for interacting with the asm parser: parsing | |||
1287 | /// things from it, emitting errors etc. It has an intentionally high-level API | |||
1288 | /// that is designed to reduce/constrain syntax innovation in individual | |||
1289 | /// operations. | |||
1290 | /// | |||
1291 | /// For example, consider an op like this: | |||
1292 | /// | |||
1293 | /// %x = load %p[%1, %2] : memref<...> | |||
1294 | /// | |||
1295 | /// The "%x = load" tokens are already parsed and therefore invisible to the | |||
1296 | /// custom op parser. This can be supported by calling `parseOperandList` to | |||
1297 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to | |||
1298 | /// parse the indices, then calling `parseColonTypeList` to parse the result | |||
1299 | /// type. | |||
1300 | /// | |||
1301 | class OpAsmParser : public AsmParser { | |||
1302 | public: | |||
1303 | using AsmParser::AsmParser; | |||
1304 | ~OpAsmParser() override; | |||
1305 | ||||
1306 | /// Parse a loc(...) specifier if present, filling in result if so. | |||
1307 | /// Location for BlockArgument and Operation may be deferred with an alias, in | |||
1308 | /// which case an OpaqueLoc is set and will be resolved when parsing | |||
1309 | /// completes. | |||
1310 | virtual ParseResult | |||
1311 | parseOptionalLocationSpecifier(std::optional<Location> &result) = 0; | |||
1312 | ||||
1313 | /// Return the name of the specified result in the specified syntax, as well | |||
1314 | /// as the sub-element in the name. It returns an empty string and ~0U for | |||
1315 | /// invalid result numbers. For example, in this operation: | |||
1316 | /// | |||
1317 | /// %x, %y:2, %z = foo.op | |||
1318 | /// | |||
1319 | /// getResultName(0) == {"x", 0 } | |||
1320 | /// getResultName(1) == {"y", 0 } | |||
1321 | /// getResultName(2) == {"y", 1 } | |||
1322 | /// getResultName(3) == {"z", 0 } | |||
1323 | /// getResultName(4) == {"", ~0U } | |||
1324 | virtual std::pair<StringRef, unsigned> | |||
1325 | getResultName(unsigned resultNo) const = 0; | |||
1326 | ||||
1327 | /// Return the number of declared SSA results. This returns 4 for the foo.op | |||
1328 | /// example in the comment for `getResultName`. | |||
1329 | virtual size_t getNumResults() const = 0; | |||
1330 | ||||
1331 | // These methods emit an error and return failure or success. This allows | |||
1332 | // these to be chained together into a linear sequence of || expressions in | |||
1333 | // many cases. | |||
1334 | ||||
1335 | /// Parse an operation in its generic form. | |||
1336 | /// The parsed operation is parsed in the current context and inserted in the | |||
1337 | /// provided block and insertion point. The results produced by this operation | |||
1338 | /// aren't mapped to any named value in the parser. Returns nullptr on | |||
1339 | /// failure. | |||
1340 | virtual Operation *parseGenericOperation(Block *insertBlock, | |||
1341 | Block::iterator insertPt) = 0; | |||
1342 | ||||
1343 | /// Parse the name of an operation, in the custom form. On success, return a | |||
1344 | /// an object of type 'OperationName'. Otherwise, failure is returned. | |||
1345 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; | |||
1346 | ||||
1347 | //===--------------------------------------------------------------------===// | |||
1348 | // Operand Parsing | |||
1349 | //===--------------------------------------------------------------------===// | |||
1350 | ||||
1351 | /// This is the representation of an operand reference. | |||
1352 | struct UnresolvedOperand { | |||
1353 | SMLoc location; // Location of the token. | |||
1354 | StringRef name; // Value name, e.g. %42 or %abc | |||
1355 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 | |||
1356 | }; | |||
1357 | ||||
1358 | /// Parse different components, viz., use-info of operand(s), successor(s), | |||
1359 | /// region(s), attribute(s) and function-type, of the generic form of an | |||
1360 | /// operation instance and populate the input operation-state 'result' with | |||
1361 | /// those components. If any of the components is explicitly provided, then | |||
1362 | /// skip parsing that component. | |||
1363 | virtual ParseResult parseGenericOperationAfterOpName( | |||
1364 | OperationState &result, | |||
1365 | std::optional<ArrayRef<UnresolvedOperand>> parsedOperandType = | |||
1366 | std::nullopt, | |||
1367 | std::optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, | |||
1368 | std::optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = | |||
1369 | std::nullopt, | |||
1370 | std::optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, | |||
1371 | std::optional<Attribute> parsedPropertiesAttribute = std::nullopt, | |||
1372 | std::optional<FunctionType> parsedFnType = std::nullopt) = 0; | |||
1373 | ||||
1374 | /// Parse a single SSA value operand name along with a result number if | |||
1375 | /// `allowResultNumber` is true. | |||
1376 | virtual ParseResult parseOperand(UnresolvedOperand &result, | |||
1377 | bool allowResultNumber = true) = 0; | |||
1378 | ||||
1379 | /// Parse a single operand if present. | |||
1380 | virtual OptionalParseResult | |||
1381 | parseOptionalOperand(UnresolvedOperand &result, | |||
1382 | bool allowResultNumber = true) = 0; | |||
1383 | ||||
1384 | /// Parse zero or more SSA comma-separated operand references with a specified | |||
1385 | /// surrounding delimiter, and an optional required operand count. | |||
1386 | virtual ParseResult | |||
1387 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, | |||
1388 | Delimiter delimiter = Delimiter::None, | |||
1389 | bool allowResultNumber = true, | |||
1390 | int requiredOperandCount = -1) = 0; | |||
1391 | ||||
1392 | /// Parse a specified number of comma separated operands. | |||
1393 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, | |||
1394 | int requiredOperandCount, | |||
1395 | Delimiter delimiter = Delimiter::None) { | |||
1396 | return parseOperandList(result, delimiter, | |||
1397 | /*allowResultNumber=*/true, requiredOperandCount); | |||
1398 | } | |||
1399 | ||||
1400 | /// Parse zero or more trailing SSA comma-separated trailing operand | |||
1401 | /// references with a specified surrounding delimiter, and an optional | |||
1402 | /// required operand count. A leading comma is expected before the | |||
1403 | /// operands. | |||
1404 | ParseResult | |||
1405 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, | |||
1406 | Delimiter delimiter = Delimiter::None) { | |||
1407 | if (failed(parseOptionalComma())) | |||
1408 | return success(); // The comma is optional. | |||
1409 | return parseOperandList(result, delimiter); | |||
1410 | } | |||
1411 | ||||
1412 | /// Resolve an operand to an SSA value, emitting an error on failure. | |||
1413 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, | |||
1414 | Type type, | |||
1415 | SmallVectorImpl<Value> &result) = 0; | |||
1416 | ||||
1417 | /// Resolve a list of operands to SSA values, emitting an error on failure, or | |||
1418 | /// appending the results to the list on success. This method should be used | |||
1419 | /// when all operands have the same type. | |||
1420 | template <typename Operands = ArrayRef<UnresolvedOperand>> | |||
1421 | ParseResult resolveOperands(Operands &&operands, Type type, | |||
1422 | SmallVectorImpl<Value> &result) { | |||
1423 | for (const UnresolvedOperand &operand : operands) | |||
1424 | if (resolveOperand(operand, type, result)) | |||
1425 | return failure(); | |||
1426 | return success(); | |||
1427 | } | |||
1428 | template <typename Operands = ArrayRef<UnresolvedOperand>> | |||
1429 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, | |||
1430 | SmallVectorImpl<Value> &result) { | |||
1431 | return resolveOperands(std::forward<Operands>(operands), type, result); | |||
1432 | } | |||
1433 | ||||
1434 | /// Resolve a list of operands and a list of operand types to SSA values, | |||
1435 | /// emitting an error and returning failure, or appending the results | |||
1436 | /// to the list on success. | |||
1437 | template <typename Operands = ArrayRef<UnresolvedOperand>, | |||
1438 | typename Types = ArrayRef<Type>> | |||
1439 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> | |||
1440 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, | |||
1441 | SmallVectorImpl<Value> &result) { | |||
1442 | size_t operandSize = std::distance(operands.begin(), operands.end()); | |||
1443 | size_t typeSize = std::distance(types.begin(), types.end()); | |||
1444 | if (operandSize != typeSize) | |||
1445 | return emitError(loc) | |||
1446 | << operandSize << " operands present, but expected " << typeSize; | |||
1447 | ||||
1448 | for (auto [operand, type] : llvm::zip(operands, types)) | |||
1449 | if (resolveOperand(operand, type, result)) | |||
1450 | return failure(); | |||
1451 | return success(); | |||
1452 | } | |||
1453 | ||||
1454 | /// Parses an affine map attribute where dims and symbols are SSA operands. | |||
1455 | /// Operand values must come from single-result sources, and be valid | |||
1456 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. | |||
1457 | virtual ParseResult | |||
1458 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, | |||
1459 | Attribute &map, StringRef attrName, | |||
1460 | NamedAttrList &attrs, | |||
1461 | Delimiter delimiter = Delimiter::Square) = 0; | |||
1462 | ||||
1463 | /// Parses an affine expression where dims and symbols are SSA operands. | |||
1464 | /// Operand values must come from single-result sources, and be valid | |||
1465 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. | |||
1466 | virtual ParseResult | |||
1467 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, | |||
1468 | SmallVectorImpl<UnresolvedOperand> &symbOperands, | |||
1469 | AffineExpr &expr) = 0; | |||
1470 | ||||
1471 | //===--------------------------------------------------------------------===// | |||
1472 | // Argument Parsing | |||
1473 | //===--------------------------------------------------------------------===// | |||
1474 | ||||
1475 | struct Argument { | |||
1476 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. | |||
1477 | Type type; // Type. | |||
1478 | DictionaryAttr attrs; // Attributes if present. | |||
1479 | std::optional<Location> sourceLoc; // Source location specifier if present. | |||
1480 | }; | |||
1481 | ||||
1482 | /// Parse a single argument with the following syntax: | |||
1483 | /// | |||
1484 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` | |||
1485 | /// | |||
1486 | /// If `allowType` is false or `allowAttrs` are false then the respective | |||
1487 | /// parts of the grammar are not parsed. | |||
1488 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, | |||
1489 | bool allowAttrs = false) = 0; | |||
1490 | ||||
1491 | /// Parse a single argument if present. | |||
1492 | virtual OptionalParseResult | |||
1493 | parseOptionalArgument(Argument &result, bool allowType = false, | |||
1494 | bool allowAttrs = false) = 0; | |||
1495 | ||||
1496 | /// Parse zero or more arguments with a specified surrounding delimiter. | |||
1497 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, | |||
1498 | Delimiter delimiter = Delimiter::None, | |||
1499 | bool allowType = false, | |||
1500 | bool allowAttrs = false) = 0; | |||
1501 | ||||
1502 | //===--------------------------------------------------------------------===// | |||
1503 | // Region Parsing | |||
1504 | //===--------------------------------------------------------------------===// | |||
1505 | ||||
1506 | /// Parses a region. Any parsed blocks are appended to 'region' and must be | |||
1507 | /// moved to the op regions after the op is created. The first block of the | |||
1508 | /// region takes 'arguments'. | |||
1509 | /// | |||
1510 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to | |||
1511 | /// shadow the names of other existing SSA values defined above the region | |||
1512 | /// scope. 'enableNameShadowing' can only be set to true for regions attached | |||
1513 | /// to operations that are 'IsolatedFromAbove'. | |||
1514 | virtual ParseResult parseRegion(Region ®ion, | |||
1515 | ArrayRef<Argument> arguments = {}, | |||
1516 | bool enableNameShadowing = false) = 0; | |||
1517 | ||||
1518 | /// Parses a region if present. | |||
1519 | virtual OptionalParseResult | |||
1520 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, | |||
1521 | bool enableNameShadowing = false) = 0; | |||
1522 | ||||
1523 | /// Parses a region if present. If the region is present, a new region is | |||
1524 | /// allocated and placed in `region`. If no region is present or on failure, | |||
1525 | /// `region` remains untouched. | |||
1526 | virtual OptionalParseResult | |||
1527 | parseOptionalRegion(std::unique_ptr<Region> ®ion, | |||
1528 | ArrayRef<Argument> arguments = {}, | |||
1529 | bool enableNameShadowing = false) = 0; | |||
1530 | ||||
1531 | //===--------------------------------------------------------------------===// | |||
1532 | // Successor Parsing | |||
1533 | //===--------------------------------------------------------------------===// | |||
1534 | ||||
1535 | /// Parse a single operation successor. | |||
1536 | virtual ParseResult parseSuccessor(Block *&dest) = 0; | |||
1537 | ||||
1538 | /// Parse an optional operation successor. | |||
1539 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; | |||
1540 | ||||
1541 | /// Parse a single operation successor and its operand list. | |||
1542 | virtual ParseResult | |||
1543 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; | |||
1544 | ||||
1545 | //===--------------------------------------------------------------------===// | |||
1546 | // Type Parsing | |||
1547 | //===--------------------------------------------------------------------===// | |||
1548 | ||||
1549 | /// Parse a list of assignments of the form | |||
1550 | /// (%x1 = %y1, %x2 = %y2, ...) | |||
1551 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, | |||
1552 | SmallVectorImpl<UnresolvedOperand> &rhs) { | |||
1553 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); | |||
1554 | if (!result.has_value()) | |||
1555 | return emitError(getCurrentLocation(), "expected '('"); | |||
1556 | return result.value(); | |||
1557 | } | |||
1558 | ||||
1559 | virtual OptionalParseResult | |||
1560 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, | |||
1561 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; | |||
1562 | }; | |||
1563 | ||||
1564 | //===--------------------------------------------------------------------===// | |||
1565 | // Dialect OpAsm interface. | |||
1566 | //===--------------------------------------------------------------------===// | |||
1567 | ||||
1568 | /// A functor used to set the name of the start of a result group of an | |||
1569 | /// operation. See 'getAsmResultNames' below for more details. | |||
1570 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; | |||
1571 | ||||
1572 | /// A functor used to set the name of blocks in regions directly nested under | |||
1573 | /// an operation. | |||
1574 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; | |||
1575 | ||||
1576 | class OpAsmDialectInterface | |||
1577 | : public DialectInterface::Base<OpAsmDialectInterface> { | |||
1578 | public: | |||
1579 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} | |||
1580 | ||||
1581 | //===------------------------------------------------------------------===// | |||
1582 | // Aliases | |||
1583 | //===------------------------------------------------------------------===// | |||
1584 | ||||
1585 | /// Holds the result of `getAlias` hook call. | |||
1586 | enum class AliasResult { | |||
1587 | /// The object (type or attribute) is not supported by the hook | |||
1588 | /// and an alias was not provided. | |||
1589 | NoAlias, | |||
1590 | /// An alias was provided, but it might be overriden by other hook. | |||
1591 | OverridableAlias, | |||
1592 | /// An alias was provided and it should be used | |||
1593 | /// (no other hooks will be checked). | |||
1594 | FinalAlias | |||
1595 | }; | |||
1596 | ||||
1597 | /// Hooks for getting an alias identifier alias for a given symbol, that is | |||
1598 | /// not necessarily a part of this dialect. The identifier is used in place of | |||
1599 | /// the symbol when printing textual IR. These aliases must not contain `.` or | |||
1600 | /// end with a numeric digit([0-9]+). | |||
1601 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { | |||
1602 | return AliasResult::NoAlias; | |||
1603 | } | |||
1604 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { | |||
1605 | return AliasResult::NoAlias; | |||
1606 | } | |||
1607 | ||||
1608 | //===--------------------------------------------------------------------===// | |||
1609 | // Resources | |||
1610 | //===--------------------------------------------------------------------===// | |||
1611 | ||||
1612 | /// Declare a resource with the given key, returning a handle to use for any | |||
1613 | /// references of this resource key within the IR during parsing. The result | |||
1614 | /// of `getResourceKey` on the returned handle is permitted to be different | |||
1615 | /// than `key`. | |||
1616 | virtual FailureOr<AsmDialectResourceHandle> | |||
1617 | declareResource(StringRef key) const { | |||
1618 | return failure(); | |||
1619 | } | |||
1620 | ||||
1621 | /// Return a key to use for the given resource. This key should uniquely | |||
1622 | /// identify this resource within the dialect. | |||
1623 | virtual std::string | |||
1624 | getResourceKey(const AsmDialectResourceHandle &handle) const { | |||
1625 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1626) | |||
1626 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1626); | |||
1627 | } | |||
1628 | ||||
1629 | /// Hook for parsing resource entries. Returns failure if the entry was not | |||
1630 | /// valid, or could otherwise not be processed correctly. Any necessary errors | |||
1631 | /// can be emitted via the provided entry. | |||
1632 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; | |||
1633 | ||||
1634 | /// Hook for building resources to use during printing. The given `op` may be | |||
1635 | /// inspected to help determine what information to include. | |||
1636 | /// `referencedResources` contains all of the resources detected when printing | |||
1637 | /// 'op'. | |||
1638 | virtual void | |||
1639 | buildResources(Operation *op, | |||
1640 | const SetVector<AsmDialectResourceHandle> &referencedResources, | |||
1641 | AsmResourceBuilder &builder) const {} | |||
1642 | }; | |||
1643 | } // namespace mlir | |||
1644 | ||||
1645 | //===--------------------------------------------------------------------===// | |||
1646 | // Operation OpAsm interface. | |||
1647 | //===--------------------------------------------------------------------===// | |||
1648 | ||||
1649 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. | |||
1650 | #include "mlir/IR/OpAsmInterface.h.inc" | |||
1651 | ||||
1652 | namespace llvm { | |||
1653 | template <> | |||
1654 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { | |||
1655 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { | |||
1656 | return {DenseMapInfo<void *>::getEmptyKey(), | |||
1657 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; | |||
1658 | } | |||
1659 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { | |||
1660 | return {DenseMapInfo<void *>::getTombstoneKey(), | |||
1661 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; | |||
1662 | } | |||
1663 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { | |||
1664 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); | |||
1665 | } | |||
1666 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, | |||
1667 | const mlir::AsmDialectResourceHandle &rhs) { | |||
1668 | return lhs.getResource() == rhs.getResource(); | |||
1669 | } | |||
1670 | }; | |||
1671 | } // namespace llvm | |||
1672 | ||||
1673 | #endif |