Bug Summary

File:mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Warning:line 90, column 22
The left operand of '>' is a garbage value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name TypeParser.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Quant -I /build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/lib/Dialect/Quant -I include -I /build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/llvm/include -I /build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-01-25-232935-20746-1 -x c++ /build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/lib/Dialect/Quant/IR/TypeParser.cpp

/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/lib/Dialect/Quant/IR/TypeParser.cpp

1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/QuantOps.h"
10#include "mlir/Dialect/Quant/QuantTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/StringSwitch.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/MathExtras.h"
19#include "llvm/Support/SourceMgr.h"
20#include "llvm/Support/raw_ostream.h"
21
22using namespace mlir;
23using namespace quant;
24
25static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26 auto typeLoc = parser.getCurrentLocation();
27 IntegerType type;
28
29 // Parse storage type (alpha_ident, integer_literal).
30 StringRef identifier;
31 unsigned storageTypeWidth = 0;
32 OptionalParseResult result = parser.parseOptionalType(type);
33 if (result.hasValue()) {
34 if (!succeeded(*result)) {
35 parser.parseType(type);
36 return nullptr;
37 }
38 isSigned = !type.isUnsigned();
39 storageTypeWidth = type.getWidth();
40 } else if (succeeded(parser.parseKeyword(&identifier))) {
41 // Otherwise, this must be an unsigned integer (`u` integer-literal).
42 if (!identifier.consume_front("u")) {
43 parser.emitError(typeLoc, "illegal storage type prefix");
44 return nullptr;
45 }
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.emitError(typeLoc, "expected storage type width");
48 return nullptr;
49 }
50 isSigned = false;
51 type = parser.getBuilder().getIntegerType(storageTypeWidth);
52 } else {
53 return nullptr;
54 }
55
56 if (storageTypeWidth == 0 ||
57 storageTypeWidth > QuantizedType::MaxStorageBits) {
58 parser.emitError(typeLoc, "illegal storage type size: ")
59 << storageTypeWidth;
60 return nullptr;
61 }
62
63 return type;
64}
65
66static ParseResult parseStorageRange(DialectAsmParser &parser,
67 IntegerType storageType, bool isSigned,
68 int64_t &storageTypeMin,
69 int64_t &storageTypeMax) {
70 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
71 isSigned, storageType.getWidth());
72 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
73 isSigned, storageType.getWidth());
74 if (failed(parser.parseOptionalLess())) {
22
Calling 'failed'
28
Returning from 'failed'
29
Taking false branch
75 storageTypeMin = defaultIntegerMin;
76 storageTypeMax = defaultIntegerMax;
77 return success();
78 }
79
80 // Explicit storage min and storage max.
81 llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
82 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
39
Taking false branch
83 parser.getCurrentLocation(&maxLoc) ||
84 parser.parseInteger(storageTypeMax) || parser.parseGreater())
30
Calling 'AsmParser::parseInteger'
38
Returning from 'AsmParser::parseInteger'
85 return failure();
86 if (storageTypeMin
39.1
'storageTypeMin' is >= 'defaultIntegerMin'
39.1
'storageTypeMin' is >= 'defaultIntegerMin'
39.1
'storageTypeMin' is >= 'defaultIntegerMin'
39.1
'storageTypeMin' is >= 'defaultIntegerMin'
39.1
'storageTypeMin' is >= 'defaultIntegerMin'
< defaultIntegerMin) {
40
Taking false branch
87 return parser.emitError(minLoc, "illegal storage type minimum: ")
88 << storageTypeMin;
89 }
90 if (storageTypeMax > defaultIntegerMax) {
41
The left operand of '>' is a garbage value
91 return parser.emitError(maxLoc, "illegal storage type maximum: ")
92 << storageTypeMax;
93 }
94 return success();
95}
96
97static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
98 double &min, double &max) {
99 auto typeLoc = parser.getCurrentLocation();
100 FloatType type;
101
102 if (failed(parser.parseType(type))) {
103 parser.emitError(typeLoc, "expecting float expressed type");
104 return nullptr;
105 }
106
107 // Calibrated min and max values.
108 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
109 parser.parseFloat(max) || parser.parseGreater()) {
110 parser.emitError(typeLoc, "calibrated values must be present");
111 return nullptr;
112 }
113 return type;
114}
115
116/// Parses an AnyQuantizedType.
117///
118/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
119/// storage-spec ::= storage-type (`<` storage-range `>`)?
120/// storage-range ::= integer-literal `:` integer-literal
121/// storage-type ::= (`i` | `u`) integer-literal
122/// expressed-type-spec ::= `:` `f` integer-literal
123static Type parseAnyType(DialectAsmParser &parser) {
124 IntegerType storageType;
125 FloatType expressedType;
126 unsigned typeFlags = 0;
127 int64_t storageTypeMin;
128 int64_t storageTypeMax;
129
130 // Type specification.
131 if (parser.parseLess())
132 return nullptr;
133
134 // Storage type.
135 bool isSigned = false;
136 storageType = parseStorageType(parser, isSigned);
137 if (!storageType) {
138 return nullptr;
139 }
140 if (isSigned) {
141 typeFlags |= QuantizationFlags::Signed;
142 }
143
144 // Storage type range.
145 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
146 storageTypeMax)) {
147 return nullptr;
148 }
149
150 // Optional expressed type.
151 if (succeeded(parser.parseOptionalColon())) {
152 if (parser.parseType(expressedType)) {
153 return nullptr;
154 }
155 }
156
157 if (parser.parseGreater()) {
158 return nullptr;
159 }
160
161 return parser.getChecked<AnyQuantizedType>(
162 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
163}
164
165static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
166 int64_t &zeroPoint) {
167 // scale[:zeroPoint]?
168 // scale.
169 if (parser.parseFloat(scale))
170 return failure();
171
172 // zero point.
173 zeroPoint = 0;
174 if (failed(parser.parseOptionalColon())) {
175 // Default zero point.
176 return success();
177 }
178
179 return parser.parseInteger(zeroPoint);
180}
181
182/// Parses a UniformQuantizedType.
183///
184/// uniform_type ::= uniform_per_layer
185/// | uniform_per_axis
186/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
187/// `,` scale-zero `>`
188/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
189/// axis-spec `,` scale-zero-list `>`
190/// storage-spec ::= storage-type (`<` storage-range `>`)?
191/// storage-range ::= integer-literal `:` integer-literal
192/// storage-type ::= (`i` | `u`) integer-literal
193/// expressed-type-spec ::= `:` `f` integer-literal
194/// axis-spec ::= `:` integer-literal
195/// scale-zero ::= float-literal `:` integer-literal
196/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
197static Type parseUniformType(DialectAsmParser &parser) {
198 IntegerType storageType;
199 FloatType expressedType;
200 unsigned typeFlags = 0;
201 int64_t storageTypeMin;
202 int64_t storageTypeMax;
5
'storageTypeMax' declared without an initial value
203 bool isPerAxis = false;
204 int32_t quantizedDimension;
205 SmallVector<double, 1> scales;
206 SmallVector<int64_t, 1> zeroPoints;
207
208 // Type specification.
209 if (parser.parseLess()) {
6
Calling 'ParseResult::operator bool'
12
Returning from 'ParseResult::operator bool'
13
Taking false branch
210 return nullptr;
211 }
212
213 // Storage type.
214 bool isSigned = false;
215 storageType = parseStorageType(parser, isSigned);
216 if (!storageType) {
14
Calling 'Type::operator!'
17
Returning from 'Type::operator!'
18
Taking false branch
217 return nullptr;
218 }
219 if (isSigned
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
) {
19
Taking false branch
220 typeFlags |= QuantizationFlags::Signed;
221 }
222
223 // Storage type range.
224 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
21
Calling 'parseStorageRange'
225 storageTypeMax)) {
20
Passing value via 5th parameter 'storageTypeMax'
226 return nullptr;
227 }
228
229 // Expressed type.
230 if (parser.parseColon() || parser.parseType(expressedType)) {
231 return nullptr;
232 }
233
234 // Optionally parse quantized dimension for per-axis quantization.
235 if (succeeded(parser.parseOptionalColon())) {
236 if (parser.parseInteger(quantizedDimension))
237 return nullptr;
238 isPerAxis = true;
239 }
240
241 // Comma leading into range_spec.
242 if (parser.parseComma()) {
243 return nullptr;
244 }
245
246 // Parameter specification.
247 // For per-axis, ranges are in a {} delimitted list.
248 if (isPerAxis) {
249 if (parser.parseLBrace()) {
250 return nullptr;
251 }
252 }
253
254 // Parse scales/zeroPoints.
255 llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
256 do {
257 scales.resize(scales.size() + 1);
258 zeroPoints.resize(zeroPoints.size() + 1);
259 if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
260 return nullptr;
261 }
262 } while (isPerAxis && succeeded(parser.parseOptionalComma()));
263
264 if (isPerAxis) {
265 if (parser.parseRBrace()) {
266 return nullptr;
267 }
268 }
269
270 if (parser.parseGreater()) {
271 return nullptr;
272 }
273
274 if (!isPerAxis && scales.size() > 1) {
275 return (parser.emitError(scaleZPLoc,
276 "multiple scales/zeroPoints provided, but "
277 "quantizedDimension wasn't specified"),
278 nullptr);
279 }
280
281 if (isPerAxis) {
282 ArrayRef<double> scalesRef(scales.begin(), scales.end());
283 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
284 return parser.getChecked<UniformQuantizedPerAxisType>(
285 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
286 quantizedDimension, storageTypeMin, storageTypeMax);
287 }
288
289 return parser.getChecked<UniformQuantizedType>(
290 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
291 storageTypeMin, storageTypeMax);
292}
293
294/// Parses an CalibratedQuantizedType.
295///
296/// calibrated ::= `calibrated<` expressed-spec `>`
297/// expressed-spec ::= expressed-type `<` calibrated-range `>`
298/// expressed-type ::= `f` integer-literal
299/// calibrated-range ::= float-literal `:` float-literal
300static Type parseCalibratedType(DialectAsmParser &parser) {
301 FloatType expressedType;
302 double min;
303 double max;
304
305 // Type specification.
306 if (parser.parseLess())
307 return nullptr;
308
309 // Expressed type.
310 expressedType = parseExpressedTypeAndRange(parser, min, max);
311 if (!expressedType) {
312 return nullptr;
313 }
314
315 if (parser.parseGreater()) {
316 return nullptr;
317 }
318
319 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
320}
321
322/// Parse a type registered to this dialect.
323Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
324 // All types start with an identifier that we switch on.
325 StringRef typeNameSpelling;
326 if (failed(parser.parseKeyword(&typeNameSpelling)))
1
Taking false branch
327 return nullptr;
328
329 if (typeNameSpelling == "uniform")
2
Assuming the condition is true
3
Taking true branch
330 return parseUniformType(parser);
4
Calling 'parseUniformType'
331 if (typeNameSpelling == "any")
332 return parseAnyType(parser);
333 if (typeNameSpelling == "calibrated")
334 return parseCalibratedType(parser);
335
336 parser.emitError(parser.getNameLoc(),
337 "unknown quantized type " + typeNameSpelling);
338 return nullptr;
339}
340
341static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
342 // storage type
343 unsigned storageWidth = type.getStorageTypeIntegralWidth();
344 bool isSigned = type.isSigned();
345 if (isSigned) {
346 out << "i" << storageWidth;
347 } else {
348 out << "u" << storageWidth;
349 }
350
351 // storageTypeMin and storageTypeMax if not default.
352 int64_t defaultIntegerMin =
353 QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
354 int64_t defaultIntegerMax =
355 QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
356 if (defaultIntegerMin != type.getStorageTypeMin() ||
357 defaultIntegerMax != type.getStorageTypeMax()) {
358 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
359 << ">";
360 }
361}
362
363static void printQuantParams(double scale, int64_t zeroPoint,
364 DialectAsmPrinter &out) {
365 out << scale;
366 if (zeroPoint != 0) {
367 out << ":" << zeroPoint;
368 }
369}
370
371/// Helper that prints a AnyQuantizedType.
372static void printAnyQuantizedType(AnyQuantizedType type,
373 DialectAsmPrinter &out) {
374 out << "any<";
375 printStorageType(type, out);
376 if (Type expressedType = type.getExpressedType()) {
377 out << ":" << expressedType;
378 }
379 out << ">";
380}
381
382/// Helper that prints a UniformQuantizedType.
383static void printUniformQuantizedType(UniformQuantizedType type,
384 DialectAsmPrinter &out) {
385 out << "uniform<";
386 printStorageType(type, out);
387 out << ":" << type.getExpressedType() << ", ";
388
389 // scheme specific parameters
390 printQuantParams(type.getScale(), type.getZeroPoint(), out);
391 out << ">";
392}
393
394/// Helper that prints a UniformQuantizedPerAxisType.
395static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
396 DialectAsmPrinter &out) {
397 out << "uniform<";
398 printStorageType(type, out);
399 out << ":" << type.getExpressedType() << ":";
400 out << type.getQuantizedDimension();
401 out << ", ";
402
403 // scheme specific parameters
404 ArrayRef<double> scales = type.getScales();
405 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
406 out << "{";
407 llvm::interleave(
408 llvm::seq<size_t>(0, scales.size()), out,
409 [&](size_t index) {
410 printQuantParams(scales[index], zeroPoints[index], out);
411 },
412 ",");
413 out << "}>";
414}
415
416/// Helper that prints a CalibratedQuantizedType.
417static void printCalibratedQuantizedType(CalibratedQuantizedType type,
418 DialectAsmPrinter &out) {
419 out << "calibrated<" << type.getExpressedType();
420 out << "<" << type.getMin() << ":" << type.getMax() << ">";
421 out << ">";
422}
423
424/// Print a type registered to this dialect.
425void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
426 if (auto anyType = type.dyn_cast<AnyQuantizedType>())
427 printAnyQuantizedType(anyType, os);
428 else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
429 printUniformQuantizedType(uniformType, os);
430 else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
431 printUniformQuantizedPerAxisType(perAxisType, os);
432 else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
433 printCalibratedQuantizedType(calibratedType, os);
434 else
435 llvm_unreachable("Unhandled quantized type")::llvm::llvm_unreachable_internal("Unhandled quantized type",
"mlir/lib/Dialect/Quant/IR/TypeParser.cpp", 435)
;
436}

/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/include/mlir/IR/OpDefinition.h

