clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name TypeParser.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-15/lib/clang/15.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Quant/IR -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/Dialect/Quant/IR -I include -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/llvm/include -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-15/lib/clang/15.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-04-20-140412-16051-1 -x c++ /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/Dialect/Quant/IR/TypeParser.cpp
1 | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 | |
8 | |
9 | #include "mlir/Dialect/Quant/QuantOps.h" |
10 | #include "mlir/Dialect/Quant/QuantTypes.h" |
11 | #include "mlir/IR/BuiltinTypes.h" |
12 | #include "mlir/IR/DialectImplementation.h" |
13 | #include "mlir/IR/Location.h" |
14 | #include "mlir/IR/Types.h" |
15 | #include "llvm/ADT/APFloat.h" |
16 | #include "llvm/ADT/StringSwitch.h" |
17 | #include "llvm/Support/Format.h" |
18 | #include "llvm/Support/MathExtras.h" |
19 | #include "llvm/Support/SourceMgr.h" |
20 | #include "llvm/Support/raw_ostream.h" |
21 | |
22 | using namespace mlir; |
23 | using namespace quant; |
24 | |
25 | static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) { |
26 | auto typeLoc = parser.getCurrentLocation(); |
27 | IntegerType type; |
28 | |
29 | |
30 | StringRef identifier; |
31 | unsigned storageTypeWidth = 0; |
32 | OptionalParseResult result = parser.parseOptionalType(type); |
33 | if (result.hasValue()) { |
34 | if (!succeeded(*result)) { |
35 | parser.parseType(type); |
36 | return nullptr; |
37 | } |
38 | isSigned = !type.isUnsigned(); |
39 | storageTypeWidth = type.getWidth(); |
40 | } else if (succeeded(parser.parseKeyword(&identifier))) { |
41 | |
42 | if (!identifier.consume_front("u")) { |
43 | parser.emitError(typeLoc, "illegal storage type prefix"); |
44 | return nullptr; |
45 | } |
46 | if (identifier.getAsInteger(10, storageTypeWidth)) { |
47 | parser.emitError(typeLoc, "expected storage type width"); |
48 | return nullptr; |
49 | } |
50 | isSigned = false; |
51 | type = parser.getBuilder().getIntegerType(storageTypeWidth); |
52 | } else { |
53 | return nullptr; |
54 | } |
55 | |
56 | if (storageTypeWidth == 0 || |
57 | storageTypeWidth > QuantizedType::MaxStorageBits) { |
58 | parser.emitError(typeLoc, "illegal storage type size: ") |
59 | << storageTypeWidth; |
60 | return nullptr; |
61 | } |
62 | |
63 | return type; |
64 | } |
65 | |
66 | static ParseResult parseStorageRange(DialectAsmParser &parser, |
67 | IntegerType storageType, bool isSigned, |
68 | int64_t &storageTypeMin, |
69 | int64_t &storageTypeMax) { |
70 | int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger( |
71 | isSigned, storageType.getWidth()); |
72 | int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger( |
73 | isSigned, storageType.getWidth()); |
74 | if (failed(parser.parseOptionalLess())) { |
| |
75 | storageTypeMin = defaultIntegerMin; |
76 | storageTypeMax = defaultIntegerMax; |
77 | return success(); |
78 | } |
79 | |
80 | |
81 | SMLoc minLoc = parser.getCurrentLocation(), maxLoc; |
82 | if (parser.parseInteger(storageTypeMin) || parser.parseColon() || |
| 12 | | Calling 'AsmParser::parseInteger' | |
|
| 20 | | Returning from 'AsmParser::parseInteger' | |
|
| |
83 | parser.getCurrentLocation(&maxLoc) || |
84 | parser.parseInteger(storageTypeMax) || parser.parseGreater()) |
85 | return failure(); |
86 | if (storageTypeMin < defaultIntegerMin) { |
| 22 | | The left operand of '<' is a garbage value |
|
87 | return parser.emitError(minLoc, "illegal storage type minimum: ") |
88 | << storageTypeMin; |
89 | } |
90 | if (storageTypeMax > defaultIntegerMax) { |
91 | return parser.emitError(maxLoc, "illegal storage type maximum: ") |
92 | << storageTypeMax; |
93 | } |
94 | return success(); |
95 | } |
96 | |
97 | static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser, |
98 | double &min, double &max) { |
99 | auto typeLoc = parser.getCurrentLocation(); |
100 | FloatType type; |
101 | |
102 | if (failed(parser.parseType(type))) { |
103 | parser.emitError(typeLoc, "expecting float expressed type"); |
104 | return nullptr; |
105 | } |
106 | |
107 | |
108 | if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() || |
109 | parser.parseFloat(max) || parser.parseGreater()) { |
110 | parser.emitError(typeLoc, "calibrated values must be present"); |
111 | return nullptr; |
112 | } |
113 | return type; |
114 | } |
115 | |
116 | |
117 | |
118 | |
119 | |
120 | |
121 | |
122 | |
123 | static Type parseAnyType(DialectAsmParser &parser) { |
124 | IntegerType storageType; |
125 | FloatType expressedType; |
126 | unsigned typeFlags = 0; |
127 | int64_t storageTypeMin; |
128 | int64_t storageTypeMax; |
129 | |
130 | |
131 | if (parser.parseLess()) |
132 | return nullptr; |
133 | |
134 | |
135 | bool isSigned = false; |
136 | storageType = parseStorageType(parser, isSigned); |
137 | if (!storageType) { |
138 | return nullptr; |
139 | } |
140 | if (isSigned) { |
141 | typeFlags |= QuantizationFlags::Signed; |
142 | } |
143 | |
144 | |
145 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
146 | storageTypeMax)) { |
147 | return nullptr; |
148 | } |
149 | |
150 | |
151 | if (succeeded(parser.parseOptionalColon())) { |
152 | if (parser.parseType(expressedType)) { |
153 | return nullptr; |
154 | } |
155 | } |
156 | |
157 | if (parser.parseGreater()) { |
158 | return nullptr; |
159 | } |
160 | |
161 | return parser.getChecked<AnyQuantizedType>( |
162 | typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax); |
163 | } |
164 | |
165 | static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale, |
166 | int64_t &zeroPoint) { |
167 | |
168 | |
169 | if (parser.parseFloat(scale)) |
170 | return failure(); |
171 | |
172 | |
173 | zeroPoint = 0; |
174 | if (failed(parser.parseOptionalColon())) { |
175 | |
176 | return success(); |
177 | } |
178 | |
179 | return parser.parseInteger(zeroPoint); |
180 | } |
181 | |
182 | |
183 | |
184 | |
185 | |
186 | |
187 | |
188 | |
189 | |
190 | |
191 | |
192 | |
193 | |
194 | |
195 | |
196 | |
197 | static Type parseUniformType(DialectAsmParser &parser) { |
198 | IntegerType storageType; |
199 | FloatType expressedType; |
200 | unsigned typeFlags = 0; |
201 | int64_t storageTypeMin; |
| 5 | | 'storageTypeMin' declared without an initial value | |
|
202 | int64_t storageTypeMax; |
203 | bool isPerAxis = false; |
204 | int32_t quantizedDimension; |
205 | SmallVector<double, 1> scales; |
206 | SmallVector<int64_t, 1> zeroPoints; |
207 | |
208 | |
209 | if (parser.parseLess()) { |
| |
210 | return nullptr; |
211 | } |
212 | |
213 | |
214 | bool isSigned = false; |
215 | storageType = parseStorageType(parser, isSigned); |
216 | if (!storageType) { |
| |
217 | return nullptr; |
218 | } |
219 | if (isSigned) { |
| |
220 | typeFlags |= QuantizationFlags::Signed; |
221 | } |
222 | |
223 | |
224 | if (parseStorageRange(parser, storageType, isSigned, storageTypeMin, |
| 9 | | Passing value via 4th parameter 'storageTypeMin' | |
|
| 10 | | Calling 'parseStorageRange' | |
|
225 | storageTypeMax)) { |
226 | return nullptr; |
227 | } |
228 | |
229 | |
230 | if (parser.parseColon() || parser.parseType(expressedType)) { |
231 | return nullptr; |
232 | } |
233 | |
234 | |
235 | if (succeeded(parser.parseOptionalColon())) { |
236 | if (parser.parseInteger(quantizedDimension)) |
237 | return nullptr; |
238 | isPerAxis = true; |
239 | } |
240 | |
241 | |
242 | if (parser.parseComma()) { |
243 | return nullptr; |
244 | } |
245 | |
246 | |
247 | |
248 | if (isPerAxis) { |
249 | if (parser.parseLBrace()) { |
250 | return nullptr; |
251 | } |
252 | } |
253 | |
254 | |
255 | SMLoc scaleZPLoc = parser.getCurrentLocation(); |
256 | do { |
257 | scales.resize(scales.size() + 1); |
258 | zeroPoints.resize(zeroPoints.size() + 1); |
259 | if (parseQuantParams(parser, scales.back(), zeroPoints.back())) { |
260 | return nullptr; |
261 | } |
262 | } while (isPerAxis && succeeded(parser.parseOptionalComma())); |
263 | |
264 | if (isPerAxis) { |
265 | if (parser.parseRBrace()) { |
266 | return nullptr; |
267 | } |
268 | } |
269 | |
270 | if (parser.parseGreater()) { |
271 | return nullptr; |
272 | } |
273 | |
274 | if (!isPerAxis && scales.size() > 1) { |
275 | return (parser.emitError(scaleZPLoc, |
276 | "multiple scales/zeroPoints provided, but " |
277 | "quantizedDimension wasn't specified"), |
278 | nullptr); |
279 | } |
280 | |
281 | if (isPerAxis) { |
282 | ArrayRef<double> scalesRef(scales.begin(), scales.end()); |
283 | ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end()); |
284 | return parser.getChecked<UniformQuantizedPerAxisType>( |
285 | typeFlags, storageType, expressedType, scalesRef, zeroPointsRef, |
286 | quantizedDimension, storageTypeMin, storageTypeMax); |
287 | } |
288 | |
289 | return parser.getChecked<UniformQuantizedType>( |
290 | typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(), |
291 | storageTypeMin, storageTypeMax); |
292 | } |
293 | |
294 | |
295 | |
296 | |
297 | |
298 | |
299 | |
300 | static Type parseCalibratedType(DialectAsmParser &parser) { |
301 | FloatType expressedType; |
302 | double min; |
303 | double max; |
304 | |
305 | |
306 | if (parser.parseLess()) |
307 | return nullptr; |
308 | |
309 | |
310 | expressedType = parseExpressedTypeAndRange(parser, min, max); |
311 | if (!expressedType) { |
312 | return nullptr; |
313 | } |
314 | |
315 | if (parser.parseGreater()) { |
316 | return nullptr; |
317 | } |
318 | |
319 | return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max); |
320 | } |
321 | |
322 | |
323 | Type QuantizationDialect::parseType(DialectAsmParser &parser) const { |
324 | |
325 | StringRef typeNameSpelling; |
326 | if (failed(parser.parseKeyword(&typeNameSpelling))) |
| |
327 | return nullptr; |
328 | |
329 | if (typeNameSpelling == "uniform") |
| 2 | | Assuming the condition is true | |
|
| |
330 | return parseUniformType(parser); |
| 4 | | Calling 'parseUniformType' | |
|
331 | if (typeNameSpelling == "any") |
332 | return parseAnyType(parser); |
333 | if (typeNameSpelling == "calibrated") |
334 | return parseCalibratedType(parser); |
335 | |
336 | parser.emitError(parser.getNameLoc(), |
337 | "unknown quantized type " + typeNameSpelling); |
338 | return nullptr; |
339 | } |
340 | |
341 | static void printStorageType(QuantizedType type, DialectAsmPrinter &out) { |
342 | |
343 | unsigned storageWidth = type.getStorageTypeIntegralWidth(); |
344 | bool isSigned = type.isSigned(); |
345 | if (isSigned) { |
346 | out << "i" << storageWidth; |
347 | } else { |
348 | out << "u" << storageWidth; |
349 | } |
350 | |
351 | |
352 | int64_t defaultIntegerMin = |
353 | QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth); |
354 | int64_t defaultIntegerMax = |
355 | QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth); |
356 | if (defaultIntegerMin != type.getStorageTypeMin() || |
357 | defaultIntegerMax != type.getStorageTypeMax()) { |
358 | out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax() |
359 | << ">"; |
360 | } |
361 | } |
362 | |
363 | static void printQuantParams(double scale, int64_t zeroPoint, |
364 | DialectAsmPrinter &out) { |
365 | out << scale; |
366 | if (zeroPoint != 0) { |
367 | out << ":" << zeroPoint; |
368 | } |
369 | } |
370 | |
371 | |
372 | static void printAnyQuantizedType(AnyQuantizedType type, |
373 | DialectAsmPrinter &out) { |
374 | out << "any<"; |
375 | printStorageType(type, out); |
376 | if (Type expressedType = type.getExpressedType()) { |
377 | out << ":" << expressedType; |
378 | } |
379 | out << ">"; |
380 | } |
381 | |
382 | |
383 | static void printUniformQuantizedType(UniformQuantizedType type, |
384 | DialectAsmPrinter &out) { |
385 | out << "uniform<"; |
386 | printStorageType(type, out); |
387 | out << ":" << type.getExpressedType() << ", "; |
388 | |
389 | |
390 | printQuantParams(type.getScale(), type.getZeroPoint(), out); |
391 | out << ">"; |
392 | } |
393 | |
394 | |
395 | static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type, |
396 | DialectAsmPrinter &out) { |
397 | out << "uniform<"; |
398 | printStorageType(type, out); |
399 | out << ":" << type.getExpressedType() << ":"; |
400 | out << type.getQuantizedDimension(); |
401 | out << ", "; |
402 | |
403 | |
404 | ArrayRef<double> scales = type.getScales(); |
405 | ArrayRef<int64_t> zeroPoints = type.getZeroPoints(); |
406 | out << "{"; |
407 | llvm::interleave( |
408 | llvm::seq<size_t>(0, scales.size()), out, |
409 | [&](size_t index) { |
410 | printQuantParams(scales[index], zeroPoints[index], out); |
411 | }, |
412 | ","); |
413 | out << "}>"; |
414 | } |
415 | |
416 | |
417 | static void printCalibratedQuantizedType(CalibratedQuantizedType type, |
418 | DialectAsmPrinter &out) { |
419 | out << "calibrated<" << type.getExpressedType(); |
420 | out << "<" << type.getMin() << ":" << type.getMax() << ">"; |
421 | out << ">"; |
422 | } |
423 | |
424 | |
425 | void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const { |
426 | if (auto anyType = type.dyn_cast<AnyQuantizedType>()) |
427 | printAnyQuantizedType(anyType, os); |
428 | else if (auto uniformType = type.dyn_cast<UniformQuantizedType>()) |
429 | printUniformQuantizedType(uniformType, os); |
430 | else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>()) |
431 | printUniformQuantizedPerAxisType(perAxisType, os); |
432 | else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>()) |
433 | printCalibratedQuantizedType(calibratedType, os); |
434 | else |
435 | llvm_unreachable("Unhandled quantized type"); |
436 | } |
1 | |
2 | |
3 | |
4 | |
5 | |
6 | |
7 | |
8 | |
9 | |
10 | |
11 | |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | |
24 | class Builder; |
25 | |
26 | |
27 | |
28 | |
29 | |
30 | |
31 | |
32 | class AsmPrinter { |
33 | public: |
34 | |
35 | |
36 | class Impl; |
37 | |
38 | |
39 | AsmPrinter(Impl &impl) : impl(&impl) {} |
40 | virtual ~AsmPrinter(); |
41 | |
42 | |
43 | virtual raw_ostream &getStream() const; |
44 | |
45 | |
46 | |
47 | |
48 | virtual void printFloat(const APFloat &value); |
49 | |
50 | virtual void printType(Type type); |
51 | virtual void printAttribute(Attribute attr); |
52 | |
53 | |
54 | template <typename AttrOrType> |
55 | using has_print_method = |
56 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
57 | template <typename AttrOrType> |
58 | using detect_has_print_method = |
59 | llvm::is_detected<has_print_method, AttrOrType>; |
60 | |
61 | |
62 | |
63 | |
64 | template <typename AttrOrType, |
65 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
66 | *sfinae = nullptr> |
67 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
68 | if (succeeded(printAlias(attrOrType))) |
69 | return; |
70 | attrOrType.print(*this); |
71 | } |
72 | |
73 | |
74 | |
75 | |
76 | template <typename AttrOrType, |
77 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
78 | *sfinae = nullptr> |
79 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
80 | *this << attrOrType; |
81 | } |
82 | |
83 | |
84 | |
85 | virtual void printAttributeWithoutType(Attribute attr); |
86 | |
87 | |
88 | |
89 | virtual void printKeywordOrString(StringRef keyword); |
90 | |
91 | |
92 | |
93 | |
94 | |
95 | virtual void printSymbolName(StringRef symbolRef); |
96 | |
97 | |
98 | template <typename TypeRange> |
99 | void printOptionalArrowTypeList(TypeRange &&types) { |
100 | if (types.begin() != types.end()) |
101 | printArrowTypeList(types); |
102 | } |
103 | template <typename TypeRange> |
104 | void printArrowTypeList(TypeRange &&types) { |
105 | auto &os = getStream() << " -> "; |
106 | |
107 | bool wrapped = !llvm::hasSingleElement(types) || |
108 | (*types.begin()).template isa<FunctionType>(); |
109 | if (wrapped) |
110 | os << '('; |
111 | llvm::interleaveComma(types, *this); |
112 | if (wrapped) |
113 | os << ')'; |
114 | } |
115 | |
116 | |
117 | template <typename InputRangeT, typename ResultRangeT> |
118 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
119 | auto &os = getStream(); |
120 | os << '('; |
121 | llvm::interleaveComma(inputs, *this); |
122 | os << ')'; |
123 | printArrowTypeList(results); |
124 | } |
125 | |
126 | protected: |
127 | |
128 | |
129 | AsmPrinter() {} |
130 | |
131 | private: |
132 | AsmPrinter(const AsmPrinter &) = delete; |
133 | void operator=(const AsmPrinter &) = delete; |
134 | |
135 | |
136 | |
137 | virtual LogicalResult printAlias(Attribute attr); |
138 | |
139 | |
140 | |
141 | virtual LogicalResult printAlias(Type type); |
142 | |
143 | |
144 | Impl *impl{nullptr}; |
145 | }; |
146 | |
147 | template <typename AsmPrinterT> |
148 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
149 | AsmPrinterT &> |
150 | operator<<(AsmPrinterT &p, Type type) { |
151 | p.printType(type); |
152 | return p; |
153 | } |
154 | |
155 | template <typename AsmPrinterT> |
156 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
157 | AsmPrinterT &> |
158 | operator<<(AsmPrinterT &p, Attribute attr) { |
159 | p.printAttribute(attr); |
160 | return p; |
161 | } |
162 | |
163 | template <typename AsmPrinterT> |
164 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
165 | AsmPrinterT &> |
166 | operator<<(AsmPrinterT &p, const APFloat &value) { |
167 | p.printFloat(value); |
168 | return p; |
169 | } |
170 | template <typename AsmPrinterT> |
171 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
172 | AsmPrinterT &> |
173 | operator<<(AsmPrinterT &p, float value) { |
174 | return p << APFloat(value); |
175 | } |
176 | template <typename AsmPrinterT> |
177 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
178 | AsmPrinterT &> |
179 | operator<<(AsmPrinterT &p, double value) { |
180 | return p << APFloat(value); |
181 | } |
182 | |
183 | |
184 | |
185 | |
186 | template < |
187 | typename AsmPrinterT, typename T, |
188 | typename std::enable_if<!std::is_convertible<T &, Value &>::value && |
189 | !std::is_convertible<T &, Type &>::value && |
190 | !std::is_convertible<T &, Attribute &>::value && |
191 | !std::is_convertible<T &, ValueRange>::value && |
192 | !std::is_convertible<T &, APFloat &>::value && |
193 | !llvm::is_one_of<T, bool, float, double>::value, |
194 | T>::type * = nullptr> |
195 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
196 | AsmPrinterT &> |
197 | operator<<(AsmPrinterT &p, const T &other) { |
198 | p.getStream() << other; |
199 | return p; |
200 | } |
201 | |
202 | template <typename AsmPrinterT> |
203 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
204 | AsmPrinterT &> |
205 | operator<<(AsmPrinterT &p, bool value) { |
206 | return p << (value ? StringRef("true") : "false"); |
207 | } |
208 | |
209 | template <typename AsmPrinterT, typename ValueRangeT> |
210 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
211 | AsmPrinterT &> |
212 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
213 | llvm::interleaveComma(types, p); |
214 | return p; |
215 | } |
216 | template <typename AsmPrinterT> |
217 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
218 | AsmPrinterT &> |
219 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
220 | llvm::interleaveComma(types, p); |
221 | return p; |
222 | } |
223 | template <typename AsmPrinterT, typename ElementT> |
224 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
225 | AsmPrinterT &> |
226 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
227 | llvm::interleaveComma(types, p); |
228 | return p; |
229 | } |
230 | |
231 | |
232 | |
233 | |
234 | |
235 | |
236 | |
237 | class OpAsmPrinter : public AsmPrinter { |
238 | public: |
239 | using AsmPrinter::AsmPrinter; |
240 | ~OpAsmPrinter() override; |
241 | |
242 | |
243 | |
244 | virtual void printNewline() = 0; |
245 | |
246 | |
247 | |
248 | |
249 | |
250 | |
251 | virtual void printRegionArgument(BlockArgument arg, |
252 | ArrayRef<NamedAttribute> argAttrs = {}, |
253 | bool omitType = false) = 0; |
254 | |
255 | |
256 | virtual void printOperand(Value value) = 0; |
257 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
258 | |
259 | |
260 | template <typename ContainerType> |
261 | void printOperands(const ContainerType &container) { |
262 | printOperands(container.begin(), container.end()); |
263 | } |
264 | |
265 | |
266 | template <typename IteratorType> |
267 | void printOperands(IteratorType it, IteratorType end) { |
268 | if (it == end) |
269 | return; |
270 | printOperand(*it); |
271 | for (++it; it != end; ++it) { |
272 | getStream() << ", "; |
273 | printOperand(*it); |
274 | } |
275 | } |
276 | |
277 | |
278 | virtual void printSuccessor(Block *successor) = 0; |
279 | |
280 | |
281 | virtual void printSuccessorAndUseList(Block *successor, |
282 | ValueRange succOperands) = 0; |
283 | |
284 | |
285 | |
286 | |
287 | |
288 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
289 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
290 | |
291 | |
292 | |
293 | virtual void |
294 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
295 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
296 | |
297 | |
298 | |
299 | |
300 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
301 | |
302 | |
303 | |
304 | |
305 | |
306 | |
307 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
308 | bool printBlockTerminators = true, |
309 | bool printEmptyBlock = false) = 0; |
310 | |
311 | |
312 | |
313 | |
314 | |
315 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
316 | |
317 | |
318 | |
319 | |
320 | |
321 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
322 | ValueRange operands) = 0; |
323 | |
324 | |
325 | |
326 | |
327 | |
328 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
329 | ValueRange symOperands) = 0; |
330 | |
331 | |
332 | void printFunctionalType(Operation *op); |
333 | using AsmPrinter::printFunctionalType; |
334 | }; |
335 | |
336 | |
337 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
338 | p.printOperand(value); |
339 | return p; |
340 | } |
341 | |
342 | template <typename T, |
343 | typename std::enable_if<std::is_convertible<T &, ValueRange>::value && |
344 | !std::is_convertible<T &, Value &>::value, |
345 | T>::type * = nullptr> |
346 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
347 | p.printOperands(values); |
348 | return p; |
349 | } |
350 | |
351 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
352 | p.printSuccessor(value); |
353 | return p; |
354 | } |
355 | |
356 | |
357 | |
358 | |
359 | |
360 | |
361 | |
362 | class AsmParser { |
363 | public: |
364 | AsmParser() = default; |
365 | virtual ~AsmParser(); |
366 | |
367 | MLIRContext *getContext() const; |
368 | |
369 | |
370 | virtual SMLoc getNameLoc() const = 0; |
371 | |
372 | |
373 | |
374 | |
375 | |
376 | |
377 | virtual InFlightDiagnostic emitError(SMLoc loc, |
378 | const Twine &message = {}) = 0; |
379 | |
380 | |
381 | |
382 | virtual Builder &getBuilder() const = 0; |
383 | |
384 | |
385 | |
386 | virtual SMLoc getCurrentLocation() = 0; |
387 | ParseResult getCurrentLocation(SMLoc *loc) { |
388 | *loc = getCurrentLocation(); |
389 | return success(); |
390 | } |
391 | |
392 | |
393 | |
394 | |
395 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
396 | |
397 | |
398 | |
399 | |
400 | |
401 | |
402 | virtual ParseResult parseArrow() = 0; |
403 | |
404 | |
405 | virtual ParseResult parseOptionalArrow() = 0; |
406 | |
407 | |
408 | virtual ParseResult parseLBrace() = 0; |
409 | |
410 | |
411 | virtual ParseResult parseOptionalLBrace() = 0; |
412 | |
413 | |
414 | virtual ParseResult parseRBrace() = 0; |
415 | |
416 | |
417 | virtual ParseResult parseOptionalRBrace() = 0; |
418 | |
419 | |
420 | virtual ParseResult parseColon() = 0; |
421 | |
422 | |
423 | virtual ParseResult parseOptionalColon() = 0; |
424 | |
425 | |
426 | virtual ParseResult parseComma() = 0; |
427 | |
428 | |
429 | virtual ParseResult parseOptionalComma() = 0; |
430 | |
431 | |
432 | virtual ParseResult parseEqual() = 0; |
433 | |
434 | |
435 | virtual ParseResult parseOptionalEqual() = 0; |
436 | |
437 | |
438 | virtual ParseResult parseLess() = 0; |
439 | |
440 | |
441 | virtual ParseResult parseOptionalLess() = 0; |
442 | |
443 | |
444 | virtual ParseResult parseGreater() = 0; |
445 | |
446 | |
447 | virtual ParseResult parseOptionalGreater() = 0; |
448 | |
449 | |
450 | virtual ParseResult parseQuestion() = 0; |
451 | |
452 | |
453 | virtual ParseResult parseOptionalQuestion() = 0; |
454 | |
455 | |
456 | virtual ParseResult parsePlus() = 0; |
457 | |
458 | |
459 | virtual ParseResult parseOptionalPlus() = 0; |
460 | |
461 | |
462 | virtual ParseResult parseStar() = 0; |
463 | |
464 | |
465 | virtual ParseResult parseOptionalStar() = 0; |
466 | |
467 | |
468 | ParseResult parseString(std::string *string) { |
469 | auto loc = getCurrentLocation(); |
470 | if (parseOptionalString(string)) |
471 | return emitError(loc, "expected string"); |
472 | return success(); |
473 | } |
474 | |
475 | |
476 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
477 | |
478 | |
479 | ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { |
480 | auto loc = getCurrentLocation(); |
481 | if (parseOptionalKeyword(keyword)) |
482 | return emitError(loc, "expected '") << keyword << "'" << msg; |
483 | return success(); |
484 | } |
485 | |
486 | |
487 | ParseResult parseKeyword(StringRef *keyword) { |
488 | auto loc = getCurrentLocation(); |
489 | if (parseOptionalKeyword(keyword)) |
490 | return emitError(loc, "expected valid keyword"); |
491 | return success(); |
492 | } |
493 | |
494 | |
495 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
496 | |
497 | |
498 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
499 | |
500 | |
501 | |
502 | virtual ParseResult |
503 | parseOptionalKeyword(StringRef *keyword, |
504 | ArrayRef<StringRef> allowedValues) = 0; |
505 | |
506 | |
507 | ParseResult parseKeywordOrString(std::string *result) { |
508 | if (failed(parseOptionalKeywordOrString(result))) |
509 | return emitError(getCurrentLocation()) |
510 | << "expected valid keyword or string"; |
511 | return success(); |
512 | } |
513 | |
514 | |
515 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
516 | |
517 | |
518 | virtual ParseResult parseLParen() = 0; |
519 | |
520 | |
521 | virtual ParseResult parseOptionalLParen() = 0; |
522 | |
523 | |
524 | virtual ParseResult parseRParen() = 0; |
525 | |
526 | |
527 | virtual ParseResult parseOptionalRParen() = 0; |
528 | |
529 | |
530 | virtual ParseResult parseLSquare() = 0; |
531 | |
532 | |
533 | virtual ParseResult parseOptionalLSquare() = 0; |
534 | |
535 | |
536 | virtual ParseResult parseRSquare() = 0; |
537 | |
538 | |
539 | virtual ParseResult parseOptionalRSquare() = 0; |
540 | |
541 | |
542 | virtual ParseResult parseOptionalEllipsis() = 0; |
543 | |
544 | |
545 | virtual ParseResult parseFloat(double &result) = 0; |
546 | |
547 | |
548 | template <typename IntT> |
549 | ParseResult parseInteger(IntT &result) { |
550 | auto loc = getCurrentLocation(); |
551 | OptionalParseResult parseResult = parseOptionalInteger(result); |
| 13 | | Calling 'AsmParser::parseOptionalInteger' | |
|
| 17 | | Returning from 'AsmParser::parseOptionalInteger' | |
|
552 | if (!parseResult.hasValue()) |
| |
553 | return emitError(loc, "expected integer value"); |
554 | return *parseResult; |
| 19 | | Returning without writing to 'result' | |
|
555 | } |
556 | |
557 | |
558 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
559 | |
560 | template <typename IntT> |
561 | OptionalParseResult parseOptionalInteger(IntT &result) { |
562 | auto loc = getCurrentLocation(); |
563 | |
564 | |
565 | APInt uintResult; |
566 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
567 | if (!parseResult.hasValue() || failed(*parseResult)) |
| 14 | | Assuming the condition is false | |
|
| |
568 | return parseResult; |
| 16 | | Returning without writing to 'result' | |
|
569 | |
570 | |
571 | |
572 | |
573 | result = |
574 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT).getLimitedValue(); |
575 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
576 | return emitError(loc, "integer value too large"); |
577 | return success(); |
578 | } |
579 | |
580 | |
581 | |
582 | enum class Delimiter { |
583 | |
584 | None, |
585 | |
586 | Paren, |
587 | |
588 | Square, |
589 | |
590 | LessGreater, |
591 | |
592 | Braces, |
593 | |
594 | OptionalParen, |
595 | |
596 | OptionalSquare, |
597 | |
598 | OptionalLessGreater, |
599 | |
600 | OptionalBraces, |
601 | }; |
602 | |
603 | |
604 | |
605 | |
606 | |
607 | |
608 | |
609 | virtual ParseResult |
610 | parseCommaSeparatedList(Delimiter delimiter, |
611 | function_ref<ParseResult()> parseElementFn, |
612 | StringRef contextMessage = StringRef()) = 0; |
613 | |
614 | |
615 | |
616 | ParseResult |
617 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
618 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
619 | } |
620 | |
621 | |
622 | |
623 | |
624 | |
625 | |
626 | |
627 | |
628 | |
629 | template <typename T, typename... ParamsT> |
630 | T getChecked(SMLoc loc, ParamsT &&... params) { |
631 | return T::getChecked([&] { return emitError(loc); }, |
632 | std::forward<ParamsT>(params)...); |
633 | } |
634 | |
635 | |
636 | template <typename T, typename... ParamsT> |
637 | T getChecked(ParamsT &&... params) { |
638 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
639 | std::forward<ParamsT>(params)...); |
640 | } |
641 | |
642 | |
643 | |
644 | |
645 | |
646 | |
647 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
648 | |
649 | |
650 | |
651 | virtual ParseResult parseCustomAttributeWithFallback( |
652 | Attribute &result, Type type, |
653 | function_ref<ParseResult(Attribute &result, Type type)> |
654 | parseAttribute) = 0; |
655 | |
656 | |
657 | template <typename AttrType> |
658 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
659 | SMLoc loc = getCurrentLocation(); |
660 | |
661 | |
662 | Attribute attr; |
663 | if (parseAttribute(attr, type)) |
664 | return failure(); |
665 | |
666 | |
667 | if (!(result = attr.dyn_cast<AttrType>())) |
668 | return emitError(loc, "invalid kind of attribute specified"); |
669 | |
670 | return success(); |
671 | } |
672 | |
673 | |
674 | |
675 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
676 | NamedAttrList &attrs) { |
677 | return parseAttribute(result, Type(), attrName, attrs); |
678 | } |
679 | |
680 | |
681 | template <typename AttrType> |
682 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
683 | NamedAttrList &attrs) { |
684 | return parseAttribute(result, Type(), attrName, attrs); |
685 | } |
686 | |
687 | |
688 | |
689 | |
690 | template <typename AttrType> |
691 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
692 | NamedAttrList &attrs) { |
693 | SMLoc loc = getCurrentLocation(); |
694 | |
695 | |
696 | Attribute attr; |
697 | if (parseAttribute(attr, type)) |
698 | return failure(); |
699 | |
700 | |
701 | result = attr.dyn_cast<AttrType>(); |
702 | if (!result) |
703 | return emitError(loc, "invalid kind of attribute specified"); |
704 | |
705 | attrs.append(attrName, result); |
706 | return success(); |
707 | } |
708 | |
709 | |
710 | template <typename AttrType> |
711 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
712 | std::declval<Type>())); |
713 | template <typename AttrType> |
714 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
715 | |
716 | |
717 | |
718 | |
719 | |
720 | template <typename AttrType> |
721 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
722 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
723 | StringRef attrName, NamedAttrList &attrs) { |
724 | SMLoc loc = getCurrentLocation(); |
725 | |
726 | |
727 | Attribute attr; |
728 | if (parseCustomAttributeWithFallback( |
729 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
730 | result = AttrType::parse(*this, type); |
731 | if (!result) |
732 | return failure(); |
733 | return success(); |
734 | })) |
735 | return failure(); |
736 | |
737 | |
738 | result = attr.dyn_cast<AttrType>(); |
739 | if (!result) |
740 | return emitError(loc, "invalid kind of attribute specified"); |
741 | |
742 | attrs.append(attrName, result); |
743 | return success(); |
744 | } |
745 | |
746 | |
747 | template <typename AttrType> |
748 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
749 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
750 | StringRef attrName, NamedAttrList &attrs) { |
751 | return parseAttribute(result, type, attrName, attrs); |
752 | } |
753 | |
754 | |
755 | |
756 | |
757 | template <typename AttrType> |
758 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
759 | parseCustomAttributeWithFallback(AttrType &result) { |
760 | SMLoc loc = getCurrentLocation(); |
761 | |
762 | |
763 | Attribute attr; |
764 | if (parseCustomAttributeWithFallback( |
765 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
766 | result = AttrType::parse(*this, type); |
767 | return success(!!result); |
768 | })) |
769 | return failure(); |
770 | |
771 | |
772 | result = attr.dyn_cast<AttrType>(); |
773 | if (!result) |
774 | return emitError(loc, "invalid kind of attribute specified"); |
775 | return success(); |
776 | } |
777 | |
778 | |
779 | template <typename AttrType> |
780 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
781 | parseCustomAttributeWithFallback(AttrType &result) { |
782 | return parseAttribute(result); |
783 | } |
784 | |
785 | |
786 | |
787 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
788 | Type type = {}) = 0; |
789 | |
790 | |
791 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
792 | Type type = {}) = 0; |
793 | |
794 | |
795 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
796 | Type type = {}) = 0; |
797 | |
798 | |
799 | |
800 | template <typename AttrType> |
801 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
802 | StringRef attrName, |
803 | NamedAttrList &attrs) { |
804 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
805 | } |
806 | |
807 | |
808 | |
809 | template <typename AttrType> |
810 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
811 | StringRef attrName, |
812 | NamedAttrList &attrs) { |
813 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
814 | if (parseResult.hasValue() && succeeded(*parseResult)) |
815 | attrs.append(attrName, result); |
816 | return parseResult; |
817 | } |
818 | |
819 | |
820 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
821 | |
822 | |
823 | |
824 | virtual ParseResult |
825 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
826 | |
827 | |
828 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
829 | |
830 | |
831 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
832 | |
833 | |
834 | |
835 | |
836 | |
837 | |
838 | |
839 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
840 | NamedAttrList &attrs) { |
841 | if (failed(parseOptionalSymbolName(result, attrName, attrs))) |
842 | return emitError(getCurrentLocation()) |
843 | << "expected valid '@'-identifier for symbol name"; |
844 | return success(); |
845 | } |
846 | |
847 | |
848 | |
849 | virtual ParseResult parseOptionalSymbolName(StringAttr &result, |
850 | StringRef attrName, |
851 | NamedAttrList &attrs) = 0; |
852 | |
853 | |
854 | |
855 | |
856 | |
857 | |
858 | virtual ParseResult parseType(Type &result) = 0; |
859 | |
860 | |
861 | |
862 | virtual ParseResult parseCustomTypeWithFallback( |
863 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
864 | |
865 | |
866 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
867 | |
868 | |
869 | template <typename TypeT> |
870 | ParseResult parseType(TypeT &result) { |
871 | SMLoc loc = getCurrentLocation(); |
872 | |
873 | |
874 | Type type; |
875 | if (parseType(type)) |
876 | return failure(); |
877 | |
878 | |
879 | result = type.dyn_cast<TypeT>(); |
880 | if (!result) |
881 | return emitError(loc, "invalid kind of type specified"); |
882 | |
883 | return success(); |
884 | } |
885 | |
886 | |
887 | template <typename TypeT> |
888 | using type_has_parse_method = |
889 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
890 | template <typename TypeT> |
891 | using detect_type_has_parse_method = |
892 | llvm::is_detected<type_has_parse_method, TypeT>; |
893 | |
894 | |
895 | |
896 | |
897 | template <typename TypeT> |
898 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
899 | parseCustomTypeWithFallback(TypeT &result) { |
900 | SMLoc loc = getCurrentLocation(); |
901 | |
902 | |
903 | Type type; |
904 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
905 | result = TypeT::parse(*this); |
906 | return success(!!result); |
907 | })) |
908 | return failure(); |
909 | |
910 | |
911 | result = type.dyn_cast<TypeT>(); |
912 | if (!result) |
913 | return emitError(loc, "invalid kind of Type specified"); |
914 | return success(); |
915 | } |
916 | |
917 | |
918 | template <typename TypeT> |
919 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
920 | parseCustomTypeWithFallback(TypeT &result) { |
921 | return parseType(result); |
922 | } |
923 | |
924 | |
925 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
926 | do { |
927 | Type type; |
928 | if (parseType(type)) |
929 | return failure(); |
930 | result.push_back(type); |
931 | } while (succeeded(parseOptionalComma())); |
932 | return success(); |
933 | } |
934 | |
935 | |
936 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
937 | |
938 | |
939 | virtual ParseResult |
940 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
941 | |
942 | |
943 | virtual ParseResult parseColonType(Type &result) = 0; |
944 | |
945 | |
946 | template <typename TypeType> |
947 | ParseResult parseColonType(TypeType &result) { |
948 | SMLoc loc = getCurrentLocation(); |
949 | |
950 | |
951 | Type type; |
952 | if (parseColonType(type)) |
953 | return failure(); |
954 | |
955 | |
956 | result = type.dyn_cast<TypeType>(); |
957 | if (!result) |
958 | return emitError(loc, "invalid kind of type specified"); |
959 | |
960 | return success(); |
961 | } |
962 | |
963 | |
964 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
965 | |
966 | |
967 | |
968 | virtual ParseResult |
969 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
970 | |
971 | |
972 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
973 | return failure(parseKeyword(keyword) || parseType(result)); |
974 | } |
975 | |
976 | |
977 | |
978 | |
979 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
980 | result.push_back(type); |
981 | return success(); |
982 | } |
983 | |
984 | |
985 | |
986 | |
987 | ParseResult addTypesToList(ArrayRef<Type> types, |
988 | SmallVectorImpl<Type> &result) { |
989 | result.append(types.begin(), types.end()); |
990 | return success(); |
991 | } |
992 | |
993 | |
994 | |
995 | |
996 | |
997 | |
998 | |
999 | |
1000 | |
1001 | |
1002 | |
1003 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1004 | bool allowDynamic = true) = 0; |
1005 | |
1006 | |
1007 | |
1008 | |
1009 | virtual ParseResult parseXInDimensionList() = 0; |
1010 | |
1011 | private: |
1012 | AsmParser(const AsmParser &) = delete; |
1013 | void operator=(const AsmParser &) = delete; |
1014 | }; |
1015 | |
1016 | |
1017 | |
1018 | |
1019 | |
1020 | |
1021 | |
1022 | |
1023 | |
1024 | |
1025 | |
1026 | |
1027 | |
1028 | |
1029 | |
1030 | |
1031 | |
1032 | |
1033 | |
1034 | |
1035 | class OpAsmParser : public AsmParser { |
1036 | public: |
1037 | using AsmParser::AsmParser; |
1038 | ~OpAsmParser() override; |
1039 | |
1040 | |
1041 | |
1042 | |
1043 | |
1044 | virtual ParseResult |
1045 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1046 | |
1047 | |
1048 | |
1049 | |
1050 | |
1051 | |
1052 | |
1053 | |
1054 | |
1055 | |
1056 | |
1057 | |
1058 | virtual std::pair<StringRef, unsigned> |
1059 | getResultName(unsigned resultNo) const = 0; |
1060 | |
1061 | |
1062 | |
1063 | virtual size_t getNumResults() const = 0; |
1064 | |
1065 | |
1066 | |
1067 | |
1068 | |
1069 | |
1070 | |
1071 | |
1072 | |
1073 | |
1074 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1075 | Block::iterator insertPt) = 0; |
1076 | |
1077 | |
1078 | |
1079 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1080 | |
1081 | |
1082 | |
1083 | |
1084 | |
1085 | |
1086 | struct UnresolvedOperand { |
1087 | SMLoc location; |
1088 | StringRef name; |
1089 | unsigned number; |
1090 | }; |
1091 | |
1092 | |
1093 | |
1094 | |
1095 | |
1096 | |
1097 | virtual ParseResult parseGenericOperationAfterOpName( |
1098 | OperationState &result, |
1099 | Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None, |
1100 | Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None, |
1101 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1102 | llvm::None, |
1103 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None, |
1104 | Optional<FunctionType> parsedFnType = llvm::None) = 0; |
1105 | |
1106 | |
1107 | virtual ParseResult parseOperand(UnresolvedOperand &result) = 0; |
1108 | |
1109 | |
1110 | virtual OptionalParseResult |
1111 | parseOptionalOperand(UnresolvedOperand &result) = 0; |
1112 | |
1113 | |
1114 | |
1115 | virtual ParseResult |
1116 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1117 | int requiredOperandCount = -1, |
1118 | Delimiter delimiter = Delimiter::None) = 0; |
1119 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1120 | Delimiter delimiter) { |
1121 | return parseOperandList(result, -1, delimiter); |
1122 | } |
1123 | |
1124 | |
1125 | |
1126 | |
1127 | virtual ParseResult |
1128 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1129 | int requiredOperandCount = -1, |
1130 | Delimiter delimiter = Delimiter::None) = 0; |
1131 | ParseResult |
1132 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1133 | Delimiter delimiter) { |
1134 | return parseTrailingOperandList(result, -1, |
1135 | delimiter); |
1136 | } |
1137 | |
1138 | |
1139 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1140 | Type type, |
1141 | SmallVectorImpl<Value> &result) = 0; |
1142 | |
1143 | |
1144 | |
1145 | |
1146 | ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type, |
1147 | SmallVectorImpl<Value> &result) { |
1148 | for (auto elt : operands) |
1149 | if (resolveOperand(elt, type, result)) |
1150 | return failure(); |
1151 | return success(); |
1152 | } |
1153 | |
1154 | |
1155 | |
1156 | |
1157 | ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, |
1158 | ArrayRef<Type> types, SMLoc loc, |
1159 | SmallVectorImpl<Value> &result) { |
1160 | if (operands.size() != types.size()) |
1161 | return emitError(loc) |
1162 | << operands.size() << " operands present, but expected " |
1163 | << types.size(); |
1164 | |
1165 | for (unsigned i = 0, e = operands.size(); i != e; ++i) |
1166 | if (resolveOperand(operands[i], types[i], result)) |
1167 | return failure(); |
1168 | return success(); |
1169 | } |
1170 | template <typename Operands> |
1171 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1172 | SmallVectorImpl<Value> &result) { |
1173 | return resolveOperands(std::forward<Operands>(operands), |
1174 | ArrayRef<Type>(type), loc, result); |
1175 | } |
1176 | template <typename Operands, typename Types> |
1177 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1178 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1179 | SmallVectorImpl<Value> &result) { |
1180 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1181 | size_t typeSize = std::distance(types.begin(), types.end()); |
1182 | if (operandSize != typeSize) |
1183 | return emitError(loc) |
1184 | << operandSize << " operands present, but expected " << typeSize; |
1185 | |
1186 | for (auto it : llvm::zip(operands, types)) |
1187 | if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) |
1188 | return failure(); |
1189 | return success(); |
1190 | } |
1191 | |
1192 | |
1193 | |
1194 | |
1195 | virtual ParseResult |
1196 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1197 | Attribute &map, StringRef attrName, |
1198 | NamedAttrList &attrs, |
1199 | Delimiter delimiter = Delimiter::Square) = 0; |
1200 | |
1201 | |
1202 | |
1203 | |
1204 | virtual ParseResult |
1205 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1206 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1207 | AffineExpr &expr) = 0; |
1208 | |
1209 | |
1210 | |
1211 | |
1212 | |
1213 | |
1214 | |
1215 | |
1216 | |
1217 | |
1218 | |
1219 | |
1220 | |
1221 | virtual ParseResult parseRegion(Region ®ion, |
1222 | ArrayRef<UnresolvedOperand> arguments = {}, |
1223 | ArrayRef<Type> argTypes = {}, |
1224 | ArrayRef<Location> argLocations = {}, |
1225 | bool enableNameShadowing = false) = 0; |
1226 | |
1227 | |
1228 | virtual OptionalParseResult parseOptionalRegion( |
1229 | Region ®ion, ArrayRef<UnresolvedOperand> arguments = {}, |
1230 | ArrayRef<Type> argTypes = {}, ArrayRef<Location> argLocations = {}, |
1231 | bool enableNameShadowing = false) = 0; |
1232 | |
1233 | |
1234 | |
1235 | |
1236 | virtual OptionalParseResult |
1237 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1238 | ArrayRef<UnresolvedOperand> arguments = {}, |
1239 | ArrayRef<Type> argTypes = {}, |
1240 | bool enableNameShadowing = false) = 0; |
1241 | |
1242 | |
1243 | |
1244 | virtual ParseResult parseRegionArgument(UnresolvedOperand &argument) = 0; |
1245 | |
1246 | |
1247 | |
1248 | |
1249 | |
1250 | virtual ParseResult |
1251 | parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result, |
1252 | int requiredOperandCount = -1, |
1253 | Delimiter delimiter = Delimiter::None) = 0; |
1254 | virtual ParseResult |
1255 | parseRegionArgumentList(SmallVectorImpl<UnresolvedOperand> &result, |
1256 | Delimiter delimiter) { |
1257 | return parseRegionArgumentList(result, -1, |
1258 | delimiter); |
1259 | } |
1260 | |
1261 | |
1262 | virtual ParseResult |
1263 | parseOptionalRegionArgument(UnresolvedOperand &argument) = 0; |
1264 | |
1265 | |
1266 | |
1267 | |
1268 | |
1269 | |
1270 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1271 | |
1272 | |
1273 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1274 | |
1275 | |
1276 | virtual ParseResult |
1277 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1278 | |
1279 | |
1280 | |
1281 | |
1282 | |
1283 | |
1284 | |
1285 | ParseResult parseAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs, |
1286 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1287 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1288 | if (!result.hasValue()) |
1289 | return emitError(getCurrentLocation(), "expected '('"); |
1290 | return result.getValue(); |
1291 | } |
1292 | |
1293 | virtual OptionalParseResult |
1294 | parseOptionalAssignmentList(SmallVectorImpl<UnresolvedOperand> &lhs, |
1295 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1296 | |
1297 | |
1298 | |
1299 | ParseResult |
1300 | parseAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs, |
1301 | SmallVectorImpl<UnresolvedOperand> &rhs, |
1302 | SmallVectorImpl<Type> &types) { |
1303 | OptionalParseResult result = |
1304 | parseOptionalAssignmentListWithTypes(lhs, rhs, types); |
1305 | if (!result.hasValue()) |
1306 | return emitError(getCurrentLocation(), "expected '('"); |
1307 | return result.getValue(); |
1308 | } |
1309 | |
1310 | virtual OptionalParseResult |
1311 | parseOptionalAssignmentListWithTypes(SmallVectorImpl<UnresolvedOperand> &lhs, |
1312 | SmallVectorImpl<UnresolvedOperand> &rhs, |
1313 | SmallVectorImpl<Type> &types) = 0; |
1314 | |
1315 | private: |
1316 | |
1317 | |
1318 | ParseResult |
1319 | parseOperandOrRegionArgList(SmallVectorImpl<UnresolvedOperand> &result, |
1320 | bool isOperandList, int requiredOperandCount, |
1321 | Delimiter delimiter); |
1322 | }; |
1323 | |
1324 | |
1325 | |
1326 | |
1327 | |
1328 | |
1329 | |
1330 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1331 | |
1332 | |
1333 | |
1334 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1335 | |
1336 | class OpAsmDialectInterface |
1337 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1338 | public: |
1339 | |
1340 | enum class AliasResult { |
1341 | |
1342 | |
1343 | NoAlias, |
1344 | |
1345 | OverridableAlias, |
1346 | |
1347 | |
1348 | FinalAlias |
1349 | }; |
1350 | |
1351 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1352 | |
1353 | |
1354 | |
1355 | |
1356 | |
1357 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1358 | return AliasResult::NoAlias; |
1359 | } |
1360 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1361 | return AliasResult::NoAlias; |
1362 | } |
1363 | |
1364 | }; |
1365 | } |
1366 | |
1367 | |
1368 | |
1369 | |
1370 | |
1371 | |
1372 | #include "mlir/IR/OpAsmInterface.h.inc" |
1373 | |
1374 | #endif |