1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24#include "llvm/Support/PointerLikeTypeTraits.h"
25
26#include <type_traits>
27
28namespace mlir {
29class Builder;
30class OpBuilder;
31
32/// This class represents success/failure for operation parsing. It is
33/// essentially a simple wrapper class around LogicalResult that allows for
34/// explicit conversion to bool. This allows for the parser to chain together
35/// parse rules without the clutter of "failed/succeeded".
36class ParseResult : public LogicalResult {
37public:
38 ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
39
40 // Allow diagnostics emitted during parsing to be converted to failure.
41 ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
42 ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
43
44 /// Failure is true in a boolean context.
45 explicit operator bool() const { return failed(); }
7
Calling 'LogicalResult::failed'
10
Returning from 'LogicalResult::failed'
11
Returning zero, which participates in a condition later
46};
47/// This class implements `Optional` functionality for ParseResult. We don't
48/// directly use Optional here, because it provides an implicit conversion
49/// to 'bool' which we want to avoid. This class is used to implement tri-state
50/// 'parseOptional' functions that may have a failure mode when parsing that
51/// shouldn't be attributed to "not present".
52class OptionalParseResult {
53public:
54 OptionalParseResult() = default;
55 OptionalParseResult(LogicalResult result) : impl(result) {}
56 OptionalParseResult(ParseResult result) : impl(result) {}
57 OptionalParseResult(const InFlightDiagnostic &)
58 : OptionalParseResult(failure()) {}
59 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
60
61 /// Returns true if we contain a valid ParseResult value.
62 bool hasValue() const { return impl.hasValue(); }
63
64 /// Access the internal ParseResult value.
65 ParseResult getValue() const { return impl.getValue(); }
66 ParseResult operator*() const { return getValue(); }
67
68private:
69 Optional<ParseResult> impl;
70};
71
72// These functions are out-of-line utilities, which avoids them being template
73// instantiated/duplicated.
74namespace impl {
75/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
76/// region's only block if it does not have a terminator already. If the region
77/// is empty, insert a new block first. `buildTerminatorOp` should return the
78/// terminator operation to insert.
79void ensureRegionTerminator(
80 Region &region, OpBuilder &builder, Location loc,
81 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
82void ensureRegionTerminator(
83 Region &region, Builder &builder, Location loc,
84 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
85
86} // namespace impl
87
88/// This is the concrete base class that holds the operation pointer and has
89/// non-generic methods that only depend on State (to avoid having them
90/// instantiated on template types that don't affect them.
91///
92/// This also has the fallback implementations of customization hooks for when
93/// they aren't customized.
94class OpState {
95public:
96 /// Ops are pointer-like, so we allow conversion to bool.
97 explicit operator bool() { return getOperation() != nullptr; }
98
99 /// This implicitly converts to Operation*.
100 operator Operation *() const { return state; }
101
102 /// Shortcut of `->` to access a member of Operation.
103 Operation *operator->() const { return state; }
104
105 /// Return the operation that this refers to.
106 Operation *getOperation() { return state; }
107
108 /// Return the context this operation belongs to.
109 MLIRContext *getContext() { return getOperation()->getContext(); }
110
111 /// Print the operation to the given stream.
112 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
113 state->print(os, flags);
114 }
115 void print(raw_ostream &os, AsmState &asmState,
116 OpPrintingFlags flags = llvm::None) {
117 state->print(os, asmState, flags);
118 }
119
120 /// Dump this operation.
121 void dump() { state->dump(); }
122
123 /// The source location the operation was defined or derived from.
124 Location getLoc() { return state->getLoc(); }
125
126 /// Return true if there are no users of any results of this operation.
127 bool use_empty() { return state->use_empty(); }
128
129 /// Remove this operation from its parent block and delete it.
130 void erase() { state->erase(); }
131
132 /// Emit an error with the op name prefixed, like "'dim' op " which is
133 /// convenient for verifiers.
134 InFlightDiagnostic emitOpError(const Twine &message = {});
135
136 /// Emit an error about fatal conditions with this operation, reporting up to
137 /// any diagnostic handlers that may be listening.
138 InFlightDiagnostic emitError(const Twine &message = {});
139
140 /// Emit a warning about this operation, reporting up to any diagnostic
141 /// handlers that may be listening.
142 InFlightDiagnostic emitWarning(const Twine &message = {});
143
144 /// Emit a remark about this operation, reporting up to any diagnostic
145 /// handlers that may be listening.
146 InFlightDiagnostic emitRemark(const Twine &message = {});
147
148 /// Walk the operation by calling the callback for each nested operation
149 /// (including this one), block or region, depending on the callback provided.
150 /// Regions, blocks and operations at the same nesting level are visited in
151 /// lexicographical order. The walk order for enclosing regions, blocks and
152 /// operations with respect to their nested ones is specified by 'Order'
153 /// (post-order by default). A callback on a block or operation is allowed to
154 /// erase that block or operation if either:
155 /// * the walk is in post-order, or
156 /// * the walk is in pre-order and the walk is skipped after the erasure.
157 /// See Operation::walk for more details.
158 template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
159 typename RetT = detail::walkResultType<FnT>>
160 typename std::enable_if<
161 llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
162 walk(FnT &&callback) {
163 return state->walk<Order>(std::forward<FnT>(callback));
164 }
165
166 /// Generic walker with a stage aware callback. Walk the operation by calling
167 /// the callback for each nested operation (including this one) N+1 times,
168 /// where N is the number of regions attached to that operation.
169 ///
170 /// The callback method can take any of the following forms:
171 /// void(Operation *, const WalkStage &) : Walk all operation opaquely
172 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
173 /// void(OpT, const WalkStage &) : Walk all operations of the given derived
174 /// type.
175 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
176 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
177 /// but allow for interruption/skipping.
178 /// * op.walk([](... op, const WalkStage &stage) {
179 /// // Skip the walk of this op based on some invariant.
180 /// if (some_invariant)
181 /// return WalkResult::skip();
182 /// // Interrupt, i.e cancel, the walk based on some invariant.
183 /// if (another_invariant)
184 /// return WalkResult::interrupt();
185 /// return WalkResult::advance();
186 /// });
187 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
188 typename std::enable_if<
189 llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
190 walk(FnT &&callback) {
191 return state->walk(std::forward<FnT>(callback));
192 }
193
194 // These are default implementations of customization hooks.
195public:
196 /// This hook returns any canonicalization pattern rewrites that the operation
197 /// supports, for use by the canonicalization pass.
198 static void getCanonicalizationPatterns(RewritePatternSet &results,
199 MLIRContext *context) {}
200
201protected:
202 /// If the concrete type didn't implement a custom verifier hook, just fall
203 /// back to this one which accepts everything.
204 LogicalResult verify() { return success(); }
205
206 /// Parse the custom form of an operation. Unless overridden, this method will
207 /// first try to get an operation parser from the op's dialect. Otherwise the
208 /// custom assembly form of an op is always rejected. Op implementations
209 /// should implement this to return failure. On success, they should fill in
210 /// result with the fields to use.
211 static ParseResult parse(OpAsmParser &parser, OperationState &result);
212
213 /// Print the operation. Unless overridden, this method will first try to get
214 /// an operation printer from the dialect. Otherwise, it prints the operation
215 /// in generic form.
216 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
217
218 /// Print an operation name, eliding the dialect prefix if necessary.
219 static void printOpName(Operation *op, OpAsmPrinter &p,
220 StringRef defaultDialect);
221
222 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
223 /// so we can cast it away here.
224 explicit OpState(Operation *state) : state(state) {}
225
226private:
227 Operation *state;
228
229 /// Allow access to internal hook implementation methods.
230 friend RegisteredOperationName;
231};
232
233// Allow comparing operators.
234inline bool operator==(OpState lhs, OpState rhs) {
235 return lhs.getOperation() == rhs.getOperation();
236}
237inline bool operator!=(OpState lhs, OpState rhs) {
238 return lhs.getOperation() != rhs.getOperation();
239}
240
241raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
242
243/// This class represents a single result from folding an operation.
244class OpFoldResult : public PointerUnion<Attribute, Value> {
245 using PointerUnion<Attribute, Value>::PointerUnion;
246
247public:
248 void dump() { llvm::errs() << *this << "\n"; }
249};
250
251/// Allow printing to a stream.
252inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
253 if (Value value = ofr.dyn_cast<Value>())
254 value.print(os);
255 else
256 ofr.dyn_cast<Attribute>().print(os);
257 return os;
258}
259
260/// Allow printing to a stream.
261inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
262 op.print(os, OpPrintingFlags().useLocalScope());
263 return os;
264}
265
266//===----------------------------------------------------------------------===//
267// Operation Trait Types
268//===----------------------------------------------------------------------===//
269
270namespace OpTrait {
271
272// These functions are out-of-line implementations of the methods in the
273// corresponding trait classes. This avoids them being template
274// instantiated/duplicated.
275namespace impl {
276OpFoldResult foldIdempotent(Operation *op);
277OpFoldResult foldInvolution(Operation *op);
278LogicalResult verifyZeroOperands(Operation *op);
279LogicalResult verifyOneOperand(Operation *op);
280LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
281LogicalResult verifyIsIdempotent(Operation *op);
282LogicalResult verifyIsInvolution(Operation *op);
283LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
284LogicalResult verifyOperandsAreFloatLike(Operation *op);
285LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
286LogicalResult verifySameTypeOperands(Operation *op);
287LogicalResult verifyZeroRegion(Operation *op);
288LogicalResult verifyOneRegion(Operation *op);
289LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
290LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
291LogicalResult verifyZeroResult(Operation *op);
292LogicalResult verifyOneResult(Operation *op);
293LogicalResult verifyNResults(Operation *op, unsigned numOperands);
294LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
295LogicalResult verifySameOperandsShape(Operation *op);
296LogicalResult verifySameOperandsAndResultShape(Operation *op);
297LogicalResult verifySameOperandsElementType(Operation *op);
298LogicalResult verifySameOperandsAndResultElementType(Operation *op);
299LogicalResult verifySameOperandsAndResultType(Operation *op);
300LogicalResult verifyResultsAreBoolLike(Operation *op);
301LogicalResult verifyResultsAreFloatLike(Operation *op);
302LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
303LogicalResult verifyIsTerminator(Operation *op);
304LogicalResult verifyZeroSuccessor(Operation *op);
305LogicalResult verifyOneSuccessor(Operation *op);
306LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
307LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
308LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
309 StringRef valueGroupName,
310 size_t expectedCount);
311LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
312LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
313LogicalResult verifyNoRegionArguments(Operation *op);
314LogicalResult verifyElementwise(Operation *op);
315LogicalResult verifyIsIsolatedFromAbove(Operation *op);
316} // namespace impl
317
318/// Helper class for implementing traits. Clients are not expected to interact
319/// with this directly, so its members are all protected.
320template <typename ConcreteType, template <typename> class TraitType>
321class TraitBase {
322protected:
323 /// Return the ultimate Operation being worked on.
324 Operation *getOperation() {
325 // We have to cast up to the trait type, then to the concrete type, then to
326 // the BaseState class in explicit hops because the concrete type will
327 // multiply derive from the (content free) TraitBase class, and we need to
328 // be able to disambiguate the path for the C++ compiler.
329 auto *trait = static_cast<TraitType<ConcreteType> *>(this);
330 auto *concrete = static_cast<ConcreteType *>(trait);
331 auto *base = static_cast<OpState *>(concrete);
332 return base->getOperation();
333 }
334};
335
336//===----------------------------------------------------------------------===//
337// Operand Traits
338
339namespace detail {
340/// Utility trait base that provides accessors for derived traits that have
341/// multiple operands.
342template <typename ConcreteType, template <typename> class TraitType>
343struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
344 using operand_iterator = Operation::operand_iterator;
345 using operand_range = Operation::operand_range;
346 using operand_type_iterator = Operation::operand_type_iterator;
347 using operand_type_range = Operation::operand_type_range;
348
349 /// Return the number of operands.
350 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
351
352 /// Return the operand at index 'i'.
353 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
354
355 /// Set the operand at index 'i' to 'value'.
356 void setOperand(unsigned i, Value value) {
357 this->getOperation()->setOperand(i, value);
358 }
359
360 /// Operand iterator access.
361 operand_iterator operand_begin() {
362 return this->getOperation()->operand_begin();
363 }
364 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
365 operand_range getOperands() { return this->getOperation()->getOperands(); }
366
367 /// Operand type access.
368 operand_type_iterator operand_type_begin() {
369 return this->getOperation()->operand_type_begin();
370 }
371 operand_type_iterator operand_type_end() {
372 return this->getOperation()->operand_type_end();
373 }
374 operand_type_range getOperandTypes() {
375 return this->getOperation()->getOperandTypes();
376 }
377};
378} // namespace detail
379
380/// This class provides the API for ops that are known to have no
381/// SSA operand.
382template <typename ConcreteType>
383class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
384public:
385 static LogicalResult verifyTrait(Operation *op) {
386 return impl::verifyZeroOperands(op);
387 }
388
389private:
390 // Disable these.
391 void getOperand() {}
392 void setOperand() {}
393};
394
395/// This class provides the API for ops that are known to have exactly one
396/// SSA operand.
397template <typename ConcreteType>
398class OneOperand : public TraitBase<ConcreteType, OneOperand> {
399public:
400 Value getOperand() { return this->getOperation()->getOperand(0); }
401
402 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
403
404 static LogicalResult verifyTrait(Operation *op) {
405 return impl::verifyOneOperand(op);
406 }
407};
408
409/// This class provides the API for ops that are known to have a specified
410/// number of operands. This is used as a trait like this:
411///
412/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
413///
414template <unsigned N>
415class NOperands {
416public:
417 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
418
419 template <typename ConcreteType>
420 class Impl
421 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
422 public:
423 static LogicalResult verifyTrait(Operation *op) {
424 return impl::verifyNOperands(op, N);
425 }
426 };
427};
428
429/// This class provides the API for ops that are known to have a at least a
430/// specified number of operands. This is used as a trait like this:
431///
432/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
433///
434template <unsigned N>
435class AtLeastNOperands {
436public:
437 template <typename ConcreteType>
438 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
439 AtLeastNOperands<N>::Impl> {
440 public:
441 static LogicalResult verifyTrait(Operation *op) {
442 return impl::verifyAtLeastNOperands(op, N);
443 }
444 };
445};
446
447/// This class provides the API for ops which have an unknown number of
448/// SSA operands.
449template <typename ConcreteType>
450class VariadicOperands
451 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
452
453//===----------------------------------------------------------------------===//
454// Region Traits
455
456/// This class provides verification for ops that are known to have zero
457/// regions.
458template <typename ConcreteType>
459class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
460public:
461 static LogicalResult verifyTrait(Operation *op) {
462 return impl::verifyZeroRegion(op);
463 }
464};
465
466namespace detail {
467/// Utility trait base that provides accessors for derived traits that have
468/// multiple regions.
469template <typename ConcreteType, template <typename> class TraitType>
470struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
471 using region_iterator = MutableArrayRef<Region>;
472 using region_range = RegionRange;
473
474 /// Return the number of regions.
475 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
476
477 /// Return the region at `index`.
478 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
479
480 /// Region iterator access.
481 region_iterator region_begin() {
482 return this->getOperation()->region_begin();
483 }
484 region_iterator region_end() { return this->getOperation()->region_end(); }
485 region_range getRegions() { return this->getOperation()->getRegions(); }
486};
487} // namespace detail
488
489/// This class provides APIs for ops that are known to have a single region.
490template <typename ConcreteType>
491class OneRegion : public TraitBase<ConcreteType, OneRegion> {
492public:
493 Region &getRegion() { return this->getOperation()->getRegion(0); }
494
495 /// Returns a range of operations within the region of this operation.
496 auto getOps() { return getRegion().getOps(); }
497 template <typename OpT>
498 auto getOps() {
499 return getRegion().template getOps<OpT>();
500 }
501
502 static LogicalResult verifyTrait(Operation *op) {
503 return impl::verifyOneRegion(op);
504 }
505};
506
507/// This class provides the API for ops that are known to have a specified
508/// number of regions.
509template <unsigned N>
510class NRegions {
511public:
512 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
513
514 template <typename ConcreteType>
515 class Impl
516 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
517 public:
518 static LogicalResult verifyTrait(Operation *op) {
519 return impl::verifyNRegions(op, N);
520 }
521 };
522};
523
524/// This class provides APIs for ops that are known to have at least a specified
525/// number of regions.
526template <unsigned N>
527class AtLeastNRegions {
528public:
529 template <typename ConcreteType>
530 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
531 AtLeastNRegions<N>::Impl> {
532 public:
533 static LogicalResult verifyTrait(Operation *op) {
534 return impl::verifyAtLeastNRegions(op, N);
535 }
536 };
537};
538
539/// This class provides the API for ops which have an unknown number of
540/// regions.
541template <typename ConcreteType>
542class VariadicRegions
543 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
544
545//===----------------------------------------------------------------------===//
546// Result Traits
547
548/// This class provides return value APIs for ops that are known to have
549/// zero results.
550template <typename ConcreteType>
551class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
552public:
553 static LogicalResult verifyTrait(Operation *op) {
554 return impl::verifyZeroResult(op);
555 }
556};
557
558namespace detail {
559/// Utility trait base that provides accessors for derived traits that have
560/// multiple results.
561template <typename ConcreteType, template <typename> class TraitType>
562struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
563 using result_iterator = Operation::result_iterator;
564 using result_range = Operation::result_range;
565 using result_type_iterator = Operation::result_type_iterator;
566 using result_type_range = Operation::result_type_range;
567
568 /// Return the number of results.
569 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
570
571 /// Return the result at index 'i'.
572 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
573
574 /// Replace all uses of results of this operation with the provided 'values'.
575 /// 'values' may correspond to an existing operation, or a range of 'Value'.
576 template <typename ValuesT>
577 void replaceAllUsesWith(ValuesT &&values) {
578 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
579 }
580
581 /// Return the type of the `i`-th result.
582 Type getType(unsigned i) { return getResult(i).getType(); }
583
584 /// Result iterator access.
585 result_iterator result_begin() {
586 return this->getOperation()->result_begin();
587 }
588 result_iterator result_end() { return this->getOperation()->result_end(); }
589 result_range getResults() { return this->getOperation()->getResults(); }
590
591 /// Result type access.
592 result_type_iterator result_type_begin() {
593 return this->getOperation()->result_type_begin();
594 }
595 result_type_iterator result_type_end() {
596 return this->getOperation()->result_type_end();
597 }
598 result_type_range getResultTypes() {
599 return this->getOperation()->getResultTypes();
600 }
601};
602} // namespace detail
603
604/// This class provides return value APIs for ops that are known to have a
605/// single result. ResultType is the concrete type returned by getType().
606template <typename ConcreteType>
607class OneResult : public TraitBase<ConcreteType, OneResult> {
608public:
609 Value getResult() { return this->getOperation()->getResult(0); }
610
611 /// If the operation returns a single value, then the Op can be implicitly
612 /// converted to an Value. This yields the value of the only result.
613 operator Value() { return getResult(); }
614
615 /// Replace all uses of 'this' value with the new value, updating anything
616 /// in the IR that uses 'this' to use the other value instead. When this
617 /// returns there are zero uses of 'this'.
618 void replaceAllUsesWith(Value newValue) {
619 getResult().replaceAllUsesWith(newValue);
620 }
621
622 /// Replace all uses of 'this' value with the result of 'op'.
623 void replaceAllUsesWith(Operation *op) {
624 this->getOperation()->replaceAllUsesWith(op);
625 }
626
627 static LogicalResult verifyTrait(Operation *op) {
628 return impl::verifyOneResult(op);
629 }
630};
631
632/// This trait is used for return value APIs for ops that are known to have a
633/// specific type other than `Type`. This allows the "getType()" member to be
634/// more specific for an op. This should be used in conjunction with OneResult,
635/// and occur in the trait list before OneResult.
636template <typename ResultType>
637class OneTypedResult {
638public:
639 /// This class provides return value APIs for ops that are known to have a
640 /// single result. ResultType is the concrete type returned by getType().
641 template <typename ConcreteType>
642 class Impl
643 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
644 public:
645 ResultType getType() {
646 auto resultTy = this->getOperation()->getResult(0).getType();
647 return resultTy.template cast<ResultType>();
648 }
649 };
650};
651
652/// This class provides the API for ops that are known to have a specified
653/// number of results. This is used as a trait like this:
654///
655/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
656///
657template <unsigned N>
658class NResults {
659public:
660 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
661
662 template <typename ConcreteType>
663 class Impl
664 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
665 public:
666 static LogicalResult verifyTrait(Operation *op) {
667 return impl::verifyNResults(op, N);
668 }
669 };
670};
671
672/// This class provides the API for ops that are known to have at least a
673/// specified number of results. This is used as a trait like this:
674///
675/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
676///
677template <unsigned N>
678class AtLeastNResults {
679public:
680 template <typename ConcreteType>
681 class Impl : public detail::MultiResultTraitBase<ConcreteType,
682 AtLeastNResults<N>::Impl> {
683 public:
684 static LogicalResult verifyTrait(Operation *op) {
685 return impl::verifyAtLeastNResults(op, N);
686 }
687 };
688};
689
690/// This class provides the API for ops which have an unknown number of
691/// results.
692template <typename ConcreteType>
693class VariadicResults
694 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
695
696//===----------------------------------------------------------------------===//
697// Terminator Traits
698
699/// This class indicates that the regions associated with this op don't have
700/// terminators.
701template <typename ConcreteType>
702class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
703
704/// This class provides the API for ops that are known to be terminators.
705template <typename ConcreteType>
706class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
707public:
708 static LogicalResult verifyTrait(Operation *op) {
709 return impl::verifyIsTerminator(op);
710 }
711};
712
713/// This class provides verification for ops that are known to have zero
714/// successors.
715template <typename ConcreteType>
716class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
717public:
718 static LogicalResult verifyTrait(Operation *op) {
719 return impl::verifyZeroSuccessor(op);
720 }
721};
722
723namespace detail {
724/// Utility trait base that provides accessors for derived traits that have
725/// multiple successors.
726template <typename ConcreteType, template <typename> class TraitType>
727struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
728 using succ_iterator = Operation::succ_iterator;
729 using succ_range = SuccessorRange;
730
731 /// Return the number of successors.
732 unsigned getNumSuccessors() {
733 return this->getOperation()->getNumSuccessors();
734 }
735
736 /// Return the successor at `index`.
737 Block *getSuccessor(unsigned i) {
738 return this->getOperation()->getSuccessor(i);
739 }
740
741 /// Set the successor at `index`.
742 void setSuccessor(Block *block, unsigned i) {
743 return this->getOperation()->setSuccessor(block, i);
744 }
745
746 /// Successor iterator access.
747 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
748 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
749 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
750};
751} // namespace detail
752
753/// This class provides APIs for ops that are known to have a single successor.
754template <typename ConcreteType>
755class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
756public:
757 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
758 void setSuccessor(Block *succ) {
759 this->getOperation()->setSuccessor(succ, 0);
760 }
761
762 static LogicalResult verifyTrait(Operation *op) {
763 return impl::verifyOneSuccessor(op);
764 }
765};
766
767/// This class provides the API for ops that are known to have a specified
768/// number of successors.
769template <unsigned N>
770class NSuccessors {
771public:
772 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
773
774 template <typename ConcreteType>
775 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
776 NSuccessors<N>::Impl> {
777 public:
778 static LogicalResult verifyTrait(Operation *op) {
779 return impl::verifyNSuccessors(op, N);
780 }
781 };
782};
783
784/// This class provides APIs for ops that are known to have at least a specified
785/// number of successors.
786template <unsigned N>
787class AtLeastNSuccessors {
788public:
789 template <typename ConcreteType>
790 class Impl
791 : public detail::MultiSuccessorTraitBase<ConcreteType,
792 AtLeastNSuccessors<N>::Impl> {
793 public:
794 static LogicalResult verifyTrait(Operation *op) {
795 return impl::verifyAtLeastNSuccessors(op, N);
796 }
797 };
798};
799
800/// This class provides the API for ops which have an unknown number of
801/// successors.
802template <typename ConcreteType>
803class VariadicSuccessors
804 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
805};
806
807//===----------------------------------------------------------------------===//
808// SingleBlock
809
810/// This class provides APIs and verifiers for ops with regions having a single
811/// block.
812template <typename ConcreteType>
813struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
814public:
815 static LogicalResult verifyTrait(Operation *op) {
816 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
817 Region &region = op->getRegion(i);
818
819 // Empty regions are fine.
820 if (region.empty())
821 continue;
822
823 // Non-empty regions must contain a single basic block.
824 if (!llvm::hasSingleElement(region))
825 return op->emitOpError("expects region #")
826 << i << " to have 0 or 1 blocks";
827
828 if (!ConcreteType::template hasTrait<NoTerminator>()) {
829 Block &block = region.front();
830 if (block.empty())
831 return op->emitOpError() << "expects a non-empty block";
832 }
833 }
834 return success();
835 }
836
837 Block *getBody(unsigned idx = 0) {
838 Region &region = this->getOperation()->getRegion(idx);
839 assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region"
) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\""
, "mlir/include/mlir/IR/OpDefinition.h", 839, __extension__ __PRETTY_FUNCTION__
))
;
840 return &region.front();
841 }
842 Region &getBodyRegion(unsigned idx = 0) {
843 return this->getOperation()->getRegion(idx);
844 }
845
846 //===------------------------------------------------------------------===//
847 // Single Region Utilities
848 //===------------------------------------------------------------------===//
849
850 /// The following are a set of methods only enabled when the parent
851 /// operation has a single region. Each of these methods take an additional
852 /// template parameter that represents the concrete operation so that we
853 /// can use SFINAE to disable the methods for non-single region operations.
854 template <typename OpT, typename T = void>
855 using enable_if_single_region =
856 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
857
858 template <typename OpT = ConcreteType>
859 enable_if_single_region<OpT, Block::iterator> begin() {
860 return getBody()->begin();
861 }
862 template <typename OpT = ConcreteType>
863 enable_if_single_region<OpT, Block::iterator> end() {
864 return getBody()->end();
865 }
866 template <typename OpT = ConcreteType>
867 enable_if_single_region<OpT, Operation &> front() {
868 return *begin();
869 }
870
871 /// Insert the operation into the back of the body.
872 template <typename OpT = ConcreteType>
873 enable_if_single_region<OpT> push_back(Operation *op) {
874 insert(Block::iterator(getBody()->end()), op);
875 }
876
877 /// Insert the operation at the given insertion point.
878 template <typename OpT = ConcreteType>
879 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
880 insert(Block::iterator(insertPt), op);
881 }
882 template <typename OpT = ConcreteType>
883 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
884 getBody()->getOperations().insert(insertPt, op);
885 }
886};
887
888//===----------------------------------------------------------------------===//
889// SingleBlockImplicitTerminator
890
891/// This class provides APIs and verifiers for ops with regions having a single
892/// block that must terminate with `TerminatorOpType`.
893template <typename TerminatorOpType>
894struct SingleBlockImplicitTerminator {
895 template <typename ConcreteType>
896 class Impl : public SingleBlock<ConcreteType> {
897 private:
898 using Base = SingleBlock<ConcreteType>;
899 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
900 /// cyclic header inclusion.
901 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
902 OperationState state(loc, TerminatorOpType::getOperationName());
903 TerminatorOpType::build(builder, state);
904 return Operation::create(state);
905 }
906
907 public:
908 /// The type of the operation used as the implicit terminator type.
909 using ImplicitTerminatorOpT = TerminatorOpType;
910
911 static LogicalResult verifyTrait(Operation *op) {
912 if (failed(Base::verifyTrait(op)))
913 return failure();
914 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
915 Region &region = op->getRegion(i);
916 // Empty regions are fine.
917 if (region.empty())
918 continue;
919 Operation &terminator = region.front().back();
920 if (isa<TerminatorOpType>(terminator))
921 continue;
922
923 return op->emitOpError("expects regions to end with '" +
924 TerminatorOpType::getOperationName() +
925 "', found '" +
926 terminator.getName().getStringRef() + "'")
927 .attachNote()
928 << "in custom textual format, the absence of terminator implies "
929 "'"
930 << TerminatorOpType::getOperationName() << '\'';
931 }
932
933 return success();
934 }
935
936 /// Ensure that the given region has the terminator required by this trait.
937 /// If OpBuilder is provided, use it to build the terminator and notify the
938 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
939 /// construct an OpBuilder with no listeners; this should only be used if no
940 /// OpBuilder is available at the call site, e.g., in the parser.
941 static void ensureTerminator(Region &region, Builder &builder,
942 Location loc) {
943 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
944 buildTerminator);
945 }
946 static void ensureTerminator(Region &region, OpBuilder &builder,
947 Location loc) {
948 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
949 buildTerminator);
950 }
951
952 //===------------------------------------------------------------------===//
953 // Single Region Utilities
954 //===------------------------------------------------------------------===//
955 using Base::getBody;
956
957 template <typename OpT, typename T = void>
958 using enable_if_single_region =
959 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
960
961 /// Insert the operation into the back of the body, before the terminator.
962 template <typename OpT = ConcreteType>
963 enable_if_single_region<OpT> push_back(Operation *op) {
964 insert(Block::iterator(getBody()->getTerminator()), op);
965 }
966
967 /// Insert the operation at the given insertion point. Note: The operation
968 /// is never inserted after the terminator, even if the insertion point is
969 /// end().
970 template <typename OpT = ConcreteType>
971 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
972 insert(Block::iterator(insertPt), op);
973 }
974 template <typename OpT = ConcreteType>
975 enable_if_single_region<OpT> insert(Block::iterator insertPt,
976 Operation *op) {
977 auto *body = getBody();
978 if (insertPt == body->end())
979 insertPt = Block::iterator(body->getTerminator());
980 body->getOperations().insert(insertPt, op);
981 }
982 };
983};
984
985/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
986/// to be used with `llvm::is_detected`.
987template <class T>
988using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
989
990/// Support to check if an operation has the SingleBlockImplicitTerminator
991/// trait. We can't just use `hasTrait` because this class is templated on a
992/// specific terminator op.
993template <class Op, bool hasTerminator =
994 llvm::is_detected<has_implicit_terminator_t, Op>::value>
995struct hasSingleBlockImplicitTerminator {
996 static constexpr bool value = std::is_base_of<
997 typename OpTrait::SingleBlockImplicitTerminator<
998 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
999 Op>::value;
1000};
1001template <class Op>
1002struct hasSingleBlockImplicitTerminator<Op, false> {
1003 static constexpr bool value = false;
1004};
1005
1006//===----------------------------------------------------------------------===//
1007// Misc Traits
1008
1009/// This class provides verification for ops that are known to have the same
1010/// operand shape: all operands are scalars, vectors/tensors of the same
1011/// shape.
1012template <typename ConcreteType>
1013class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
1014public:
1015 static LogicalResult verifyTrait(Operation *op) {
1016 return impl::verifySameOperandsShape(op);
1017 }
1018};
1019
1020/// This class provides verification for ops that are known to have the same
1021/// operand and result shape: both are scalars, vectors/tensors of the same
1022/// shape.
1023template <typename ConcreteType>
1024class SameOperandsAndResultShape
1025 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
1026public:
1027 static LogicalResult verifyTrait(Operation *op) {
1028 return impl::verifySameOperandsAndResultShape(op);
1029 }
1030};
1031
1032/// This class provides verification for ops that are known to have the same
1033/// operand element type (or the type itself if it is scalar).
1034///
1035template <typename ConcreteType>
1036class SameOperandsElementType
1037 : public TraitBase<ConcreteType, SameOperandsElementType> {
1038public:
1039 static LogicalResult verifyTrait(Operation *op) {
1040 return impl::verifySameOperandsElementType(op);
1041 }
1042};
1043
1044/// This class provides verification for ops that are known to have the same
1045/// operand and result element type (or the type itself if it is scalar).
1046///
1047template <typename ConcreteType>
1048class SameOperandsAndResultElementType
1049 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1050public:
1051 static LogicalResult verifyTrait(Operation *op) {
1052 return impl::verifySameOperandsAndResultElementType(op);
1053 }
1054};
1055
1056/// This class provides verification for ops that are known to have the same
1057/// operand and result type.
1058///
1059/// Note: this trait subsumes the SameOperandsAndResultShape and
1060/// SameOperandsAndResultElementType traits.
1061template <typename ConcreteType>
1062class SameOperandsAndResultType
1063 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1064public:
1065 static LogicalResult verifyTrait(Operation *op) {
1066 return impl::verifySameOperandsAndResultType(op);
1067 }
1068};
1069
1070/// This class verifies that any results of the specified op have a boolean
1071/// type, a vector thereof, or a tensor thereof.
1072template <typename ConcreteType>
1073class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1074public:
1075 static LogicalResult verifyTrait(Operation *op) {
1076 return impl::verifyResultsAreBoolLike(op);
1077 }
1078};
1079
1080/// This class verifies that any results of the specified op have a floating
1081/// point type, a vector thereof, or a tensor thereof.
1082template <typename ConcreteType>
1083class ResultsAreFloatLike
1084 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1085public:
1086 static LogicalResult verifyTrait(Operation *op) {
1087 return impl::verifyResultsAreFloatLike(op);
1088 }
1089};
1090
1091/// This class verifies that any results of the specified op have a signless
1092/// integer or index type, a vector thereof, or a tensor thereof.
1093template <typename ConcreteType>
1094class ResultsAreSignlessIntegerLike
1095 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1096public:
1097 static LogicalResult verifyTrait(Operation *op) {
1098 return impl::verifyResultsAreSignlessIntegerLike(op);
1099 }
1100};
1101
1102/// This class adds property that the operation is commutative.
1103template <typename ConcreteType>
1104class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
1105
1106/// This class adds property that the operation is an involution.
1107/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1108template <typename ConcreteType>
1109class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1110public:
1111 static LogicalResult verifyTrait(Operation *op) {
1112 static_assert(ConcreteType::template hasTrait<OneResult>(),
1113 "expected operation to produce one result");
1114 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1115 "expected operation to take one operand");
1116 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1117 "expected operation to preserve type");
1118 // Involution requires the operation to be side effect free as well
1119 // but currently this check is under a FIXME and is not actually done.
1120 return impl::verifyIsInvolution(op);
1121 }
1122
1123 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1124 return impl::foldInvolution(op);
1125 }
1126};
1127
1128/// This class adds property that the operation is idempotent.
1129/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1130/// or a binary operation "g" that satisfies g(x, x) = x.
1131template <typename ConcreteType>
1132class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1133public:
1134 static LogicalResult verifyTrait(Operation *op) {
1135 static_assert(ConcreteType::template hasTrait<OneResult>(),
1136 "expected operation to produce one result");
1137 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1138 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1139 "expected operation to take one or two operands");
1140 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1141 "expected operation to preserve type");
1142 // Idempotent requires the operation to be side effect free as well
1143 // but currently this check is under a FIXME and is not actually done.
1144 return impl::verifyIsIdempotent(op);
1145 }
1146
1147 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1148 return impl::foldIdempotent(op);
1149 }
1150};
1151
1152/// This class verifies that all operands of the specified op have a float type,
1153/// a vector thereof, or a tensor thereof.
1154template <typename ConcreteType>
1155class OperandsAreFloatLike
1156 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1157public:
1158 static LogicalResult verifyTrait(Operation *op) {
1159 return impl::verifyOperandsAreFloatLike(op);
1160 }
1161};
1162
1163/// This class verifies that all operands of the specified op have a signless
1164/// integer or index type, a vector thereof, or a tensor thereof.
1165template <typename ConcreteType>
1166class OperandsAreSignlessIntegerLike
1167 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1168public:
1169 static LogicalResult verifyTrait(Operation *op) {
1170 return impl::verifyOperandsAreSignlessIntegerLike(op);
1171 }
1172};
1173
1174/// This class verifies that all operands of the specified op have the same
1175/// type.
1176template <typename ConcreteType>
1177class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1178public:
1179 static LogicalResult verifyTrait(Operation *op) {
1180 return impl::verifySameTypeOperands(op);
1181 }
1182};
1183
1184/// This class provides the API for a sub-set of ops that are known to be
1185/// constant-like. These are non-side effecting operations with one result and
1186/// zero operands that can always be folded to a specific attribute value.
1187template <typename ConcreteType>
1188class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1189public:
1190 static LogicalResult verifyTrait(Operation *op) {
1191 static_assert(ConcreteType::template hasTrait<OneResult>(),
1192 "expected operation to produce one result");
1193 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1194 "expected operation to take zero operands");
1195 // TODO: We should verify that the operation can always be folded, but this
1196 // requires that the attributes of the op already be verified. We should add
1197 // support for verifying traits "after" the operation to enable this use
1198 // case.
1199 return success();
1200 }
1201};
1202
1203/// This class provides the API for ops that are known to be isolated from
1204/// above.
1205template <typename ConcreteType>
1206class IsIsolatedFromAbove
1207 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1208public:
1209 static LogicalResult verifyTrait(Operation *op) {
1210 return impl::verifyIsIsolatedFromAbove(op);
1211 }
1212};
1213
1214/// A trait of region holding operations that defines a new scope for polyhedral
1215/// optimization purposes. Any SSA values of 'index' type that either dominate
1216/// such an operation or are used at the top-level of such an operation
1217/// automatically become valid symbols for the polyhedral scope defined by that
1218/// operation. For more details, see `Traits.md#AffineScope`.
1219template <typename ConcreteType>
1220class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1221public:
1222 static LogicalResult verifyTrait(Operation *op) {
1223 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1224 "expected operation to have one or more regions");
1225 return success();
1226 }
1227};
1228
1229/// A trait of region holding operations that define a new scope for automatic
1230/// allocations, i.e., allocations that are freed when control is transferred
1231/// back from the operation's region. Any operations performing such allocations
1232/// (for eg. memref.alloca) will have their allocations automatically freed at
1233/// their closest enclosing operation with this trait.
1234template <typename ConcreteType>
1235class AutomaticAllocationScope
1236 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1237public:
1238 static LogicalResult verifyTrait(Operation *op) {
1239 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1240 "expected operation to have one or more regions");
1241 return success();
1242 }
1243};
1244
1245/// This class provides a verifier for ops that are expecting their parent
1246/// to be one of the given parent ops
1247template <typename... ParentOpTypes>
1248struct HasParent {
1249 template <typename ConcreteType>
1250 class Impl : public TraitBase<ConcreteType, Impl> {
1251 public:
1252 static LogicalResult verifyTrait(Operation *op) {
1253 if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1254 return success();
1255
1256 return op->emitOpError()
1257 << "expects parent op "
1258 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1259 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1260 << "'";
1261 }
1262 };
1263};
1264
1265/// A trait for operations that have an attribute specifying operand segments.
1266///
1267/// Certain operations can have multiple variadic operands and their size
1268/// relationship is not always known statically. For such cases, we need
1269/// a per-op-instance specification to divide the operands into logical groups
1270/// or segments. This can be modeled by attributes. The attribute will be named
1271/// as `operand_segment_sizes`.
1272///
1273/// This trait verifies the attribute for specifying operand segments has
1274/// the correct type (1D vector) and values (non-negative), etc.
1275template <typename ConcreteType>
1276class AttrSizedOperandSegments
1277 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1278public:
1279 static StringRef getOperandSegmentSizeAttr() {
1280 return "operand_segment_sizes";
1281 }
1282
1283 static LogicalResult verifyTrait(Operation *op) {
1284 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1285 op, getOperandSegmentSizeAttr());
1286 }
1287};
1288
1289/// Similar to AttrSizedOperandSegments but used for results.
1290template <typename ConcreteType>
1291class AttrSizedResultSegments
1292 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1293public:
1294 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1295
1296 static LogicalResult verifyTrait(Operation *op) {
1297 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1298 op, getResultSegmentSizeAttr());
1299 }
1300};
1301
1302/// This trait provides a verifier for ops that are expecting their regions to
1303/// not have any arguments
1304template <typename ConcrentType>
1305struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1306 static LogicalResult verifyTrait(Operation *op) {
1307 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1308 }
1309};
1310
1311// This trait is used to flag operations that consume or produce
1312// values of `MemRef` type where those references can be 'normalized'.
1313// TODO: Right now, the operands of an operation are either all normalizable,
1314// or not. In the future, we may want to allow some of the operands to be
1315// normalizable.
1316template <typename ConcrentType>
1317struct MemRefsNormalizable
1318 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1319
1320/// This trait tags element-wise ops on vectors or tensors.
1321///
1322/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1323/// trait. In particular, broadcasting behavior is not allowed.
1324///
1325/// An `Elementwise` op must satisfy the following properties:
1326///
1327/// 1. If any result is a vector/tensor then at least one operand must also be a
1328/// vector/tensor.
1329/// 2. If any operand is a vector/tensor then there must be at least one result
1330/// and all results must be vectors/tensors.
1331/// 3. All operand and result vector/tensor types must be of the same shape. The
1332/// shape may be dynamic in which case the op's behaviour is undefined for
1333/// non-matching shapes.
1334/// 4. The operation must be elementwise on its vector/tensor operands and
1335/// results. When applied to single-element vectors/tensors, the result must
1336/// be the same per elememnt.
1337///
1338/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1339/// interface `ElementwiseTypeInterface` that describes the container types for
1340/// which the operation is elementwise.
1341///
1342/// Rationale:
1343/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1344/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1345/// generic definition of the iteration space.
1346/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1347/// the same pattern, as otherwise lots of special handling for type
1348/// mismatches would be needed.
1349/// - 4. guarantees that no error handling is needed. Higher-level dialects
1350/// should reify any needed guards or error handling code before lowering to
1351/// an `Elementwise` op.
1352template <typename ConcreteType>
1353struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1354 static LogicalResult verifyTrait(Operation *op) {
1355 return ::mlir::OpTrait::impl::verifyElementwise(op);
1356 }
1357};
1358
1359/// This trait tags `Elementwise` operatons that can be systematically
1360/// scalarized. All vector/tensor operands and results are then replaced by
1361/// scalars of the respective element type. Semantically, this is the operation
1362/// on a single element of the vector/tensor.
1363///
1364/// Rationale:
1365/// Allow to define the vector/tensor semantics of elementwise operations based
1366/// on the same op's behavior on scalars. This provides a constructive procedure
1367/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1368///
1369/// Example:
1370/// ```
1371/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
1372/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1373/// -> tensor<?xf32>
1374/// ```
1375/// can be scalarized to
1376///
1377/// ```
1378/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1379/// : (i1, f32, f32) -> f32
1380/// ```
1381template <typename ConcreteType>
1382struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1383 static LogicalResult verifyTrait(Operation *op) {
1384 static_assert(
1385 ConcreteType::template hasTrait<Elementwise>(),
1386 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1387 return success();
1388 }
1389};
1390
1391/// This trait tags `Elementwise` operatons that can be systematically
1392/// vectorized. All scalar operands and results are then replaced by vectors
1393/// with the respective element type. Semantically, this is the operation on
1394/// multiple elements simultaneously. See also `Tensorizable`.
1395///
1396/// Rationale:
1397/// Provide the reverse to `Scalarizable` which, when chained together, allows
1398/// reasoning about the relationship between the tensor and vector case.
1399/// Additionally, it permits reasoning about promoting scalars to vectors via
1400/// broadcasting in cases like `%select_scalar_pred` below.
1401template <typename ConcreteType>
1402struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1403 static LogicalResult verifyTrait(Operation *op) {
1404 static_assert(
1405 ConcreteType::template hasTrait<Elementwise>(),
1406 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1407 return success();
1408 }
1409};
1410
1411/// This trait tags `Elementwise` operatons that can be systematically
1412/// tensorized. All scalar operands and results are then replaced by tensors
1413/// with the respective element type. Semantically, this is the operation on
1414/// multiple elements simultaneously. See also `Vectorizable`.
1415///
1416/// Rationale:
1417/// Provide the reverse to `Scalarizable` which, when chained together, allows
1418/// reasoning about the relationship between the tensor and vector case.
1419/// Additionally, it permits reasoning about promoting scalars to tensors via
1420/// broadcasting in cases like `%select_scalar_pred` below.
1421///
1422/// Examples:
1423/// ```
1424/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1425/// ```
1426/// can be tensorized to
1427/// ```
1428/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1429/// -> tensor<?xf32>
1430/// ```
1431///
1432/// ```
1433/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
1434/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1435/// ```
1436/// can be tensorized to
1437/// ```
1438/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
1439/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1440/// -> tensor<?xf32>
1441/// ```
1442template <typename ConcreteType>
1443struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1444 static LogicalResult verifyTrait(Operation *op) {
1445 static_assert(
1446 ConcreteType::template hasTrait<Elementwise>(),
1447 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1448 return success();
1449 }
1450};
1451
1452/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1453/// provide an easy way for scalar operations to conveniently generalize their
1454/// behavior to vectors/tensors, and systematize conversion between these forms.
1455bool hasElementwiseMappableTraits(Operation *op);
1456
1457} // namespace OpTrait
1458
1459//===----------------------------------------------------------------------===//
1460// Internal Trait Utilities
1461//===----------------------------------------------------------------------===//
1462
1463namespace op_definition_impl {
1464//===----------------------------------------------------------------------===//
1465// Trait Existence
1466
1467/// Returns true if this given Trait ID matches the IDs of any of the provided
1468/// trait types `Traits`.
1469template <template <typename T> class... Traits>
1470static bool hasTrait(TypeID traitID) {
1471 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1472 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1473 if (traitIDs[i] == traitID)
1474 return true;
1475 return false;
1476}
1477
1478//===----------------------------------------------------------------------===//
1479// Trait Folding
1480
1481/// Trait to check if T provides a 'foldTrait' method for single result
1482/// operations.
1483template <typename T, typename... Args>
1484using has_single_result_fold_trait = decltype(T::foldTrait(
1485 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1486template <typename T>
1487using detect_has_single_result_fold_trait =
1488 llvm::is_detected<has_single_result_fold_trait, T>;
1489/// Trait to check if T provides a general 'foldTrait' method.
1490template <typename T, typename... Args>
1491using has_fold_trait =
1492 decltype(T::foldTrait(std::declval<Operation *>(),
1493 std::declval<ArrayRef<Attribute>>(),
1494 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1495template <typename T>
1496using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1497/// Trait to check if T provides any `foldTrait` method.
1498/// NOTE: This should use std::disjunction when C++17 is available.
1499template <typename T>
1500using detect_has_any_fold_trait =
1501 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1502 detect_has_fold_trait<T>,
1503 detect_has_single_result_fold_trait<T>>;
1504
1505/// Returns the result of folding a trait that implements a `foldTrait` function
1506/// that is specialized for operations that have a single result.
1507template <typename Trait>
1508static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1509 LogicalResult>
1510foldTrait(Operation *op, ArrayRef<Attribute> operands,
1511 SmallVectorImpl<OpFoldResult> &results) {
1512 assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__
))
1513 "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__
))
1514 "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__
))
;
1515 // If a previous trait has already been folded and replaced this operation, we
1516 // fail to fold this trait.
1517 if (!results.empty())
1518 return failure();
1519
1520 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1521 if (result.template dyn_cast<Value>() != op->getResult(0))
1522 results.push_back(result);
1523 return success();
1524 }
1525 return failure();
1526}
1527/// Returns the result of folding a trait that implements a generalized
1528/// `foldTrait` function that is supports any operation type.
1529template <typename Trait>
1530static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1531foldTrait(Operation *op, ArrayRef<Attribute> operands,
1532 SmallVectorImpl<OpFoldResult> &results) {
1533 // If a previous trait has already been folded and replaced this operation, we
1534 // fail to fold this trait.
1535 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1536}
1537
1538/// The internal implementation of `foldTraits` below that returns the result of
1539/// folding a set of trait types `Ts` that implement a `foldTrait` method.
1540template <typename... Ts>
1541static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1542 SmallVectorImpl<OpFoldResult> &results,
1543 std::tuple<Ts...> *) {
1544 bool anyFolded = false;
1545 (void)std::initializer_list<int>{
1546 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1547 return success(anyFolded);
1548}
1549
1550/// Given a tuple type containing a set of traits that contain a `foldTrait`
1551/// method, return the result of folding the given operation.
1552template <typename TraitTupleT>
1553static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
1554foldTraits(Operation *op, ArrayRef<Attribute> operands,
1555 SmallVectorImpl<OpFoldResult> &results) {
1556 return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1557}
1558/// A variant of the method above that is specialized when there are no traits
1559/// that contain a `foldTrait` method.
1560template <typename TraitTupleT>
1561static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
1562foldTraits(Operation *op, ArrayRef<Attribute> operands,
1563 SmallVectorImpl<OpFoldResult> &results) {
1564 return failure();
1565}
1566
1567//===----------------------------------------------------------------------===//
1568// Trait Verification
1569
1570/// Trait to check if T provides a `verifyTrait` method.
1571template <typename T, typename... Args>
1572using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1573template <typename T>
1574using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1575
1576/// The internal implementation of `verifyTraits` below that returns the result
1577/// of verifying the current operation with all of the provided trait types
1578/// `Ts`.
1579template <typename... Ts>
1580static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1581 LogicalResult result = success();
1582 (void)std::initializer_list<int>{
1583 (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1584 return result;
1585}
1586
1587/// Given a tuple type containing a set of traits that contain a
1588/// `verifyTrait` method, return the result of verifying the given operation.
1589template <typename TraitTupleT>
1590static LogicalResult verifyTraits(Operation *op) {
1591 return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1592}
1593} // namespace op_definition_impl
1594
1595//===----------------------------------------------------------------------===//
1596// Operation Definition classes
1597//===----------------------------------------------------------------------===//
1598
1599/// This provides public APIs that all operations should have. The template
1600/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1601/// are base classes by the policy pattern.
1602template <typename ConcreteType, template <typename T> class... Traits>
1603class Op : public OpState, public Traits<ConcreteType>... {
1604public:
1605 /// Inherit getOperation from `OpState`.
1606 using OpState::getOperation;
1607
1608 /// Return if this operation contains the provided trait.
1609 template <template <typename T> class Trait>
1610 static constexpr bool hasTrait() {
1611 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1612 }
1613
1614 /// Create a deep copy of this operation.
1615 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1616
1617 /// Create a partial copy of this operation without traversing into attached
1618 /// regions. The new operation will have the same number of regions as the
1619 /// original one, but they will be left empty.
1620 ConcreteType cloneWithoutRegions() {
1621 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1622 }
1623
1624 /// Return true if this "op class" can match against the specified operation.
1625 static bool classof(Operation *op) {
1626 if (auto info = op->getRegisteredInfo())
1627 return TypeID::get<ConcreteType>() == info->getTypeID();
1628#ifndef NDEBUG
1629 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1630 llvm::report_fatal_error(
1631 "classof on '" + ConcreteType::getOperationName() +
1632 "' failed due to the operation not being registered");
1633#endif
1634 return false;
1635 }
1636 /// Provide `classof` support for other OpBase derived classes, such as
1637 /// Interfaces.
1638 template <typename T>
1639 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1640 classof(const T *op) {
1641 return classof(const_cast<T *>(op)->getOperation());
1642 }
1643
1644 /// Expose the type we are instantiated on to template machinery that may want
1645 /// to introspect traits on this operation.
1646 using ConcreteOpType = ConcreteType;
1647
1648 /// This is a public constructor. Any op can be initialized to null.
1649 explicit Op() : OpState(nullptr) {}
1650 Op(std::nullptr_t) : OpState(nullptr) {}
1651
1652 /// This is a public constructor to enable access via the llvm::cast family of
1653 /// methods. This should not be used directly.
1654 explicit Op(Operation *state) : OpState(state) {}
1655
1656 /// Methods for supporting PointerLikeTypeTraits.
1657 const void *getAsOpaquePointer() const {
1658 return static_cast<const void *>((Operation *)*this);
1659 }
1660 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1661 return ConcreteOpType(
1662 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1663 }
1664
1665 /// Attach the given models as implementations of the corresponding interfaces
1666 /// for the concrete operation.
1667 template <typename... Models>
1668 static void attachInterface(MLIRContext &context) {
1669 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
1670 ConcreteType::getOperationName(), &context);
1671 if (!info)
1672 llvm::report_fatal_error(
1673 "Attempting to attach an interface to an unregistered operation " +
1674 ConcreteType::getOperationName() + ".");
1675 info->attachInterface<Models...>();
1676 }
1677
1678private:
1679 /// Trait to check if T provides a 'fold' method for a single result op.
1680 template <typename T, typename... Args>
1681 using has_single_result_fold =
1682 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1683 template <typename T>
1684 using detect_has_single_result_fold =
1685 llvm::is_detected<has_single_result_fold, T>;
1686 /// Trait to check if T provides a general 'fold' method.
1687 template <typename T, typename... Args>
1688 using has_fold = decltype(std::declval<T>().fold(
1689 std::declval<ArrayRef<Attribute>>(),
1690 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1691 template <typename T>
1692 using detect_has_fold = llvm::is_detected<has_fold, T>;
1693 /// Trait to check if T provides a 'print' method.
1694 template <typename T, typename... Args>
1695 using has_print =
1696 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1697 template <typename T>
1698 using detect_has_print = llvm::is_detected<has_print, T>;
1699 /// A tuple type containing the traits that have a `foldTrait` function.
1700 using FoldableTraitsTupleT = typename detail::FilterTypes<
1701 op_definition_impl::detect_has_any_fold_trait,
1702 Traits<ConcreteType>...>::type;
1703 /// A tuple type containing the traits that have a verify function.
1704 using VerifiableTraitsTupleT =
1705 typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1706 Traits<ConcreteType>...>::type;
1707
1708 /// Returns an interface map containing the interfaces registered to this
1709 /// operation.
1710 static detail::InterfaceMap getInterfaceMap() {
1711 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1712 }
1713
1714 /// Return the internal implementations of each of the OperationName
1715 /// hooks.
1716 /// Implementation of `FoldHookFn` OperationName hook.
1717 static OperationName::FoldHookFn getFoldHookFn() {
1718 return getFoldHookFnImpl<ConcreteType>();
1719 }
1720 /// The internal implementation of `getFoldHookFn` above that is invoked if
1721 /// the operation is single result and defines a `fold` method.
1722 template <typename ConcreteOpT>
1723 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1724 Traits<ConcreteOpT>...>::value &&
1725 detect_has_single_result_fold<ConcreteOpT>::value,
1726 OperationName::FoldHookFn>
1727 getFoldHookFnImpl() {
1728 return [](Operation *op, ArrayRef<Attribute> operands,
1729 SmallVectorImpl<OpFoldResult> &results) {
1730 return foldSingleResultHook<ConcreteOpT>(op, operands, results);
1731 };
1732 }
1733 /// The internal implementation of `getFoldHookFn` above that is invoked if
1734 /// the operation is not single result and defines a `fold` method.
1735 template <typename ConcreteOpT>
1736 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1737 Traits<ConcreteOpT>...>::value &&
1738 detect_has_fold<ConcreteOpT>::value,
1739 OperationName::FoldHookFn>
1740 getFoldHookFnImpl() {
1741 return [](Operation *op, ArrayRef<Attribute> operands,
1742 SmallVectorImpl<OpFoldResult> &results) {
1743 return foldHook<ConcreteOpT>(op, operands, results);
1744 };
1745 }
1746 /// The internal implementation of `getFoldHookFn` above that is invoked if
1747 /// the operation does not define a `fold` method.
1748 template <typename ConcreteOpT>
1749 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1750 !detect_has_fold<ConcreteOpT>::value,
1751 OperationName::FoldHookFn>
1752 getFoldHookFnImpl() {
1753 return [](Operation *op, ArrayRef<Attribute> operands,
1754 SmallVectorImpl<OpFoldResult> &results) {
1755 // In this case, we only need to fold the traits of the operation.
1756 return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
1757 results);
1758 };
1759 }
1760 /// Return the result of folding a single result operation that defines a
1761 /// `fold` method.
1762 template <typename ConcreteOpT>
1763 static LogicalResult
1764 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1765 SmallVectorImpl<OpFoldResult> &results) {
1766 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1767
1768 // If the fold failed or was in-place, try to fold the traits of the
1769 // operation.
1770 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1771 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1772 op, operands, results)))
1773 return success();
1774 return success(static_cast<bool>(result));
1775 }
1776 results.push_back(result);
1777 return success();
1778 }
1779 /// Return the result of folding an operation that defines a `fold` method.
1780 template <typename ConcreteOpT>
1781 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1782 SmallVectorImpl<OpFoldResult> &results) {
1783 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1784
1785 // If the fold failed or was in-place, try to fold the traits of the
1786 // operation.
1787 if (failed(result) || results.empty()) {
1788 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1789 op, operands, results)))
1790 return success();
1791 }
1792 return result;
1793 }
1794
1795 /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1796 static OperationName::GetCanonicalizationPatternsFn
1797 getGetCanonicalizationPatternsFn() {
1798 return &ConcreteType::getCanonicalizationPatterns;
1799 }
1800 /// Implementation of `GetHasTraitFn`
1801 static OperationName::HasTraitFn getHasTraitFn() {
1802 return
1803 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1804 }
1805 /// Implementation of `ParseAssemblyFn` OperationName hook.
1806 static OperationName::ParseAssemblyFn getParseAssemblyFn() {
1807 return &ConcreteType::parse;
1808 }
1809 /// Implementation of `PrintAssemblyFn` OperationName hook.
1810 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1811 return getPrintAssemblyFnImpl<ConcreteType>();
1812 }
1813 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1814 /// the concrete operation does not define a `print` method.
1815 template <typename ConcreteOpT>
1816 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1817 OperationName::PrintAssemblyFn>
1818 getPrintAssemblyFnImpl() {
1819 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1820 return OpState::print(op, printer, defaultDialect);
1821 };
1822 }
1823 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1824 /// the concrete operation defines a `print` method.
1825 template <typename ConcreteOpT>
1826 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1827 OperationName::PrintAssemblyFn>
1828 getPrintAssemblyFnImpl() {
1829 return &printAssembly;
1830 }
1831 static void printAssembly(Operation *op, OpAsmPrinter &p,
1832 StringRef defaultDialect) {
1833 OpState::printOpName(op, p, defaultDialect);
1834 return cast<ConcreteType>(op).print(p);
1835 }
1836 /// Implementation of `VerifyInvariantsFn` OperationName hook.
1837 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
1838 return &verifyInvariants;
1839 }
1840
1841 static constexpr bool hasNoDataMembers() {
1842 // Checking that the derived class does not define any member by comparing
1843 // its size to an ad-hoc EmptyOp.
1844 class EmptyOp : public Op<EmptyOp, Traits...> {};
1845 return sizeof(ConcreteType) == sizeof(EmptyOp);
1846 }
1847
1848 static LogicalResult verifyInvariants(Operation *op) {
1849 static_assert(hasNoDataMembers(),
1850 "Op class shouldn't define new data members");
1851 return failure(
1852 failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1853 failed(cast<ConcreteType>(op).verify()));
1854 }
1855
1856 /// Allow access to internal implementation methods.
1857 friend RegisteredOperationName;
1858};
1859
1860/// This class represents the base of an operation interface. See the definition
1861/// of `detail::Interface` for requirements on the `Traits` type.
1862template <typename ConcreteType, typename Traits>
1863class OpInterface
1864 : public detail::Interface<ConcreteType, Operation *, Traits,
1865 Op<ConcreteType>, OpTrait::TraitBase> {
1866public:
1867 using Base = OpInterface<ConcreteType, Traits>;
1868 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1869 Op<ConcreteType>, OpTrait::TraitBase>;
1870
1871 /// Inherit the base class constructor.
1872 using InterfaceBase::InterfaceBase;
1873
1874protected:
1875 /// Returns the impl interface instance for the given operation.
1876 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1877 OperationName name = op->getName();
1878
1879 // Access the raw interface from the operation info.
1880 if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1881 if (auto *opIface = rInfo->getInterface<ConcreteType>())
1882 return opIface;
1883 // Fallback to the dialect to provide it with a chance to implement this
1884 // interface for this operation.
1885 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
1886 op->getName());
1887 }
1888 // Fallback to the dialect to provide it with a chance to implement this
1889 // interface for this operation.
1890 if (Dialect *dialect = name.getDialect())
1891 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1892 return nullptr;
1893 }
1894
1895 /// Allow access to `getInterfaceFor`.
1896 friend InterfaceBase;
1897};
1898
1899//===----------------------------------------------------------------------===//
1900// Common Operation Folders/Parsers/Printers
1901//===----------------------------------------------------------------------===//
1902
1903// These functions are out-of-line implementations of the methods in UnaryOp and
1904// BinaryOp, which avoids them being template instantiated/duplicated.
1905namespace impl {
1906ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1907 OperationState &result);
1908
1909void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1910 Value rhs);
1911ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1912 OperationState &result);
1913
1914// Prints the given binary `op` in custom assembly form if both the two operands
1915// and the result have the same time. Otherwise, prints the generic assembly
1916// form.
1917void printOneResultOp(Operation *op, OpAsmPrinter &p);
1918} // namespace impl
1919
1920// These functions are out-of-line implementations of the methods in
1921// CastOpInterface, which avoids them being template instantiated/duplicated.
1922namespace impl {
1923/// Attempt to fold the given cast operation.
1924LogicalResult foldCastInterfaceOp(Operation *op,
1925 ArrayRef<Attribute> attrOperands,
1926 SmallVectorImpl<OpFoldResult> &foldResults);
1927/// Attempt to verify the given cast operation.
1928LogicalResult verifyCastInterfaceOp(
1929 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1930
1931// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1932// need for them, but some older ODS code in `std` still depends on them).
1933void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1934 Type destType);
1935ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1936void printCastOp(Operation *op, OpAsmPrinter &p);
1937// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
1938// when all uses have been updated. Also, consider adding functionality to
1939// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
1940// generically.
1941Value foldCastOp(Operation *op);
1942LogicalResult verifyCastOp(Operation *op,
1943 function_ref<bool(Type, Type)> areCastCompatible);
1944} // namespace impl
1945} // namespace mlir
1946
1947namespace llvm {
1948
1949template <typename T>
1950struct DenseMapInfo<
1951 T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
1952 static inline T getEmptyKey() {
1953 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1954 return T::getFromOpaquePointer(pointer);
1955 }
1956 static inline T getTombstoneKey() {
1957 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1958 return T::getFromOpaquePointer(pointer);
1959 }
1960 static unsigned getHashValue(T val) {
1961 return hash_value(val.getAsOpaquePointer());
1962 }
1963 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
1964};
1965
1966} // namespace llvm
1967
1968#endif

/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/include/mlir/Support/LogicalResult.h

1//===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_SUPPORT_LOGICALRESULT_H
10#define MLIR_SUPPORT_LOGICALRESULT_H
11
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/Optional.h"
14
15namespace mlir {
16
17/// This class represents an efficient way to signal success or failure. It
18/// should be preferred over the use of `bool` when appropriate, as it avoids
19/// all of the ambiguity that arises in interpreting a boolean result. This
20/// class is marked as NODISCARD to ensure that the result is processed. Users
21/// may explicitly discard a result by using `(void)`, e.g.
22/// `(void)functionThatReturnsALogicalResult();`. Given the intended nature of
23/// this class, it generally shouldn't be used as the result of functions that
24/// very frequently have the result ignored. This class is intended to be used
25/// in conjunction with the utility functions below.
26struct LLVM_NODISCARD[[clang::warn_unused_result]] LogicalResult {
27public:
28 /// If isSuccess is true a `success` result is generated, otherwise a
29 /// 'failure' result is generated.
30 static LogicalResult success(bool isSuccess = true) {
31 return LogicalResult(isSuccess);
32 }
33
34 /// If isFailure is true a `failure` result is generated, otherwise a
35 /// 'success' result is generated.
36 static LogicalResult failure(bool isFailure = true) {
37 return success(!isFailure);
38 }
39
40 /// Returns true if the provided LogicalResult corresponds to a success value.
41 bool succeeded() const { return isSuccess; }
42
43 /// Returns true if the provided LogicalResult corresponds to a failure value.
44 bool failed() const { return !succeeded(); }
8
Assuming the condition is false
9
Returning zero, which participates in a condition later
24
Assuming the condition is false
25
Returning zero, which participates in a condition later
45
46private:
47 LogicalResult(bool isSuccess) : isSuccess(isSuccess) {}
48
49 /// Boolean indicating if this is a success result, if false this is a
50 /// failure result.
51 bool isSuccess;
52};
53
54/// Utility function to generate a LogicalResult. If isSuccess is true a
55/// `success` result is generated, otherwise a 'failure' result is generated.
56inline LogicalResult success(bool isSuccess = true) {
57 return LogicalResult::success(isSuccess);
58}
59
60/// Utility function to generate a LogicalResult. If isFailure is true a
61/// `failure` result is generated, otherwise a 'success' result is generated.
62inline LogicalResult failure(bool isFailure = true) {
63 return LogicalResult::failure(isFailure);
64}
65
66/// Utility function that returns true if the provided LogicalResult corresponds
67/// to a success value.
68inline bool succeeded(LogicalResult result) { return result.succeeded(); }
69
70/// Utility function that returns true if the provided LogicalResult corresponds
71/// to a failure value.
72inline bool failed(LogicalResult result) { return result.failed(); }
23
Calling 'LogicalResult::failed'
26
Returning from 'LogicalResult::failed'
27
Returning zero, which participates in a condition later
73
74/// This class provides support for representing a failure result, or a valid
75/// value of type `T`. This allows for integrating with LogicalResult, while
76/// also providing a value on the success path.
77template <typename T> class LLVM_NODISCARD[[clang::warn_unused_result]] FailureOr : public Optional<T> {
78public:
79 /// Allow constructing from a LogicalResult. The result *must* be a failure.
80 /// Success results should use a proper instance of type `T`.
81 FailureOr(LogicalResult result) {
82 assert(failed(result) &&(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'"
) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\""
, "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__
__PRETTY_FUNCTION__))
83 "success should be constructed with an instance of 'T'")(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'"
) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\""
, "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__
__PRETTY_FUNCTION__))
;
84 }
85 FailureOr() : FailureOr(failure()) {}
86 FailureOr(T &&y) : Optional<T>(std::forward<T>(y)) {}
87 FailureOr(const T &y) : Optional<T>(y) {}
88 template <typename U,
89 std::enable_if_t<std::is_constructible<T, U>::value> * = nullptr>
90 FailureOr(const FailureOr<U> &other)
91 : Optional<T>(failed(other) ? Optional<T>() : Optional<T>(*other)) {}
92
93 operator LogicalResult() const { return success(this->hasValue()); }
94
95private:
96 /// Hide the bool conversion as it easily creates confusion.
97 using Optional<T>::operator bool;
98 using Optional<T>::hasValue;
99};
100
101} // namespace mlir
102
103#endif // MLIR_SUPPORT_LOGICALRESULT_H

/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/include/mlir/IR/Types.h

1//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_TYPES_H
10#define MLIR_IR_TYPES_H
11
12#include "mlir/IR/TypeSupport.h"
13#include "llvm/ADT/ArrayRef.h"
14#include "llvm/ADT/DenseMapInfo.h"
15#include "llvm/Support/PointerLikeTypeTraits.h"
16
17namespace mlir {
18/// Instances of the Type class are uniqued, have an immutable identifier and an
19/// optional mutable component. They wrap a pointer to the storage object owned
20/// by MLIRContext. Therefore, instances of Type are passed around by value.
21///
22/// Some types are "primitives" meaning they do not have any parameters, for
23/// example the Index type. Parametric types have additional information that
24/// differentiates the types of the same class, for example the Integer type has
25/// bitwidth, making i8 and i16 belong to the same kind by be different
26/// instances of the IntegerType. Type parameters are part of the unique
27/// immutable key. The mutable component of the type can be modified after the
28/// type is created, but cannot affect the identity of the type.
29///
30/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
31///
32/// Derived type classes are expected to implement several required
33/// implementation hooks:
34/// * Optional:
35/// - static LogicalResult verify(
36/// function_ref<InFlightDiagnostic()> emitError,
37/// Args... args)
38/// * This method is invoked when calling the 'TypeBase::get/getChecked'
39/// methods to ensure that the arguments passed in are valid to construct
40/// a type instance with.
41/// * This method is expected to return failure if a type cannot be
42/// constructed with 'args', success otherwise.
43/// * 'args' must correspond with the arguments passed into the
44/// 'TypeBase::get' call.
45///
46///
47/// Type storage objects inherit from TypeStorage and contain the following:
48/// - The dialect that defined the type.
49/// - Any parameters of the type.
50/// - An optional mutable component.
51/// For non-parametric types, a convenience DefaultTypeStorage is provided.
52/// Parametric storage types must derive TypeStorage and respect the following:
53/// - Define a type alias, KeyTy, to a type that uniquely identifies the
54/// instance of the type.
55/// * The key type must be constructible from the values passed into the
56/// detail::TypeUniquer::get call.
57/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
58/// storage class must define a hashing method:
59/// 'static unsigned hashKey(const KeyTy &)'
60///
61/// - Provide a method, 'bool operator==(const KeyTy &) const', to
62/// compare the storage instance against an instance of the key type.
63///
64/// - Provide a static construction method:
65/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
66/// that builds a unique instance of the derived storage. The arguments to
67/// this function are an allocator to store any uniqued data within the
68/// context and the key type for this storage.
69///
70/// - If they have a mutable component, this component must not be a part of
71// the key.
72class Type {
73public:
74 /// Utility class for implementing types.
75 template <typename ConcreteType, typename BaseType, typename StorageType,
76 template <typename T> class... Traits>
77 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
78 detail::TypeUniquer, Traits...>;
79
80 using ImplType = TypeStorage;
81
82 using AbstractTy = AbstractType;
83
84 constexpr Type() {}
85 /* implicit */ Type(const ImplType *impl)
86 : impl(const_cast<ImplType *>(impl)) {}
87
88 Type(const Type &other) = default;
89 Type &operator=(const Type &other) = default;
90
91 bool operator==(Type other) const { return impl == other.impl; }
92 bool operator!=(Type other) const { return !(*this == other); }
93 explicit operator bool() const { return impl; }
94
95 bool operator!() const { return impl == nullptr; }
15
Assuming the condition is false
16
Returning zero, which participates in a condition later
96
97 template <typename U> bool isa() const;
98 template <typename First, typename Second, typename... Rest> bool isa() const;
99 template <typename U> U dyn_cast() const;
100 template <typename U> U dyn_cast_or_null() const;
101 template <typename U> U cast() const;
102
103 // Support type casting Type to itself.
104 static bool classof(Type) { return true; }
105
106 /// Return a unique identifier for the concrete type. This is used to support
107 /// dynamic type casting.
108 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
109
110 /// Return the MLIRContext in which this type was uniqued.
111 MLIRContext *getContext() const;
112
113 /// Get the dialect this type is registered to.
114 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
115
116 // Convenience predicates. This is only for floating point types,
117 // derived types should use isa/dyn_cast.
118 bool isIndex() const;
119 bool isBF16() const;
120 bool isF16() const;
121 bool isF32() const;
122 bool isF64() const;
123 bool isF80() const;
124 bool isF128() const;
125
126 /// Return true if this is an integer type with the specified width.
127 bool isInteger(unsigned width) const;
128 /// Return true if this is a signless integer type (with the specified width).
129 bool isSignlessInteger() const;
130 bool isSignlessInteger(unsigned width) const;
131 /// Return true if this is a signed integer type (with the specified width).
132 bool isSignedInteger() const;
133 bool isSignedInteger(unsigned width) const;
134 /// Return true if this is an unsigned integer type (with the specified
135 /// width).
136 bool isUnsignedInteger() const;
137 bool isUnsignedInteger(unsigned width) const;
138
139 /// Return the bit width of an integer or a float type, assert failure on
140 /// other types.
141 unsigned getIntOrFloatBitWidth() const;
142
143 /// Return true if this is a signless integer or index type.
144 bool isSignlessIntOrIndex() const;
145 /// Return true if this is a signless integer, index, or float type.
146 bool isSignlessIntOrIndexOrFloat() const;
147 /// Return true of this is a signless integer or a float type.
148 bool isSignlessIntOrFloat() const;
149
150 /// Return true if this is an integer (of any signedness) or an index type.
151 bool isIntOrIndex() const;
152 /// Return true if this is an integer (of any signedness) or a float type.
153 bool isIntOrFloat() const;
154 /// Return true if this is an integer (of any signedness), index, or float
155 /// type.
156 bool isIntOrIndexOrFloat() const;
157
158 /// Print the current type.
159 void print(raw_ostream &os) const;
160 void dump() const;
161
162 friend ::llvm::hash_code hash_value(Type arg);
163
164 /// Methods for supporting PointerLikeTypeTraits.
165 const void *getAsOpaquePointer() const {
166 return static_cast<const void *>(impl);
167 }
168 static Type getFromOpaquePointer(const void *pointer) {
169 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
170 }
171
172 /// Returns true if the type was registered with a particular trait.
173 template <template <typename T> class Trait>
174 bool hasTrait() {
175 return getAbstractType().hasTrait<Trait>();
176 }
177
178 /// Return the abstract type descriptor for this type.
179 const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
180
181protected:
182 ImplType *impl{nullptr};
183};
184
185inline raw_ostream &operator<<(raw_ostream &os, Type type) {
186 type.print(os);
187 return os;
188}
189
190//===----------------------------------------------------------------------===//
191// TypeTraitBase
192//===----------------------------------------------------------------------===//
193
194namespace TypeTrait {
195/// This class represents the base of a type trait.
196template <typename ConcreteType, template <typename> class TraitType>
197using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
198} // namespace TypeTrait
199
200//===----------------------------------------------------------------------===//
201// TypeInterface
202//===----------------------------------------------------------------------===//
203
204/// This class represents the base of a type interface. See the definition of
205/// `detail::Interface` for requirements on the `Traits` type.
206template <typename ConcreteType, typename Traits>
207class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
208 TypeTrait::TraitBase> {
209public:
210 using Base = TypeInterface<ConcreteType, Traits>;
211 using InterfaceBase =
212 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
213 using InterfaceBase::InterfaceBase;
214
215private:
216 /// Returns the impl interface instance for the given type.
217 static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
218 return type.getAbstractType().getInterface<ConcreteType>();
219 }
220
221 /// Allow access to 'getInterfaceFor'.
222 friend InterfaceBase;
223};
224
225//===----------------------------------------------------------------------===//
226// Type Utils
227//===----------------------------------------------------------------------===//
228
229// Make Type hashable.
230inline ::llvm::hash_code hash_value(Type arg) {
231 return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
232}
233
234template <typename U> bool Type::isa() const {
235 assert(impl && "isa<> used on a null type.")(static_cast <bool> (impl && "isa<> used on a null type."
) ? void (0) : __assert_fail ("impl && \"isa<> used on a null type.\""
, "mlir/include/mlir/IR/Types.h", 235, __extension__ __PRETTY_FUNCTION__
))
;
236 return U::classof(*this);
237}
238
239template <typename First, typename Second, typename... Rest>
240bool Type::isa() const {
241 return isa<First>() || isa<Second, Rest...>();
242}
243
244template <typename U> U Type::dyn_cast() const {
245 return isa<U>() ? U(impl) : U(nullptr);
246}
247template <typename U> U Type::dyn_cast_or_null() const {
248 return (impl && isa<U>()) ? U(impl) : U(nullptr);
249}
250template <typename U> U Type::cast() const {
251 assert(isa<U>())(static_cast <bool> (isa<U>()) ? void (0) : __assert_fail
("isa<U>()", "mlir/include/mlir/IR/Types.h", 251, __extension__
__PRETTY_FUNCTION__))
;
252 return U(impl);
253}
254
255} // namespace mlir
256
257namespace llvm {
258
259// Type hash just like pointers.
260template <> struct DenseMapInfo<mlir::Type> {
261 static mlir::Type getEmptyKey() {
262 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
263 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
264 }
265 static mlir::Type getTombstoneKey() {
266 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
267 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
268 }
269 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
270 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
271};
272template <typename T>
273struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value>>
274 : public DenseMapInfo<mlir::Type> {
275 static T getEmptyKey() {
276 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
277 return T::getFromOpaquePointer(pointer);
278 }
279 static T getTombstoneKey() {
280 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
281 return T::getFromOpaquePointer(pointer);
282 }
283};
284
285/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
286template <> struct PointerLikeTypeTraits<mlir::Type> {
287public:
288 static inline void *getAsVoidPointer(mlir::Type I) {
289 return const_cast<void *>(I.getAsOpaquePointer());
290 }
291 static inline mlir::Type getFromVoidPointer(void *P) {
292 return mlir::Type::getFromOpaquePointer(P);
293 }
294 static constexpr int NumLowBitsAvailable = 3;
295};
296
297} // namespace llvm
298
299#endif // MLIR_IR_TYPES_H

/build/llvm-toolchain-snapshot-14~++20220125101009+ceec4383681c/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23
24class Builder;
25
26//===----------------------------------------------------------------------===//
27// AsmPrinter
28//===----------------------------------------------------------------------===//
29
30/// This base class exposes generic asm printer hooks, usable across the various
31/// derived printers.
32class AsmPrinter {
33public:
34 /// This class contains the internal default implementation of the base
35 /// printer methods.
36 class Impl;
37
38 /// Initialize the printer with the given internal implementation.
39 AsmPrinter(Impl &impl) : impl(&impl) {}
40 virtual ~AsmPrinter();
41
42 /// Return the raw output stream used by this printer.
43 virtual raw_ostream &getStream() const;
44
45 /// Print the given floating point value in a stabilized form that can be
46 /// roundtripped through the IR. This is the companion to the 'parseFloat'
47 /// hook on the AsmParser.
48 virtual void printFloat(const APFloat &value);
49
50 virtual void printType(Type type);
51 virtual void printAttribute(Attribute attr);
52
53 /// Trait to check if `AttrType` provides a `print` method.
54 template <typename AttrOrType>
55 using has_print_method =
56 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
57 template <typename AttrOrType>
58 using detect_has_print_method =
59 llvm::is_detected<has_print_method, AttrOrType>;
60
61 /// Print the provided attribute in the context of an operation custom
62 /// printer/parser: this will invoke directly the print method on the
63 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
64 template <typename AttrOrType,
65 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
66 *sfinae = nullptr>
67 void printStrippedAttrOrType(AttrOrType attrOrType) {
68 if (succeeded(printAlias(attrOrType)))
69 return;
70 attrOrType.print(*this);
71 }
72
73 /// SFINAE for printing the provided attribute in the context of an operation
74 /// custom printer in the case where the attribute does not define a print
75 /// method.
76 template <typename AttrOrType,
77 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
78 *sfinae = nullptr>
79 void printStrippedAttrOrType(AttrOrType attrOrType) {
80 *this << attrOrType;
81 }
82
83 /// Print the given attribute without its type. The corresponding parser must
84 /// provide a valid type for the attribute.
85 virtual void printAttributeWithoutType(Attribute attr);
86
87 /// Print the given string as a keyword, or a quoted and escaped string if it
88 /// has any special or non-printable characters in it.
89 virtual void printKeywordOrString(StringRef keyword);
90
91 /// Print the given string as a symbol reference, i.e. a form representable by
92 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
93 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
94 /// special or non-printable characters in it.
95 virtual void printSymbolName(StringRef symbolRef);
96
97 /// Print an optional arrow followed by a type list.
98 template <typename TypeRange>
99 void printOptionalArrowTypeList(TypeRange &&types) {
100 if (types.begin() != types.end())
101 printArrowTypeList(types);
102 }
103 template <typename TypeRange>
104 void printArrowTypeList(TypeRange &&types) {
105 auto &os = getStream() << " -> ";
106
107 bool wrapped = !llvm::hasSingleElement(types) ||
108 (*types.begin()).template isa<FunctionType>();
109 if (wrapped)
110 os << '(';
111 llvm::interleaveComma(types, *this);
112 if (wrapped)
113 os << ')';
114 }
115
116 /// Print the two given type ranges in a functional form.
117 template <typename InputRangeT, typename ResultRangeT>
118 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
119 auto &os = getStream();
120 os << '(';
121 llvm::interleaveComma(inputs, *this);
122 os << ')';
123 printArrowTypeList(results);
124 }
125
126protected:
127 /// Initialize the printer with no internal implementation. In this case, all
128 /// virtual methods of this class must be overriden.
129 AsmPrinter() {}
130
131private:
132 AsmPrinter(const AsmPrinter &) = delete;
133 void operator=(const AsmPrinter &) = delete;
134
135 /// Print the alias for the given attribute, return failure if no alias could
136 /// be printed.
137 virtual LogicalResult printAlias(Attribute attr);
138
139 /// Print the alias for the given type, return failure if no alias could
140 /// be printed.
141 virtual LogicalResult printAlias(Type type);
142
143 /// The internal implementation of the printer.
144 Impl *impl{nullptr};
145};
146
147template <typename AsmPrinterT>
148inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
149 AsmPrinterT &>
150operator<<(AsmPrinterT &p, Type type) {
151 p.printType(type);
152 return p;
153}
154
155template <typename AsmPrinterT>
156inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
157 AsmPrinterT &>
158operator<<(AsmPrinterT &p, Attribute attr) {
159 p.printAttribute(attr);
160 return p;
161}
162
163template <typename AsmPrinterT>
164inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
165 AsmPrinterT &>
166operator<<(AsmPrinterT &p, const APFloat &value) {
167 p.printFloat(value);
168 return p;
169}
170template <typename AsmPrinterT>
171inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
172 AsmPrinterT &>
173operator<<(AsmPrinterT &p, float value) {
174 return p << APFloat(value);
175}
176template <typename AsmPrinterT>
177inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
178 AsmPrinterT &>
179operator<<(AsmPrinterT &p, double value) {
180 return p << APFloat(value);
181}
182
183// Support printing anything that isn't convertible to one of the other
184// streamable types, even if it isn't exactly one of them. For example, we want
185// to print FunctionType with the Type version above, not have it match this.
186template <
187 typename AsmPrinterT, typename T,
188 typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
189 !std::is_convertible<T &, Type &>::value &&
190 !std::is_convertible<T &, Attribute &>::value &&
191 !std::is_convertible<T &, ValueRange>::value &&
192 !std::is_convertible<T &, APFloat &>::value &&
193 !llvm::is_one_of<T, bool, float, double>::value,
194 T>::type * = nullptr>
195inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
196 AsmPrinterT &>
197operator<<(AsmPrinterT &p, const T &other) {
198 p.getStream() << other;
199 return p;
200}
201
202template <typename AsmPrinterT>
203inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
204 AsmPrinterT &>
205operator<<(AsmPrinterT &p, bool value) {
206 return p << (value ? StringRef("true") : "false");
207}
208
209template <typename AsmPrinterT, typename ValueRangeT>
210inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
211 AsmPrinterT &>
212operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
213 llvm::interleaveComma(types, p);
214 return p;
215}
216template <typename AsmPrinterT>
217inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
218 AsmPrinterT &>
219operator<<(AsmPrinterT &p, const TypeRange &types) {
220 llvm::interleaveComma(types, p);
221 return p;
222}
223template <typename AsmPrinterT, typename ElementT>
224inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
225 AsmPrinterT &>
226operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
227 llvm::interleaveComma(types, p);
228 return p;
229}
230
231//===----------------------------------------------------------------------===//
232// OpAsmPrinter
233//===----------------------------------------------------------------------===//
234
235/// This is a pure-virtual base class that exposes the asmprinter hooks
236/// necessary to implement a custom print() method.
237class OpAsmPrinter : public AsmPrinter {
238public:
239 using AsmPrinter::AsmPrinter;
240 ~OpAsmPrinter() override;
241
242 /// Print a newline and indent the printer to the start of the current
243 /// operation.
244 virtual void printNewline() = 0;
245
246 /// Print a block argument in the usual format of:
247 /// %ssaName : type {attr1=42} loc("here")
248 /// where location printing is controlled by the standard internal option.
249 /// You may pass omitType=true to not print a type, and pass an empty
250 /// attribute list if you don't care for attributes.
251 virtual void printRegionArgument(BlockArgument arg,
252 ArrayRef<NamedAttribute> argAttrs = {},
253 bool omitType = false) = 0;
254
255 /// Print implementations for various things an operation contains.
256 virtual void printOperand(Value value) = 0;
257 virtual void printOperand(Value value, raw_ostream &os) = 0;
258
259 /// Print a comma separated list of operands.
260 template <typename ContainerType>
261 void printOperands(const ContainerType &container) {
262 printOperands(container.begin(), container.end());
263 }
264
265 /// Print a comma separated list of operands.
266 template <typename IteratorType>
267 void printOperands(IteratorType it, IteratorType end) {
268 if (it == end)
269 return;
270 printOperand(*it);
271 for (++it; it != end; ++it) {
272 getStream() << ", ";
273 printOperand(*it);
274 }
275 }
276
277 /// Print the given successor.
278 virtual void printSuccessor(Block *successor) = 0;
279
280 /// Print the successor and its operands.
281 virtual void printSuccessorAndUseList(Block *successor,
282 ValueRange succOperands) = 0;
283
284 /// If the specified operation has attributes, print out an attribute
285 /// dictionary with their values. elidedAttrs allows the client to ignore
286 /// specific well known attributes, commonly used if the attribute value is
287 /// printed some other way (like as a fixed operand).
288 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
289 ArrayRef<StringRef> elidedAttrs = {}) = 0;
290
291 /// If the specified operation has attributes, print out an attribute
292 /// dictionary prefixed with 'attributes'.
293 virtual void
294 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
295 ArrayRef<StringRef> elidedAttrs = {}) = 0;
296
297 /// Print the entire operation with the default generic assembly form.
298 /// If `printOpName` is true, then the operation name is printed (the default)
299 /// otherwise it is omitted and the print will start with the operand list.
300 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
301
302 /// Prints a region.
303 /// If 'printEntryBlockArgs' is false, the arguments of the
304 /// block are not printed. If 'printBlockTerminator' is false, the terminator
305 /// operation of the block is not printed. If printEmptyBlock is true, then
306 /// the block header is printed even if the block is empty.
307 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
308 bool printBlockTerminators = true,
309 bool printEmptyBlock = false) = 0;
310
311 /// Renumber the arguments for the specified region to the same names as the
312 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
313 /// operations. If any entry in namesToUse is null, the corresponding
314 /// argument name is left alone.
315 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
316
317 /// Prints an affine map of SSA ids, where SSA id names are used in place
318 /// of dims/symbols.
319 /// Operand values must come from single-result sources, and be valid
320 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
321 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
322 ValueRange operands) = 0;
323
324 /// Prints an affine expression of SSA ids with SSA id names used instead of
325 /// dims and symbols.
326 /// Operand values must come from single-result sources, and be valid
327 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
328 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
329 ValueRange symOperands) = 0;
330
331 /// Print the complete type of an operation in functional form.
332 void printFunctionalType(Operation *op);
333 using AsmPrinter::printFunctionalType;
334};
335
336// Make the implementations convenient to use.
337inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
338 p.printOperand(value);
339 return p;
340}
341
342template <typename T,
343 typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
344 !std::is_convertible<T &, Value &>::value,
345 T>::type * = nullptr>
346inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
347 p.printOperands(values);
348 return p;
349}
350
351inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
352 p.printSuccessor(value);
353 return p;
354}
355
356//===----------------------------------------------------------------------===//
357// AsmParser
358//===----------------------------------------------------------------------===//
359
360/// This base class exposes generic asm parser hooks, usable across the various
361/// derived parsers.
362class AsmParser {
363public:
364 AsmParser() = default;
365 virtual ~AsmParser();
366
367 MLIRContext *getContext() const;
368
369 /// Return the location of the original name token.
370 virtual llvm::SMLoc getNameLoc() const = 0;
371
372 //===--------------------------------------------------------------------===//
373 // Utilities
374 //===--------------------------------------------------------------------===//
375
376 /// Emit a diagnostic at the specified location and return failure.
377 virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
378 const Twine &message = {}) = 0;
379
380 /// Return a builder which provides useful access to MLIRContext, global
381 /// objects like types and attributes.
382 virtual Builder &getBuilder() const = 0;
383
384 /// Get the location of the next token and store it into the argument. This
385 /// always succeeds.
386 virtual llvm::SMLoc getCurrentLocation() = 0;
387 ParseResult getCurrentLocation(llvm::SMLoc *loc) {
388 *loc = getCurrentLocation();
389 return success();
390 }
391
392 /// Re-encode the given source location as an MLIR location and return it.
393 /// Note: This method should only be used when a `Location` is necessary, as
394 /// the encoding process is not efficient.
395 virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
396
397 //===--------------------------------------------------------------------===//
398 // Token Parsing
399 //===--------------------------------------------------------------------===//
400
401 /// Parse a '->' token.
402 virtual ParseResult parseArrow() = 0;
403
404 /// Parse a '->' token if present
405 virtual ParseResult parseOptionalArrow() = 0;
406
407 /// Parse a `{` token.
408 virtual ParseResult parseLBrace() = 0;
409
410 /// Parse a `{` token if present.
411 virtual ParseResult parseOptionalLBrace() = 0;
412
413 /// Parse a `}` token.
414 virtual ParseResult parseRBrace() = 0;
415
416 /// Parse a `}` token if present.
417 virtual ParseResult parseOptionalRBrace() = 0;
418
419 /// Parse a `:` token.
420 virtual ParseResult parseColon() = 0;
421
422 /// Parse a `:` token if present.
423 virtual ParseResult parseOptionalColon() = 0;
424
425 /// Parse a `,` token.
426 virtual ParseResult parseComma() = 0;
427
428 /// Parse a `,` token if present.
429 virtual ParseResult parseOptionalComma() = 0;
430
431 /// Parse a `=` token.
432 virtual ParseResult parseEqual() = 0;
433
434 /// Parse a `=` token if present.
435 virtual ParseResult parseOptionalEqual() = 0;
436
437 /// Parse a '<' token.
438 virtual ParseResult parseLess() = 0;
439
440 /// Parse a '<' token if present.
441 virtual ParseResult parseOptionalLess() = 0;
442
443 /// Parse a '>' token.
444 virtual ParseResult parseGreater() = 0;
445
446 /// Parse a '>' token if present.
447 virtual ParseResult parseOptionalGreater() = 0;
448
449 /// Parse a '?' token.
450 virtual ParseResult parseQuestion() = 0;
451
452 /// Parse a '?' token if present.
453 virtual ParseResult parseOptionalQuestion() = 0;
454
455 /// Parse a '+' token.
456 virtual ParseResult parsePlus() = 0;
457
458 /// Parse a '+' token if present.
459 virtual ParseResult parseOptionalPlus() = 0;
460
461 /// Parse a '*' token.
462 virtual ParseResult parseStar() = 0;
463
464 /// Parse a '*' token if present.
465 virtual ParseResult parseOptionalStar() = 0;
466
467 /// Parse a quoted string token.
468 ParseResult parseString(std::string *string) {
469 auto loc = getCurrentLocation();
470 if (parseOptionalString(string))
471 return emitError(loc, "expected string");
472 return success();
473 }
474
475 /// Parse a quoted string token if present.
476 virtual ParseResult parseOptionalString(std::string *string) = 0;
477
478 /// Parse a given keyword.
479 ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
480 auto loc = getCurrentLocation();
481 if (parseOptionalKeyword(keyword))
482 return emitError(loc, "expected '") << keyword << "'" << msg;
483 return success();
484 }
485
486 /// Parse a keyword into 'keyword'.
487 ParseResult parseKeyword(StringRef *keyword) {
488 auto loc = getCurrentLocation();
489 if (parseOptionalKeyword(keyword))
490 return emitError(loc, "expected valid keyword");
491 return success();
492 }
493
494 /// Parse the given keyword if present.
495 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
496
497 /// Parse a keyword, if present, into 'keyword'.
498 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
499
500 /// Parse a keyword, if present, and if one of the 'allowedValues',
501 /// into 'keyword'
502 virtual ParseResult
503 parseOptionalKeyword(StringRef *keyword,
504 ArrayRef<StringRef> allowedValues) = 0;
505
506 /// Parse a keyword or a quoted string.
507 ParseResult parseKeywordOrString(std::string *result) {
508 if (failed(parseOptionalKeywordOrString(result)))
509 return emitError(getCurrentLocation())
510 << "expected valid keyword or string";
511 return success();
512 }
513
514 /// Parse an optional keyword or string.
515 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
516
517 /// Parse a `(` token.
518 virtual ParseResult parseLParen() = 0;
519
520 /// Parse a `(` token if present.
521 virtual ParseResult parseOptionalLParen() = 0;
522
523 /// Parse a `)` token.
524 virtual ParseResult parseRParen() = 0;
525
526 /// Parse a `)` token if present.
527 virtual ParseResult parseOptionalRParen() = 0;
528
529 /// Parse a `[` token.
530 virtual ParseResult parseLSquare() = 0;
531
532 /// Parse a `[` token if present.
533 virtual ParseResult parseOptionalLSquare() = 0;
534
535 /// Parse a `]` token.
536 virtual ParseResult parseRSquare() = 0;
537
538 /// Parse a `]` token if present.
539 virtual ParseResult parseOptionalRSquare() = 0;
540
541 /// Parse a `...` token if present;
542 virtual ParseResult parseOptionalEllipsis() = 0;
543
544 /// Parse a floating point value from the stream.
545 virtual ParseResult parseFloat(double &result) = 0;
546
547 /// Parse an integer value from the stream.
548 template <typename IntT>
549 ParseResult parseInteger(IntT &result) {
550 auto loc = getCurrentLocation();
551 OptionalParseResult parseResult = parseOptionalInteger(result);
31
Calling 'AsmParser::parseOptionalInteger'
35
Returning from 'AsmParser::parseOptionalInteger'
552 if (!parseResult.hasValue())
36
Taking false branch
553 return emitError(loc, "expected integer value");
554 return *parseResult;
37
Returning without writing to 'result'
555 }
556
557 /// Parse an optional integer value from the stream.
558 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
559
560 template <typename IntT>
561 OptionalParseResult parseOptionalInteger(IntT &result) {
562 auto loc = getCurrentLocation();
563
564 // Parse the unsigned variant.
565 APInt uintResult;
566 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
567 if (!parseResult.hasValue() || failed(*parseResult))
32
Assuming the condition is false
33
Taking true branch
568 return parseResult;
34
Returning without writing to 'result'
569
570 // Try to convert to the provided integer type. sextOrTrunc is correct even
571 // for unsigned types because parseOptionalInteger ensures the sign bit is
572 // zero for non-negated integers.
573 result =
574 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
575 if (APInt(uintResult.getBitWidth(), result) != uintResult)
576 return emitError(loc, "integer value too large");
577 return success();
578 }
579
580 /// These are the supported delimiters around operand lists and region
581 /// argument lists, used by parseOperandList and parseRegionArgumentList.
582 enum class Delimiter {
583 /// Zero or more operands with no delimiters.
584 None,
585 /// Parens surrounding zero or more operands.
586 Paren,
587 /// Square brackets surrounding zero or more operands.
588 Square,
589 /// <> brackets surrounding zero or more operands.
590 LessGreater,
591 /// {} brackets surrounding zero or more operands.
592 Braces,
593 /// Parens supporting zero or more operands, or nothing.
594 OptionalParen,
595 /// Square brackets supporting zero or more ops, or nothing.
596 OptionalSquare,
597 /// <> brackets supporting zero or more ops, or nothing.
598 OptionalLessGreater,
599 /// {} brackets surrounding zero or more operands, or nothing.
600 OptionalBraces,
601 };
602
603 /// Parse a list of comma-separated items with an optional delimiter. If a
604 /// delimiter is provided, then an empty list is allowed. If not, then at
605 /// least one element will be parsed.
606 ///
607 /// contextMessage is an optional message appended to "expected '('" sorts of
608 /// diagnostics when parsing the delimeters.
609 virtual ParseResult
610 parseCommaSeparatedList(Delimiter delimiter,
611 function_ref<ParseResult()> parseElementFn,
612 StringRef contextMessage = StringRef()) = 0;
613
614 /// Parse a comma separated list of elements that must have at least one entry
615 /// in it.
616 ParseResult
617 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
618 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
619 }
620
621 //===--------------------------------------------------------------------===//
622 // Attribute/Type Parsing
623 //===--------------------------------------------------------------------===//
624
625 /// Invoke the `getChecked` method of the given Attribute or Type class, using
626 /// the provided location to emit errors in the case of failure. Note that
627 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
628 /// context parameter.
629 template <typename T, typename... ParamsT>
630 T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
631 return T::getChecked([&] { return emitError(loc); },
632 std::forward<ParamsT>(params)...);
633 }
634 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
635 /// errors.
636 template <typename T, typename... ParamsT>
637 T getChecked(ParamsT &&... params) {
638 return T::getChecked([&] { return emitError(getNameLoc()); },
639 std::forward<ParamsT>(params)...);
640 }
641
642 //===--------------------------------------------------------------------===//
643 // Attribute Parsing
644 //===--------------------------------------------------------------------===//
645
646 /// Parse an arbitrary attribute of a given type and return it in result.
647 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
648
649 /// Parse a custom attribute with the provided callback, unless the next
650 /// token is `#`, in which case the generic parser is invoked.
651 virtual ParseResult parseCustomAttributeWithFallback(
652 Attribute &result, Type type,
653 function_ref<ParseResult(Attribute &result, Type type)>
654 parseAttribute) = 0;
655
656 /// Parse an attribute of a specific kind and type.
657 template <typename AttrType>
658 ParseResult parseAttribute(AttrType &result, Type type = {}) {
659 llvm::SMLoc loc = getCurrentLocation();
660
661 // Parse any kind of attribute.
662 Attribute attr;
663 if (parseAttribute(attr, type))
664 return failure();
665
666 // Check for the right kind of attribute.
667 if (!(result = attr.dyn_cast<AttrType>()))
668 return emitError(loc, "invalid kind of attribute specified");
669
670 return success();
671 }
672
673 /// Parse an arbitrary attribute and return it in result. This also adds the
674 /// attribute to the specified attribute list with the specified name.
675 ParseResult parseAttribute(Attribute &result, StringRef attrName,
676 NamedAttrList &attrs) {
677 return parseAttribute(result, Type(), attrName, attrs);
678 }
679
680 /// Parse an attribute of a specific kind and type.
681 template <typename AttrType>
682 ParseResult parseAttribute(AttrType &result, StringRef attrName,
683 NamedAttrList &attrs) {
684 return parseAttribute(result, Type(), attrName, attrs);
685 }
686
687 /// Parse an arbitrary attribute of a given type and populate it in `result`.
688 /// This also adds the attribute to the specified attribute list with the
689 /// specified name.
690 template <typename AttrType>
691 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
692 NamedAttrList &attrs) {
693 llvm::SMLoc loc = getCurrentLocation();
694
695 // Parse any kind of attribute.
696 Attribute attr;
697 if (parseAttribute(attr, type))
698 return failure();
699
700 // Check for the right kind of attribute.
701 result = attr.dyn_cast<AttrType>();
702 if (!result)
703 return emitError(loc, "invalid kind of attribute specified");
704
705 attrs.append(attrName, result);
706 return success();
707 }
708
709 /// Trait to check if `AttrType` provides a `parse` method.
710 template <typename AttrType>
711 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
712 std::declval<Type>()));
713 template <typename AttrType>
714 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
715
716 /// Parse a custom attribute of a given type unless the next token is `#`, in
717 /// which case the generic parser is invoked. The parsed attribute is
718 /// populated in `result` and also added to the specified attribute list with
719 /// the specified name.
720 template <typename AttrType>
721 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
722 parseCustomAttributeWithFallback(AttrType &result, Type type,
723 StringRef attrName, NamedAttrList &attrs) {
724 llvm::SMLoc loc = getCurrentLocation();
725
726 // Parse any kind of attribute.
727 Attribute attr;
728 if (parseCustomAttributeWithFallback(
729 attr, type, [&](Attribute &result, Type type) -> ParseResult {
730 result = AttrType::parse(*this, type);
731 if (!result)
732 return failure();
733 return success();
734 }))
735 return failure();
736
737 // Check for the right kind of attribute.
738 result = attr.dyn_cast<AttrType>();
739 if (!result)
740 return emitError(loc, "invalid kind of attribute specified");
741
742 attrs.append(attrName, result);
743 return success();
744 }
745
746 /// SFINAE parsing method for Attribute that don't implement a parse method.
747 template <typename AttrType>
748 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
749 parseCustomAttributeWithFallback(AttrType &result, Type type,
750 StringRef attrName, NamedAttrList &attrs) {
751 return parseAttribute(result, type, attrName, attrs);
752 }
753
754 /// Parse a custom attribute of a given type unless the next token is `#`, in
755 /// which case the generic parser is invoked. The parsed attribute is
756 /// populated in `result`.
757 template <typename AttrType>
758 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
759 parseCustomAttributeWithFallback(AttrType &result) {
760 llvm::SMLoc loc = getCurrentLocation();
761
762 // Parse any kind of attribute.
763 Attribute attr;
764 if (parseCustomAttributeWithFallback(
765 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
766 result = AttrType::parse(*this, type);
767 return success(!!result);
768 }))
769 return failure();
770
771 // Check for the right kind of attribute.
772 result = attr.dyn_cast<AttrType>();
773 if (!result)
774 return emitError(loc, "invalid kind of attribute specified");
775 return success();
776 }
777
778 /// SFINAE parsing method for Attribute that don't implement a parse method.
779 template <typename AttrType>
780 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
781 parseCustomAttributeWithFallback(AttrType &result) {
782 return parseAttribute(result);
783 }
784
785 /// Parse an arbitrary optional attribute of a given type and return it in
786 /// result.
787 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
788 Type type = {}) = 0;
789
790 /// Parse an optional array attribute and return it in result.
791 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
792 Type type = {}) = 0;
793
794 /// Parse an optional string attribute and return it in result.
795 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
796 Type type = {}) = 0;
797
798 /// Parse an optional attribute of a specific type and add it to the list with
799 /// the specified name.
800 template <typename AttrType>
801 OptionalParseResult parseOptionalAttribute(AttrType &result,
802 StringRef attrName,
803 NamedAttrList &attrs) {
804 return parseOptionalAttribute(result, Type(), attrName, attrs);
805 }
806
807 /// Parse an optional attribute of a specific type and add it to the list with
808 /// the specified name.
809 template <typename AttrType>
810 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
811 StringRef attrName,
812 NamedAttrList &attrs) {
813 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
814 if (parseResult.hasValue() && succeeded(*parseResult))
815 attrs.append(attrName, result);
816 return parseResult;
817 }
818
819 /// Parse a named dictionary into 'result' if it is present.
820 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
821
822 /// Parse a named dictionary into 'result' if the `attributes` keyword is
823 /// present.
824 virtual ParseResult
825 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
826
827 /// Parse an affine map instance into 'map'.
828 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
829
830 /// Parse an integer set instance into 'set'.
831 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
832
833 //===--------------------------------------------------------------------===//
834 // Identifier Parsing
835 //===--------------------------------------------------------------------===//
836
837 /// Parse an @-identifier and store it (without the '@' symbol) in a string
838 /// attribute named 'attrName'.
839 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
840 NamedAttrList &attrs) {
841 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
842 return emitError(getCurrentLocation())
843 << "expected valid '@'-identifier for symbol name";
844 return success();
845 }
846
847 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
848 /// string attribute named 'attrName'.
849 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
850 StringRef attrName,
851 NamedAttrList &attrs) = 0;
852
853 //===--------------------------------------------------------------------===//
854 // Type Parsing
855 //===--------------------------------------------------------------------===//
856
857 /// Parse a type.
858 virtual ParseResult parseType(Type &result) = 0;
859
860 /// Parse a custom type with the provided callback, unless the next
861 /// token is `#`, in which case the generic parser is invoked.
862 virtual ParseResult parseCustomTypeWithFallback(
863 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
864
865 /// Parse an optional type.
866 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
867
868 /// Parse a type of a specific type.
869 template <typename TypeT>
870 ParseResult parseType(TypeT &result) {
871 llvm::SMLoc loc = getCurrentLocation();
872
873 // Parse any kind of type.
874 Type type;
875 if (parseType(type))
876 return failure();
877
878 // Check for the right kind of type.
879 result = type.dyn_cast<TypeT>();
880 if (!result)
881 return emitError(loc, "invalid kind of type specified");
882
883 return success();
884 }
885
886 /// Trait to check if `TypeT` provides a `parse` method.
887 template <typename TypeT>
888 using type_has_parse_method =
889 decltype(TypeT::parse(std::declval<AsmParser &>()));
890 template <typename TypeT>
891 using detect_type_has_parse_method =
892 llvm::is_detected<type_has_parse_method, TypeT>;
893
894 /// Parse a custom Type of a given type unless the next token is `#`, in
895 /// which case the generic parser is invoked. The parsed Type is
896 /// populated in `result`.
897 template <typename TypeT>
898 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
899 parseCustomTypeWithFallback(TypeT &result) {
900 llvm::SMLoc loc = getCurrentLocation();
901
902 // Parse any kind of Type.
903 Type type;
904 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
905 result = TypeT::parse(*this);
906 return success(!!result);
907 }))
908 return failure();
909
910 // Check for the right kind of Type.
911 result = type.dyn_cast<TypeT>();
912 if (!result)
913 return emitError(loc, "invalid kind of Type specified");
914 return success();
915 }
916
917 /// SFINAE parsing method for Type that don't implement a parse method.
918 template <typename TypeT>
919 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
920 parseCustomTypeWithFallback(TypeT &result) {
921 return parseType(result);
922 }
923
924 /// Parse a type list.
925 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
926 do {
927 Type type;
928 if (parseType(type))
929 return failure();
930 result.push_back(type);
931 } while (succeeded(parseOptionalComma()));
932 return success();
933 }
934
935 /// Parse an arrow followed by a type list.
936 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
937
938 /// Parse an optional arrow followed by a type list.
939 virtual ParseResult
940 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
941
942 /// Parse a colon followed by a type.
943 virtual ParseResult parseColonType(Type &result) = 0;
944
945 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
946 template <typename TypeType>
947 ParseResult parseColonType(TypeType &result) {
948 llvm::SMLoc loc = getCurrentLocation();
949
950 // Parse any kind of type.
951 Type type;
952 if (parseColonType(type))
953 return failure();
954
955 // Check for the right kind of type.
956 result = type.dyn_cast<TypeType>();
957 if (!result)
958 return emitError(loc, "invalid kind of type specified");
959
960 return success();
961 }
962
963 /// Parse a colon followed by a type list, which must have at least one type.
964 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
965
966 /// Parse an optional colon followed by a type list, which if present must
967 /// have at least one type.
968 virtual ParseResult
969 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
970
971 /// Parse a keyword followed by a type.
972 ParseResult parseKeywordType(const char *keyword, Type &result) {
973 return failure(parseKeyword(keyword) || parseType(result));
974 }
975
976 /// Add the specified type to the end of the specified type list and return
977 /// success. This is a helper designed to allow parse methods to be simple
978 /// and chain through || operators.
979 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
980 result.push_back(type);
981 return success();
982 }
983
984 /// Add the specified types to the end of the specified type list and return
985 /// success. This is a helper designed to allow parse methods to be simple
986 /// and chain through || operators.
987 ParseResult addTypesToList(ArrayRef<Type> types,
988 SmallVectorImpl<Type> &result) {
989 result.append(types.begin(), types.end());
990 return success();
991 }
992
993 /// Parse a 'x' separated dimension list. This populates the dimension list,
994 /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
995 /// `?` otherwise.
996 ///
997 /// dimension-list ::= (dimension `x`)*
998 /// dimension ::= `?` | integer
999 ///
1000 /// When `allowDynamic` is not set, this is used to parse:
1001 ///
1002 /// static-dimension-list ::= (integer `x`)*
1003 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1004 bool allowDynamic = true) = 0;
1005
1006 /// Parse an 'x' token in a dimension list, handling the case where the x is
1007 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1008 /// next token.
1009 virtual ParseResult parseXInDimensionList() = 0;
1010
1011private:
1012 AsmParser(const AsmParser &) = delete;
1013 void operator=(const AsmParser &) = delete;
1014};
1015
1016//===----------------------------------------------------------------------===//
1017// OpAsmParser
1018//===----------------------------------------------------------------------===//
1019
1020/// The OpAsmParser has methods for interacting with the asm parser: parsing
1021/// things from it, emitting errors etc. It has an intentionally high-level API
1022/// that is designed to reduce/constrain syntax innovation in individual
1023/// operations.
1024///
1025/// For example, consider an op like this:
1026///
1027/// %x = load %p[%1, %2] : memref<...>
1028///
1029/// The "%x = load" tokens are already parsed and therefore invisible to the
1030/// custom op parser. This can be supported by calling `parseOperandList` to
1031/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1032/// parse the indices, then calling `parseColonTypeList` to parse the result
1033/// type.
1034///
1035class OpAsmParser : public AsmParser {
1036public:
1037 using AsmParser::AsmParser;
1038 ~OpAsmParser() override;
1039
1040 /// Parse a loc(...) specifier if present, filling in result if so.
1041 /// Location for BlockArgument and Operation may be deferred with an alias, in
1042 /// which case an OpaqueLoc is set and will be resolved when parsing
1043 /// completes.
1044 virtual ParseResult
1045 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1046
1047 /// Return the name of the specified result in the specified syntax, as well
1048 /// as the sub-element in the name. It returns an empty string and ~0U for
1049 /// invalid result numbers. For example, in this operation:
1050 ///
1051 /// %x, %y:2, %z = foo.op
1052 ///
1053 /// getResultName(0) == {"x", 0 }
1054 /// getResultName(1) == {"y", 0 }
1055 /// getResultName(2) == {"y", 1 }
1056 /// getResultName(3) == {"z", 0 }
1057 /// getResultName(4) == {"", ~0U }
1058 virtual std::pair<StringRef, unsigned>
1059 getResultName(unsigned resultNo) const = 0;
1060
1061 /// Return the number of declared SSA results. This returns 4 for the foo.op
1062 /// example in the comment for `getResultName`.
1063 virtual size_t getNumResults() const = 0;
1064
1065 // These methods emit an error and return failure or success. This allows
1066 // these to be chained together into a linear sequence of || expressions in
1067 // many cases.
1068
1069 /// Parse an operation in its generic form.
1070 /// The parsed operation is parsed in the current context and inserted in the
1071 /// provided block and insertion point. The results produced by this operation
1072 /// aren't mapped to any named value in the parser. Returns nullptr on
1073 /// failure.
1074 virtual Operation *parseGenericOperation(Block *insertBlock,
1075 Block::iterator insertPt) = 0;
1076
1077 /// Parse the name of an operation, in the custom form. On success, return a
1078 /// an object of type 'OperationName'. Otherwise, failure is returned.
1079 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1080
1081 //===--------------------------------------------------------------------===//
1082 // Operand Parsing
1083 //===--------------------------------------------------------------------===//
1084
1085 /// This is the representation of an operand reference.
1086 struct OperandType {
1087 llvm::SMLoc location; // Location of the token.
1088 StringRef name; // Value name, e.g. %42 or %abc
1089 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1090 };
1091
1092 /// Parse different components, viz., use-info of operand(s), successor(s),
1093 /// region(s), attribute(s) and function-type, of the generic form of an
1094 /// operation instance and populate the input operation-state 'result' with
1095 /// those components. If any of the components is explicitly provided, then
1096 /// skip parsing that component.
1097 virtual ParseResult parseGenericOperationAfterOpName(
1098 OperationState &result,
1099 Optional<ArrayRef<OperandType>> parsedOperandType = llvm::None,
1100 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1101 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1102 llvm::None,
1103 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1104 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1105
1106 /// Parse a single operand.
1107 virtual ParseResult parseOperand(OperandType &result) = 0;
1108
1109 /// Parse a single operand if present.
1110 virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
1111
1112 /// Parse zero or more SSA comma-separated operand references with a specified
1113 /// surrounding delimiter, and an optional required operand count.
1114 virtual ParseResult
1115 parseOperandList(SmallVectorImpl<OperandType> &result,
1116 int requiredOperandCount = -1,
1117 Delimiter delimiter = Delimiter::None) = 0;
1118 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
1119 Delimiter delimiter) {
1120 return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
1121 }
1122
1123 /// Parse zero or more trailing SSA comma-separated trailing operand
1124 /// references with a specified surrounding delimiter, and an optional
1125 /// required operand count. A leading comma is expected before the operands.
1126 virtual ParseResult
1127 parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
1128 int requiredOperandCount = -1,
1129 Delimiter delimiter = Delimiter::None) = 0;
1130 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
1131 Delimiter delimiter) {
1132 return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
1133 delimiter);
1134 }
1135
1136 /// Resolve an operand to an SSA value, emitting an error on failure.
1137 virtual ParseResult resolveOperand(const OperandType &operand, Type type,
1138 SmallVectorImpl<Value> &result) = 0;
1139
1140 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1141 /// appending the results to the list on success. This method should be used
1142 /// when all operands have the same type.
1143 ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
1144 SmallVectorImpl<Value> &result) {
1145 for (auto elt : operands)
1146 if (resolveOperand(elt, type, result))
1147 return failure();
1148 return success();
1149 }
1150
1151 /// Resolve a list of operands and a list of operand types to SSA values,
1152 /// emitting an error and returning failure, or appending the results
1153 /// to the list on success.
1154 ParseResult resolveOperands(ArrayRef<OperandType> operands,
1155 ArrayRef<Type> types, llvm::SMLoc loc,
1156 SmallVectorImpl<Value> &result) {
1157 if (operands.size() != types.size())
1158 return emitError(loc)
1159 << operands.size() << " operands present, but expected "
1160 << types.size();
1161
1162 for (unsigned i = 0, e = operands.size(); i != e; ++i)
1163 if (resolveOperand(operands[i], types[i], result))
1164 return failure();
1165 return success();
1166 }
1167 template <typename Operands>
1168 ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
1169 SmallVectorImpl<Value> &result) {
1170 return resolveOperands(std::forward<Operands>(operands),
1171 ArrayRef<Type>(type), loc, result);
1172 }
1173 template <typename Operands, typename Types>
1174 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1175 resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc,
1176 SmallVectorImpl<Value> &result) {
1177 size_t operandSize = std::distance(operands.begin(), operands.end());
1178 size_t typeSize = std::distance(types.begin(), types.end());
1179 if (operandSize != typeSize)
1180 return emitError(loc)
1181 << operandSize << " operands present, but expected " << typeSize;
1182
1183 for (auto it : llvm::zip(operands, types))
1184 if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
1185 return failure();
1186 return success();
1187 }
1188
1189 /// Parses an affine map attribute where dims and symbols are SSA operands.
1190 /// Operand values must come from single-result sources, and be valid
1191 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1192 virtual ParseResult
1193 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
1194 StringRef attrName, NamedAttrList &attrs,
1195 Delimiter delimiter = Delimiter::Square) = 0;
1196
1197 /// Parses an affine expression where dims and symbols are SSA operands.
1198 /// Operand values must come from single-result sources, and be valid
1199 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1200 virtual ParseResult
1201 parseAffineExprOfSSAIds(SmallVectorImpl<OperandType> &dimOperands,
1202 SmallVectorImpl<OperandType> &symbOperands,
1203 AffineExpr &expr) = 0;
1204
1205 //===--------------------------------------------------------------------===//
1206 // Region Parsing
1207 //===--------------------------------------------------------------------===//
1208
1209 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1210 /// moved to the op regions after the op is created. The first block of the
1211 /// region takes 'arguments' of types 'argTypes'. If `argLocations` is
1212 /// non-empty it contains a location to be attached to each argument. If
1213 /// 'enableNameShadowing' is set to true, the argument names are allowed to
1214 /// shadow the names of other existing SSA values defined above the region
1215 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1216 /// to operations that are 'IsolatedFromAbove'.
1217 virtual ParseResult parseRegion(Region &region,
1218 ArrayRef<OperandType> arguments = {},
1219 ArrayRef<Type> argTypes = {},
1220 ArrayRef<Location> argLocations = {},
1221 bool enableNameShadowing = false) = 0;
1222
1223 /// Parses a region if present.
1224 virtual OptionalParseResult
1225 parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments = {},
1226 ArrayRef<Type> argTypes = {},
1227 ArrayRef<Location> argLocations = {},
1228 bool enableNameShadowing = false) = 0;
1229
1230 /// Parses a region if present. If the region is present, a new region is
1231 /// allocated and placed in `region`. If no region is present or on failure,
1232 /// `region` remains untouched.
1233 virtual OptionalParseResult parseOptionalRegion(
1234 std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
1235 ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
1236
1237 /// Parse a region argument, this argument is resolved when calling
1238 /// 'parseRegion'.
1239 virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
1240
1241 /// Parse zero or more region arguments with a specified surrounding
1242 /// delimiter, and an optional required argument count. Region arguments
1243 /// define new values; so this also checks if values with the same names have
1244 /// not been defined yet.
1245 virtual ParseResult
1246 parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
1247 int requiredOperandCount = -1,
1248 Delimiter delimiter = Delimiter::None) = 0;
1249 virtual ParseResult
1250 parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
1251 Delimiter delimiter) {
1252 return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
1253 delimiter);
1254 }
1255
1256 /// Parse a region argument if present.
1257 virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
1258
1259 //===--------------------------------------------------------------------===//
1260 // Successor Parsing
1261 //===--------------------------------------------------------------------===//
1262
1263 /// Parse a single operation successor.
1264 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1265
1266 /// Parse an optional operation successor.
1267 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1268
1269 /// Parse a single operation successor and its operand list.
1270 virtual ParseResult
1271 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1272
1273 //===--------------------------------------------------------------------===//
1274 // Type Parsing
1275 //===--------------------------------------------------------------------===//
1276
1277 /// Parse a list of assignments of the form
1278 /// (%x1 = %y1, %x2 = %y2, ...)
1279 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
1280 SmallVectorImpl<OperandType> &rhs) {
1281 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1282 if (!result.hasValue())
1283 return emitError(getCurrentLocation(), "expected '('");
1284 return result.getValue();
1285 }
1286
1287 virtual OptionalParseResult
1288 parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
1289 SmallVectorImpl<OperandType> &rhs) = 0;
1290
1291 /// Parse a list of assignments of the form
1292 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
1293 ParseResult parseAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
1294 SmallVectorImpl<OperandType> &rhs,
1295 SmallVectorImpl<Type> &types) {
1296 OptionalParseResult result =
1297 parseOptionalAssignmentListWithTypes(lhs, rhs, types);
1298 if (!result.hasValue())
1299 return emitError(getCurrentLocation(), "expected '('");
1300 return result.getValue();
1301 }
1302
1303 virtual OptionalParseResult
1304 parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
1305 SmallVectorImpl<OperandType> &rhs,
1306 SmallVectorImpl<Type> &types) = 0;
1307
1308private:
1309 /// Parse either an operand list or a region argument list depending on
1310 /// whether isOperandList is true.
1311 ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
1312 bool isOperandList,
1313 int requiredOperandCount,
1314 Delimiter delimiter);
1315};
1316
1317//===--------------------------------------------------------------------===//
1318// Dialect OpAsm interface.
1319//===--------------------------------------------------------------------===//
1320
1321/// A functor used to set the name of the start of a result group of an
1322/// operation. See 'getAsmResultNames' below for more details.
1323using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1324
1325class OpAsmDialectInterface
1326 : public DialectInterface::Base<OpAsmDialectInterface> {
1327public:
1328 /// Holds the result of `getAlias` hook call.
1329 enum class AliasResult {
1330 /// The object (type or attribute) is not supported by the hook
1331 /// and an alias was not provided.
1332 NoAlias,
1333 /// An alias was provided, but it might be overriden by other hook.
1334 OverridableAlias,
1335 /// An alias was provided and it should be used
1336 /// (no other hooks will be checked).
1337 FinalAlias
1338 };
1339
1340 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1341
1342 /// Hooks for getting an alias identifier alias for a given symbol, that is
1343 /// not necessarily a part of this dialect. The identifier is used in place of
1344 /// the symbol when printing textual IR. These aliases must not contain `.` or
1345 /// end with a numeric digit([0-9]+).
1346 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1347 return AliasResult::NoAlias;
1348 }
1349 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1350 return AliasResult::NoAlias;
1351 }
1352
1353};
1354} // namespace mlir
1355
1356//===--------------------------------------------------------------------===//
1357// Operation OpAsm interface.
1358//===--------------------------------------------------------------------===//
1359
1360/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1361#include "mlir/IR/OpAsmInterface.h.inc"
1362
1363#endif