Bug Summary

File:mlir/lib/Dialect/Quant/IR/TypeParser.cpp
Warning:line 86, column 22
The left operand of '<' is a garbage value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name TypeParser.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Quant -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Quant -I include -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/llvm/include -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-01-19-134126-35450-1 -x c++ /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Quant/IR/TypeParser.cpp

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Quant/IR/TypeParser.cpp

1//===- TypeParser.h - Quantization Type Parser ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Quant/QuantOps.h"
10#include "mlir/Dialect/Quant/QuantTypes.h"
11#include "mlir/IR/BuiltinTypes.h"
12#include "mlir/IR/DialectImplementation.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Types.h"
15#include "llvm/ADT/APFloat.h"
16#include "llvm/ADT/StringSwitch.h"
17#include "llvm/Support/Format.h"
18#include "llvm/Support/MathExtras.h"
19#include "llvm/Support/SourceMgr.h"
20#include "llvm/Support/raw_ostream.h"
21
22using namespace mlir;
23using namespace quant;
24
25static IntegerType parseStorageType(DialectAsmParser &parser, bool &isSigned) {
26 auto typeLoc = parser.getCurrentLocation();
27 IntegerType type;
28
29 // Parse storage type (alpha_ident, integer_literal).
30 StringRef identifier;
31 unsigned storageTypeWidth = 0;
32 OptionalParseResult result = parser.parseOptionalType(type);
33 if (result.hasValue()) {
34 if (!succeeded(*result)) {
35 parser.parseType(type);
36 return nullptr;
37 }
38 isSigned = !type.isUnsigned();
39 storageTypeWidth = type.getWidth();
40 } else if (succeeded(parser.parseKeyword(&identifier))) {
41 // Otherwise, this must be an unsigned integer (`u` integer-literal).
42 if (!identifier.consume_front("u")) {
43 parser.emitError(typeLoc, "illegal storage type prefix");
44 return nullptr;
45 }
46 if (identifier.getAsInteger(10, storageTypeWidth)) {
47 parser.emitError(typeLoc, "expected storage type width");
48 return nullptr;
49 }
50 isSigned = false;
51 type = parser.getBuilder().getIntegerType(storageTypeWidth);
52 } else {
53 return nullptr;
54 }
55
56 if (storageTypeWidth == 0 ||
57 storageTypeWidth > QuantizedType::MaxStorageBits) {
58 parser.emitError(typeLoc, "illegal storage type size: ")
59 << storageTypeWidth;
60 return nullptr;
61 }
62
63 return type;
64}
65
66static ParseResult parseStorageRange(DialectAsmParser &parser,
67 IntegerType storageType, bool isSigned,
68 int64_t &storageTypeMin,
69 int64_t &storageTypeMax) {
70 int64_t defaultIntegerMin = QuantizedType::getDefaultMinimumForInteger(
71 isSigned, storageType.getWidth());
72 int64_t defaultIntegerMax = QuantizedType::getDefaultMaximumForInteger(
73 isSigned, storageType.getWidth());
74 if (failed(parser.parseOptionalLess())) {
22
Calling 'failed'
28
Returning from 'failed'
29
Taking false branch
75 storageTypeMin = defaultIntegerMin;
76 storageTypeMax = defaultIntegerMax;
77 return success();
78 }
79
80 // Explicit storage min and storage max.
81 llvm::SMLoc minLoc = parser.getCurrentLocation(), maxLoc;
82 if (parser.parseInteger(storageTypeMin) || parser.parseColon() ||
30
Calling 'AsmParser::parseInteger'
38
Returning from 'AsmParser::parseInteger'
39
Taking false branch
83 parser.getCurrentLocation(&maxLoc) ||
84 parser.parseInteger(storageTypeMax) || parser.parseGreater())
85 return failure();
86 if (storageTypeMin < defaultIntegerMin) {
40
The left operand of '<' is a garbage value
87 return parser.emitError(minLoc, "illegal storage type minimum: ")
88 << storageTypeMin;
89 }
90 if (storageTypeMax > defaultIntegerMax) {
91 return parser.emitError(maxLoc, "illegal storage type maximum: ")
92 << storageTypeMax;
93 }
94 return success();
95}
96
97static FloatType parseExpressedTypeAndRange(DialectAsmParser &parser,
98 double &min, double &max) {
99 auto typeLoc = parser.getCurrentLocation();
100 FloatType type;
101
102 if (failed(parser.parseType(type))) {
103 parser.emitError(typeLoc, "expecting float expressed type");
104 return nullptr;
105 }
106
107 // Calibrated min and max values.
108 if (parser.parseLess() || parser.parseFloat(min) || parser.parseColon() ||
109 parser.parseFloat(max) || parser.parseGreater()) {
110 parser.emitError(typeLoc, "calibrated values must be present");
111 return nullptr;
112 }
113 return type;
114}
115
116/// Parses an AnyQuantizedType.
117///
118/// any ::= `any<` storage-spec (expressed-type-spec)?`>`
119/// storage-spec ::= storage-type (`<` storage-range `>`)?
120/// storage-range ::= integer-literal `:` integer-literal
121/// storage-type ::= (`i` | `u`) integer-literal
122/// expressed-type-spec ::= `:` `f` integer-literal
123static Type parseAnyType(DialectAsmParser &parser) {
124 IntegerType storageType;
125 FloatType expressedType;
126 unsigned typeFlags = 0;
127 int64_t storageTypeMin;
128 int64_t storageTypeMax;
129
130 // Type specification.
131 if (parser.parseLess())
132 return nullptr;
133
134 // Storage type.
135 bool isSigned = false;
136 storageType = parseStorageType(parser, isSigned);
137 if (!storageType) {
138 return nullptr;
139 }
140 if (isSigned) {
141 typeFlags |= QuantizationFlags::Signed;
142 }
143
144 // Storage type range.
145 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
146 storageTypeMax)) {
147 return nullptr;
148 }
149
150 // Optional expressed type.
151 if (succeeded(parser.parseOptionalColon())) {
152 if (parser.parseType(expressedType)) {
153 return nullptr;
154 }
155 }
156
157 if (parser.parseGreater()) {
158 return nullptr;
159 }
160
161 return parser.getChecked<AnyQuantizedType>(
162 typeFlags, storageType, expressedType, storageTypeMin, storageTypeMax);
163}
164
165static ParseResult parseQuantParams(DialectAsmParser &parser, double &scale,
166 int64_t &zeroPoint) {
167 // scale[:zeroPoint]?
168 // scale.
169 if (parser.parseFloat(scale))
170 return failure();
171
172 // zero point.
173 zeroPoint = 0;
174 if (failed(parser.parseOptionalColon())) {
175 // Default zero point.
176 return success();
177 }
178
179 return parser.parseInteger(zeroPoint);
180}
181
182/// Parses a UniformQuantizedType.
183///
184/// uniform_type ::= uniform_per_layer
185/// | uniform_per_axis
186/// uniform_per_layer ::= `uniform<` storage-spec expressed-type-spec
187/// `,` scale-zero `>`
188/// uniform_per_axis ::= `uniform<` storage-spec expressed-type-spec
189/// axis-spec `,` scale-zero-list `>`
190/// storage-spec ::= storage-type (`<` storage-range `>`)?
191/// storage-range ::= integer-literal `:` integer-literal
192/// storage-type ::= (`i` | `u`) integer-literal
193/// expressed-type-spec ::= `:` `f` integer-literal
194/// axis-spec ::= `:` integer-literal
195/// scale-zero ::= float-literal `:` integer-literal
196/// scale-zero-list ::= `{` scale-zero (`,` scale-zero)* `}`
197static Type parseUniformType(DialectAsmParser &parser) {
198 IntegerType storageType;
199 FloatType expressedType;
200 unsigned typeFlags = 0;
201 int64_t storageTypeMin;
5
'storageTypeMin' declared without an initial value
202 int64_t storageTypeMax;
203 bool isPerAxis = false;
204 int32_t quantizedDimension;
205 SmallVector<double, 1> scales;
206 SmallVector<int64_t, 1> zeroPoints;
207
208 // Type specification.
209 if (parser.parseLess()) {
6
Calling 'ParseResult::operator bool'
12
Returning from 'ParseResult::operator bool'
13
Taking false branch
210 return nullptr;
211 }
212
213 // Storage type.
214 bool isSigned = false;
215 storageType = parseStorageType(parser, isSigned);
216 if (!storageType) {
14
Calling 'Type::operator!'
17
Returning from 'Type::operator!'
18
Taking false branch
217 return nullptr;
218 }
219 if (isSigned
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
18.1
'isSigned' is false
) {
19
Taking false branch
220 typeFlags |= QuantizationFlags::Signed;
221 }
222
223 // Storage type range.
224 if (parseStorageRange(parser, storageType, isSigned, storageTypeMin,
20
Passing value via 4th parameter 'storageTypeMin'
21
Calling 'parseStorageRange'
225 storageTypeMax)) {
226 return nullptr;
227 }
228
229 // Expressed type.
230 if (parser.parseColon() || parser.parseType(expressedType)) {
231 return nullptr;
232 }
233
234 // Optionally parse quantized dimension for per-axis quantization.
235 if (succeeded(parser.parseOptionalColon())) {
236 if (parser.parseInteger(quantizedDimension))
237 return nullptr;
238 isPerAxis = true;
239 }
240
241 // Comma leading into range_spec.
242 if (parser.parseComma()) {
243 return nullptr;
244 }
245
246 // Parameter specification.
247 // For per-axis, ranges are in a {} delimitted list.
248 if (isPerAxis) {
249 if (parser.parseLBrace()) {
250 return nullptr;
251 }
252 }
253
254 // Parse scales/zeroPoints.
255 llvm::SMLoc scaleZPLoc = parser.getCurrentLocation();
256 do {
257 scales.resize(scales.size() + 1);
258 zeroPoints.resize(zeroPoints.size() + 1);
259 if (parseQuantParams(parser, scales.back(), zeroPoints.back())) {
260 return nullptr;
261 }
262 } while (isPerAxis && succeeded(parser.parseOptionalComma()));
263
264 if (isPerAxis) {
265 if (parser.parseRBrace()) {
266 return nullptr;
267 }
268 }
269
270 if (parser.parseGreater()) {
271 return nullptr;
272 }
273
274 if (!isPerAxis && scales.size() > 1) {
275 return (parser.emitError(scaleZPLoc,
276 "multiple scales/zeroPoints provided, but "
277 "quantizedDimension wasn't specified"),
278 nullptr);
279 }
280
281 if (isPerAxis) {
282 ArrayRef<double> scalesRef(scales.begin(), scales.end());
283 ArrayRef<int64_t> zeroPointsRef(zeroPoints.begin(), zeroPoints.end());
284 return parser.getChecked<UniformQuantizedPerAxisType>(
285 typeFlags, storageType, expressedType, scalesRef, zeroPointsRef,
286 quantizedDimension, storageTypeMin, storageTypeMax);
287 }
288
289 return parser.getChecked<UniformQuantizedType>(
290 typeFlags, storageType, expressedType, scales.front(), zeroPoints.front(),
291 storageTypeMin, storageTypeMax);
292}
293
294/// Parses an CalibratedQuantizedType.
295///
296/// calibrated ::= `calibrated<` expressed-spec `>`
297/// expressed-spec ::= expressed-type `<` calibrated-range `>`
298/// expressed-type ::= `f` integer-literal
299/// calibrated-range ::= float-literal `:` float-literal
300static Type parseCalibratedType(DialectAsmParser &parser) {
301 FloatType expressedType;
302 double min;
303 double max;
304
305 // Type specification.
306 if (parser.parseLess())
307 return nullptr;
308
309 // Expressed type.
310 expressedType = parseExpressedTypeAndRange(parser, min, max);
311 if (!expressedType) {
312 return nullptr;
313 }
314
315 if (parser.parseGreater()) {
316 return nullptr;
317 }
318
319 return parser.getChecked<CalibratedQuantizedType>(expressedType, min, max);
320}
321
322/// Parse a type registered to this dialect.
323Type QuantizationDialect::parseType(DialectAsmParser &parser) const {
324 // All types start with an identifier that we switch on.
325 StringRef typeNameSpelling;
326 if (failed(parser.parseKeyword(&typeNameSpelling)))
1
Taking false branch
327 return nullptr;
328
329 if (typeNameSpelling == "uniform")
2
Assuming the condition is true
3
Taking true branch
330 return parseUniformType(parser);
4
Calling 'parseUniformType'
331 if (typeNameSpelling == "any")
332 return parseAnyType(parser);
333 if (typeNameSpelling == "calibrated")
334 return parseCalibratedType(parser);
335
336 parser.emitError(parser.getNameLoc(),
337 "unknown quantized type " + typeNameSpelling);
338 return nullptr;
339}
340
341static void printStorageType(QuantizedType type, DialectAsmPrinter &out) {
342 // storage type
343 unsigned storageWidth = type.getStorageTypeIntegralWidth();
344 bool isSigned = type.isSigned();
345 if (isSigned) {
346 out << "i" << storageWidth;
347 } else {
348 out << "u" << storageWidth;
349 }
350
351 // storageTypeMin and storageTypeMax if not default.
352 int64_t defaultIntegerMin =
353 QuantizedType::getDefaultMinimumForInteger(isSigned, storageWidth);
354 int64_t defaultIntegerMax =
355 QuantizedType::getDefaultMaximumForInteger(isSigned, storageWidth);
356 if (defaultIntegerMin != type.getStorageTypeMin() ||
357 defaultIntegerMax != type.getStorageTypeMax()) {
358 out << "<" << type.getStorageTypeMin() << ":" << type.getStorageTypeMax()
359 << ">";
360 }
361}
362
363static void printQuantParams(double scale, int64_t zeroPoint,
364 DialectAsmPrinter &out) {
365 out << scale;
366 if (zeroPoint != 0) {
367 out << ":" << zeroPoint;
368 }
369}
370
371/// Helper that prints a AnyQuantizedType.
372static void printAnyQuantizedType(AnyQuantizedType type,
373 DialectAsmPrinter &out) {
374 out << "any<";
375 printStorageType(type, out);
376 if (Type expressedType = type.getExpressedType()) {
377 out << ":" << expressedType;
378 }
379 out << ">";
380}
381
382/// Helper that prints a UniformQuantizedType.
383static void printUniformQuantizedType(UniformQuantizedType type,
384 DialectAsmPrinter &out) {
385 out << "uniform<";
386 printStorageType(type, out);
387 out << ":" << type.getExpressedType() << ", ";
388
389 // scheme specific parameters
390 printQuantParams(type.getScale(), type.getZeroPoint(), out);
391 out << ">";
392}
393
394/// Helper that prints a UniformQuantizedPerAxisType.
395static void printUniformQuantizedPerAxisType(UniformQuantizedPerAxisType type,
396 DialectAsmPrinter &out) {
397 out << "uniform<";
398 printStorageType(type, out);
399 out << ":" << type.getExpressedType() << ":";
400 out << type.getQuantizedDimension();
401 out << ", ";
402
403 // scheme specific parameters
404 ArrayRef<double> scales = type.getScales();
405 ArrayRef<int64_t> zeroPoints = type.getZeroPoints();
406 out << "{";
407 llvm::interleave(
408 llvm::seq<size_t>(0, scales.size()), out,
409 [&](size_t index) {
410 printQuantParams(scales[index], zeroPoints[index], out);
411 },
412 ",");
413 out << "}>";
414}
415
416/// Helper that prints a CalibratedQuantizedType.
417static void printCalibratedQuantizedType(CalibratedQuantizedType type,
418 DialectAsmPrinter &out) {
419 out << "calibrated<" << type.getExpressedType();
420 out << "<" << type.getMin() << ":" << type.getMax() << ">";
421 out << ">";
422}
423
424/// Print a type registered to this dialect.
425void QuantizationDialect::printType(Type type, DialectAsmPrinter &os) const {
426 if (auto anyType = type.dyn_cast<AnyQuantizedType>())
427 printAnyQuantizedType(anyType, os);
428 else if (auto uniformType = type.dyn_cast<UniformQuantizedType>())
429 printUniformQuantizedType(uniformType, os);
430 else if (auto perAxisType = type.dyn_cast<UniformQuantizedPerAxisType>())
431 printUniformQuantizedPerAxisType(perAxisType, os);
432 else if (auto calibratedType = type.dyn_cast<CalibratedQuantizedType>())
433 printCalibratedQuantizedType(calibratedType, os);
434 else
435 llvm_unreachable("Unhandled quantized type")::llvm::llvm_unreachable_internal("Unhandled quantized type",
"mlir/lib/Dialect/Quant/IR/TypeParser.cpp", 435)
;
436}

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include/mlir/IR/OpDefinition.h

1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24#include "llvm/Support/PointerLikeTypeTraits.h"
25
26#include <type_traits>
27
28namespace mlir {
29class Builder;
30class OpBuilder;
31
32/// This class represents success/failure for operation parsing. It is
33/// essentially a simple wrapper class around LogicalResult that allows for
34/// explicit conversion to bool. This allows for the parser to chain together
35/// parse rules without the clutter of "failed/succeeded".
36class ParseResult : public LogicalResult {
37public:
38 ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
39
40 // Allow diagnostics emitted during parsing to be converted to failure.
41 ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
42 ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
43
44 /// Failure is true in a boolean context.
45 explicit operator bool() const { return failed(); }
7
Calling 'LogicalResult::failed'
10
Returning from 'LogicalResult::failed'
11
Returning zero, which participates in a condition later
46};
47/// This class implements `Optional` functionality for ParseResult. We don't
48/// directly use Optional here, because it provides an implicit conversion
49/// to 'bool' which we want to avoid. This class is used to implement tri-state
50/// 'parseOptional' functions that may have a failure mode when parsing that
51/// shouldn't be attributed to "not present".
52class OptionalParseResult {
53public:
54 OptionalParseResult() = default;
55 OptionalParseResult(LogicalResult result) : impl(result) {}
56 OptionalParseResult(ParseResult result) : impl(result) {}
57 OptionalParseResult(const InFlightDiagnostic &)
58 : OptionalParseResult(failure()) {}
59 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
60
61 /// Returns true if we contain a valid ParseResult value.
62 bool hasValue() const { return impl.hasValue(); }
63
64 /// Access the internal ParseResult value.
65 ParseResult getValue() const { return impl.getValue(); }
66 ParseResult operator*() const { return getValue(); }
67
68private:
69 Optional<ParseResult> impl;
70};
71
72// These functions are out-of-line utilities, which avoids them being template
73// instantiated/duplicated.
74namespace impl {
75/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
76/// region's only block if it does not have a terminator already. If the region
77/// is empty, insert a new block first. `buildTerminatorOp` should return the
78/// terminator operation to insert.
79void ensureRegionTerminator(
80 Region &region, OpBuilder &builder, Location loc,
81 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
82void ensureRegionTerminator(
83 Region &region, Builder &builder, Location loc,
84 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
85
86} // namespace impl
87
88/// This is the concrete base class that holds the operation pointer and has
89/// non-generic methods that only depend on State (to avoid having them
90/// instantiated on template types that don't affect them.
91///
92/// This also has the fallback implementations of customization hooks for when
93/// they aren't customized.
94class OpState {
95public:
96 /// Ops are pointer-like, so we allow conversion to bool.
97 explicit operator bool() { return getOperation() != nullptr; }
98
99 /// This implicitly converts to Operation*.
100 operator Operation *() const { return state; }
101
102 /// Shortcut of `->` to access a member of Operation.
103 Operation *operator->() const { return state; }
104
105 /// Return the operation that this refers to.
106 Operation *getOperation() { return state; }
107
108 /// Return the context this operation belongs to.
109 MLIRContext *getContext() { return getOperation()->getContext(); }
110
111 /// Print the operation to the given stream.
112 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
113 state->print(os, flags);
114 }
115 void print(raw_ostream &os, AsmState &asmState,
116 OpPrintingFlags flags = llvm::None) {
117 state->print(os, asmState, flags);
118 }
119
120 /// Dump this operation.
121 void dump() { state->dump(); }
122
123 /// The source location the operation was defined or derived from.
124 Location getLoc() { return state->getLoc(); }
125
126 /// Return true if there are no users of any results of this operation.
127 bool use_empty() { return state->use_empty(); }
128
129 /// Remove this operation from its parent block and delete it.
130 void erase() { state->erase(); }
131
132 /// Emit an error with the op name prefixed, like "'dim' op " which is
133 /// convenient for verifiers.
134 InFlightDiagnostic emitOpError(const Twine &message = {});
135
136 /// Emit an error about fatal conditions with this operation, reporting up to
137 /// any diagnostic handlers that may be listening.
138 InFlightDiagnostic emitError(const Twine &message = {});
139
140 /// Emit a warning about this operation, reporting up to any diagnostic
141 /// handlers that may be listening.
142 InFlightDiagnostic emitWarning(const Twine &message = {});
143
144 /// Emit a remark about this operation, reporting up to any diagnostic
145 /// handlers that may be listening.
146 InFlightDiagnostic emitRemark(const Twine &message = {});
147
148 /// Walk the operation by calling the callback for each nested operation
149 /// (including this one), block or region, depending on the callback provided.
150 /// Regions, blocks and operations at the same nesting level are visited in
151 /// lexicographical order. The walk order for enclosing regions, blocks and
152 /// operations with respect to their nested ones is specified by 'Order'
153 /// (post-order by default). A callback on a block or operation is allowed to
154 /// erase that block or operation if either:
155 /// * the walk is in post-order, or
156 /// * the walk is in pre-order and the walk is skipped after the erasure.
157 /// See Operation::walk for more details.
158 template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
159 typename RetT = detail::walkResultType<FnT>>
160 RetT walk(FnT &&callback) {
161 return state->walk<Order>(std::forward<FnT>(callback));
162 }
163
164 // These are default implementations of customization hooks.
165public:
166 /// This hook returns any canonicalization pattern rewrites that the operation
167 /// supports, for use by the canonicalization pass.
168 static void getCanonicalizationPatterns(RewritePatternSet &results,
169 MLIRContext *context) {}
170
171protected:
172 /// If the concrete type didn't implement a custom verifier hook, just fall
173 /// back to this one which accepts everything.
174 LogicalResult verify() { return success(); }
175
176 /// Parse the custom form of an operation. Unless overridden, this method will
177 /// first try to get an operation parser from the op's dialect. Otherwise the
178 /// custom assembly form of an op is always rejected. Op implementations
179 /// should implement this to return failure. On success, they should fill in
180 /// result with the fields to use.
181 static ParseResult parse(OpAsmParser &parser, OperationState &result);
182
183 /// Print the operation. Unless overridden, this method will first try to get
184 /// an operation printer from the dialect. Otherwise, it prints the operation
185 /// in generic form.
186 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
187
188 /// Print an operation name, eliding the dialect prefix if necessary.
189 static void printOpName(Operation *op, OpAsmPrinter &p,
190 StringRef defaultDialect);
191
192 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
193 /// so we can cast it away here.
194 explicit OpState(Operation *state) : state(state) {}
195
196private:
197 Operation *state;
198
199 /// Allow access to internal hook implementation methods.
200 friend RegisteredOperationName;
201};
202
203// Allow comparing operators.
204inline bool operator==(OpState lhs, OpState rhs) {
205 return lhs.getOperation() == rhs.getOperation();
206}
207inline bool operator!=(OpState lhs, OpState rhs) {
208 return lhs.getOperation() != rhs.getOperation();
209}
210
211raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
212
213/// This class represents a single result from folding an operation.
214class OpFoldResult : public PointerUnion<Attribute, Value> {
215 using PointerUnion<Attribute, Value>::PointerUnion;
216
217public:
218 void dump() { llvm::errs() << *this << "\n"; }
219};
220
221/// Allow printing to a stream.
222inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
223 if (Value value = ofr.dyn_cast<Value>())
224 value.print(os);
225 else
226 ofr.dyn_cast<Attribute>().print(os);
227 return os;
228}
229
230/// Allow printing to a stream.
231inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
232 op.print(os, OpPrintingFlags().useLocalScope());
233 return os;
234}
235
236//===----------------------------------------------------------------------===//
237// Operation Trait Types
238//===----------------------------------------------------------------------===//
239
240namespace OpTrait {
241
242// These functions are out-of-line implementations of the methods in the
243// corresponding trait classes. This avoids them being template
244// instantiated/duplicated.
245namespace impl {
246OpFoldResult foldIdempotent(Operation *op);
247OpFoldResult foldInvolution(Operation *op);
248LogicalResult verifyZeroOperands(Operation *op);
249LogicalResult verifyOneOperand(Operation *op);
250LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
251LogicalResult verifyIsIdempotent(Operation *op);
252LogicalResult verifyIsInvolution(Operation *op);
253LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
254LogicalResult verifyOperandsAreFloatLike(Operation *op);
255LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
256LogicalResult verifySameTypeOperands(Operation *op);
257LogicalResult verifyZeroRegion(Operation *op);
258LogicalResult verifyOneRegion(Operation *op);
259LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
260LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
261LogicalResult verifyZeroResult(Operation *op);
262LogicalResult verifyOneResult(Operation *op);
263LogicalResult verifyNResults(Operation *op, unsigned numOperands);
264LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
265LogicalResult verifySameOperandsShape(Operation *op);
266LogicalResult verifySameOperandsAndResultShape(Operation *op);
267LogicalResult verifySameOperandsElementType(Operation *op);
268LogicalResult verifySameOperandsAndResultElementType(Operation *op);
269LogicalResult verifySameOperandsAndResultType(Operation *op);
270LogicalResult verifyResultsAreBoolLike(Operation *op);
271LogicalResult verifyResultsAreFloatLike(Operation *op);
272LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
273LogicalResult verifyIsTerminator(Operation *op);
274LogicalResult verifyZeroSuccessor(Operation *op);
275LogicalResult verifyOneSuccessor(Operation *op);
276LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
277LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
278LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
279 StringRef valueGroupName,
280 size_t expectedCount);
281LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
282LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
283LogicalResult verifyNoRegionArguments(Operation *op);
284LogicalResult verifyElementwise(Operation *op);
285LogicalResult verifyIsIsolatedFromAbove(Operation *op);
286} // namespace impl
287
288/// Helper class for implementing traits. Clients are not expected to interact
289/// with this directly, so its members are all protected.
290template <typename ConcreteType, template <typename> class TraitType>
291class TraitBase {
292protected:
293 /// Return the ultimate Operation being worked on.
294 Operation *getOperation() {
295 // We have to cast up to the trait type, then to the concrete type, then to
296 // the BaseState class in explicit hops because the concrete type will
297 // multiply derive from the (content free) TraitBase class, and we need to
298 // be able to disambiguate the path for the C++ compiler.
299 auto *trait = static_cast<TraitType<ConcreteType> *>(this);
300 auto *concrete = static_cast<ConcreteType *>(trait);
301 auto *base = static_cast<OpState *>(concrete);
302 return base->getOperation();
303 }
304};
305
306//===----------------------------------------------------------------------===//
307// Operand Traits
308
309namespace detail {
310/// Utility trait base that provides accessors for derived traits that have
311/// multiple operands.
312template <typename ConcreteType, template <typename> class TraitType>
313struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
314 using operand_iterator = Operation::operand_iterator;
315 using operand_range = Operation::operand_range;
316 using operand_type_iterator = Operation::operand_type_iterator;
317 using operand_type_range = Operation::operand_type_range;
318
319 /// Return the number of operands.
320 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
321
322 /// Return the operand at index 'i'.
323 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
324
325 /// Set the operand at index 'i' to 'value'.
326 void setOperand(unsigned i, Value value) {
327 this->getOperation()->setOperand(i, value);
328 }
329
330 /// Operand iterator access.
331 operand_iterator operand_begin() {
332 return this->getOperation()->operand_begin();
333 }
334 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
335 operand_range getOperands() { return this->getOperation()->getOperands(); }
336
337 /// Operand type access.
338 operand_type_iterator operand_type_begin() {
339 return this->getOperation()->operand_type_begin();
340 }
341 operand_type_iterator operand_type_end() {
342 return this->getOperation()->operand_type_end();
343 }
344 operand_type_range getOperandTypes() {
345 return this->getOperation()->getOperandTypes();
346 }
347};
348} // namespace detail
349
350/// This class provides the API for ops that are known to have no
351/// SSA operand.
352template <typename ConcreteType>
353class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
354public:
355 static LogicalResult verifyTrait(Operation *op) {
356 return impl::verifyZeroOperands(op);
357 }
358
359private:
360 // Disable these.
361 void getOperand() {}
362 void setOperand() {}
363};
364
365/// This class provides the API for ops that are known to have exactly one
366/// SSA operand.
367template <typename ConcreteType>
368class OneOperand : public TraitBase<ConcreteType, OneOperand> {
369public:
370 Value getOperand() { return this->getOperation()->getOperand(0); }
371
372 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
373
374 static LogicalResult verifyTrait(Operation *op) {
375 return impl::verifyOneOperand(op);
376 }
377};
378
379/// This class provides the API for ops that are known to have a specified
380/// number of operands. This is used as a trait like this:
381///
382/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
383///
384template <unsigned N>
385class NOperands {
386public:
387 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
388
389 template <typename ConcreteType>
390 class Impl
391 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
392 public:
393 static LogicalResult verifyTrait(Operation *op) {
394 return impl::verifyNOperands(op, N);
395 }
396 };
397};
398
399/// This class provides the API for ops that are known to have a at least a
400/// specified number of operands. This is used as a trait like this:
401///
402/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
403///
404template <unsigned N>
405class AtLeastNOperands {
406public:
407 template <typename ConcreteType>
408 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
409 AtLeastNOperands<N>::Impl> {
410 public:
411 static LogicalResult verifyTrait(Operation *op) {
412 return impl::verifyAtLeastNOperands(op, N);
413 }
414 };
415};
416
417/// This class provides the API for ops which have an unknown number of
418/// SSA operands.
419template <typename ConcreteType>
420class VariadicOperands
421 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
422
423//===----------------------------------------------------------------------===//
424// Region Traits
425
426/// This class provides verification for ops that are known to have zero
427/// regions.
428template <typename ConcreteType>
429class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
430public:
431 static LogicalResult verifyTrait(Operation *op) {
432 return impl::verifyZeroRegion(op);
433 }
434};
435
436namespace detail {
437/// Utility trait base that provides accessors for derived traits that have
438/// multiple regions.
439template <typename ConcreteType, template <typename> class TraitType>
440struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
441 using region_iterator = MutableArrayRef<Region>;
442 using region_range = RegionRange;
443
444 /// Return the number of regions.
445 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
446
447 /// Return the region at `index`.
448 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
449
450 /// Region iterator access.
451 region_iterator region_begin() {
452 return this->getOperation()->region_begin();
453 }
454 region_iterator region_end() { return this->getOperation()->region_end(); }
455 region_range getRegions() { return this->getOperation()->getRegions(); }
456};
457} // namespace detail
458
459/// This class provides APIs for ops that are known to have a single region.
460template <typename ConcreteType>
461class OneRegion : public TraitBase<ConcreteType, OneRegion> {
462public:
463 Region &getRegion() { return this->getOperation()->getRegion(0); }
464
465 /// Returns a range of operations within the region of this operation.
466 auto getOps() { return getRegion().getOps(); }
467 template <typename OpT>
468 auto getOps() {
469 return getRegion().template getOps<OpT>();
470 }
471
472 static LogicalResult verifyTrait(Operation *op) {
473 return impl::verifyOneRegion(op);
474 }
475};
476
477/// This class provides the API for ops that are known to have a specified
478/// number of regions.
479template <unsigned N>
480class NRegions {
481public:
482 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
483
484 template <typename ConcreteType>
485 class Impl
486 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
487 public:
488 static LogicalResult verifyTrait(Operation *op) {
489 return impl::verifyNRegions(op, N);
490 }
491 };
492};
493
494/// This class provides APIs for ops that are known to have at least a specified
495/// number of regions.
496template <unsigned N>
497class AtLeastNRegions {
498public:
499 template <typename ConcreteType>
500 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
501 AtLeastNRegions<N>::Impl> {
502 public:
503 static LogicalResult verifyTrait(Operation *op) {
504 return impl::verifyAtLeastNRegions(op, N);
505 }
506 };
507};
508
509/// This class provides the API for ops which have an unknown number of
510/// regions.
511template <typename ConcreteType>
512class VariadicRegions
513 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
514
515//===----------------------------------------------------------------------===//
516// Result Traits
517
518/// This class provides return value APIs for ops that are known to have
519/// zero results.
520template <typename ConcreteType>
521class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
522public:
523 static LogicalResult verifyTrait(Operation *op) {
524 return impl::verifyZeroResult(op);
525 }
526};
527
528namespace detail {
529/// Utility trait base that provides accessors for derived traits that have
530/// multiple results.
531template <typename ConcreteType, template <typename> class TraitType>
532struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
533 using result_iterator = Operation::result_iterator;
534 using result_range = Operation::result_range;
535 using result_type_iterator = Operation::result_type_iterator;
536 using result_type_range = Operation::result_type_range;
537
538 /// Return the number of results.
539 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
540
541 /// Return the result at index 'i'.
542 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
543
544 /// Replace all uses of results of this operation with the provided 'values'.
545 /// 'values' may correspond to an existing operation, or a range of 'Value'.
546 template <typename ValuesT>
547 void replaceAllUsesWith(ValuesT &&values) {
548 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
549 }
550
551 /// Return the type of the `i`-th result.
552 Type getType(unsigned i) { return getResult(i).getType(); }
553
554 /// Result iterator access.
555 result_iterator result_begin() {
556 return this->getOperation()->result_begin();
557 }
558 result_iterator result_end() { return this->getOperation()->result_end(); }
559 result_range getResults() { return this->getOperation()->getResults(); }
560
561 /// Result type access.
562 result_type_iterator result_type_begin() {
563 return this->getOperation()->result_type_begin();
564 }
565 result_type_iterator result_type_end() {
566 return this->getOperation()->result_type_end();
567 }
568 result_type_range getResultTypes() {
569 return this->getOperation()->getResultTypes();
570 }
571};
572} // namespace detail
573
574/// This class provides return value APIs for ops that are known to have a
575/// single result. ResultType is the concrete type returned by getType().
576template <typename ConcreteType>
577class OneResult : public TraitBase<ConcreteType, OneResult> {
578public:
579 Value getResult() { return this->getOperation()->getResult(0); }
580
581 /// If the operation returns a single value, then the Op can be implicitly
582 /// converted to an Value. This yields the value of the only result.
583 operator Value() { return getResult(); }
584
585 /// Replace all uses of 'this' value with the new value, updating anything
586 /// in the IR that uses 'this' to use the other value instead. When this
587 /// returns there are zero uses of 'this'.
588 void replaceAllUsesWith(Value newValue) {
589 getResult().replaceAllUsesWith(newValue);
590 }
591
592 /// Replace all uses of 'this' value with the result of 'op'.
593 void replaceAllUsesWith(Operation *op) {
594 this->getOperation()->replaceAllUsesWith(op);
595 }
596
597 static LogicalResult verifyTrait(Operation *op) {
598 return impl::verifyOneResult(op);
599 }
600};
601
602/// This trait is used for return value APIs for ops that are known to have a
603/// specific type other than `Type`. This allows the "getType()" member to be
604/// more specific for an op. This should be used in conjunction with OneResult,
605/// and occur in the trait list before OneResult.
606template <typename ResultType>
607class OneTypedResult {
608public:
609 /// This class provides return value APIs for ops that are known to have a
610 /// single result. ResultType is the concrete type returned by getType().
611 template <typename ConcreteType>
612 class Impl
613 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
614 public:
615 ResultType getType() {
616 auto resultTy = this->getOperation()->getResult(0).getType();
617 return resultTy.template cast<ResultType>();
618 }
619 };
620};
621
622/// This class provides the API for ops that are known to have a specified
623/// number of results. This is used as a trait like this:
624///
625/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
626///
627template <unsigned N>
628class NResults {
629public:
630 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
631
632 template <typename ConcreteType>
633 class Impl
634 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
635 public:
636 static LogicalResult verifyTrait(Operation *op) {
637 return impl::verifyNResults(op, N);
638 }
639 };
640};
641
642/// This class provides the API for ops that are known to have at least a
643/// specified number of results. This is used as a trait like this:
644///
645/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
646///
647template <unsigned N>
648class AtLeastNResults {
649public:
650 template <typename ConcreteType>
651 class Impl : public detail::MultiResultTraitBase<ConcreteType,
652 AtLeastNResults<N>::Impl> {
653 public:
654 static LogicalResult verifyTrait(Operation *op) {
655 return impl::verifyAtLeastNResults(op, N);
656 }
657 };
658};
659
660/// This class provides the API for ops which have an unknown number of
661/// results.
662template <typename ConcreteType>
663class VariadicResults
664 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
665
666//===----------------------------------------------------------------------===//
667// Terminator Traits
668
669/// This class indicates that the regions associated with this op don't have
670/// terminators.
671template <typename ConcreteType>
672class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
673
674/// This class provides the API for ops that are known to be terminators.
675template <typename ConcreteType>
676class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
677public:
678 static LogicalResult verifyTrait(Operation *op) {
679 return impl::verifyIsTerminator(op);
680 }
681};
682
683/// This class provides verification for ops that are known to have zero
684/// successors.
685template <typename ConcreteType>
686class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
687public:
688 static LogicalResult verifyTrait(Operation *op) {
689 return impl::verifyZeroSuccessor(op);
690 }
691};
692
693namespace detail {
694/// Utility trait base that provides accessors for derived traits that have
695/// multiple successors.
696template <typename ConcreteType, template <typename> class TraitType>
697struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
698 using succ_iterator = Operation::succ_iterator;
699 using succ_range = SuccessorRange;
700
701 /// Return the number of successors.
702 unsigned getNumSuccessors() {
703 return this->getOperation()->getNumSuccessors();
704 }
705
706 /// Return the successor at `index`.
707 Block *getSuccessor(unsigned i) {
708 return this->getOperation()->getSuccessor(i);
709 }
710
711 /// Set the successor at `index`.
712 void setSuccessor(Block *block, unsigned i) {
713 return this->getOperation()->setSuccessor(block, i);
714 }
715
716 /// Successor iterator access.
717 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
718 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
719 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
720};
721} // namespace detail
722
723/// This class provides APIs for ops that are known to have a single successor.
724template <typename ConcreteType>
725class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
726public:
727 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
728 void setSuccessor(Block *succ) {
729 this->getOperation()->setSuccessor(succ, 0);
730 }
731
732 static LogicalResult verifyTrait(Operation *op) {
733 return impl::verifyOneSuccessor(op);
734 }
735};
736
737/// This class provides the API for ops that are known to have a specified
738/// number of successors.
739template <unsigned N>
740class NSuccessors {
741public:
742 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
743
744 template <typename ConcreteType>
745 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
746 NSuccessors<N>::Impl> {
747 public:
748 static LogicalResult verifyTrait(Operation *op) {
749 return impl::verifyNSuccessors(op, N);
750 }
751 };
752};
753
754/// This class provides APIs for ops that are known to have at least a specified
755/// number of successors.
756template <unsigned N>
757class AtLeastNSuccessors {
758public:
759 template <typename ConcreteType>
760 class Impl
761 : public detail::MultiSuccessorTraitBase<ConcreteType,
762 AtLeastNSuccessors<N>::Impl> {
763 public:
764 static LogicalResult verifyTrait(Operation *op) {
765 return impl::verifyAtLeastNSuccessors(op, N);
766 }
767 };
768};
769
770/// This class provides the API for ops which have an unknown number of
771/// successors.
772template <typename ConcreteType>
773class VariadicSuccessors
774 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
775};
776
777//===----------------------------------------------------------------------===//
778// SingleBlock
779
780/// This class provides APIs and verifiers for ops with regions having a single
781/// block.
782template <typename ConcreteType>
783struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
784public:
785 static LogicalResult verifyTrait(Operation *op) {
786 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
787 Region &region = op->getRegion(i);
788
789 // Empty regions are fine.
790 if (region.empty())
791 continue;
792
793 // Non-empty regions must contain a single basic block.
794 if (!llvm::hasSingleElement(region))
795 return op->emitOpError("expects region #")
796 << i << " to have 0 or 1 blocks";
797
798 if (!ConcreteType::template hasTrait<NoTerminator>()) {
799 Block &block = region.front();
800 if (block.empty())
801 return op->emitOpError() << "expects a non-empty block";
802 }
803 }
804 return success();
805 }
806
807 Block *getBody(unsigned idx = 0) {
808 Region &region = this->getOperation()->getRegion(idx);
809 assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region"
) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\""
, "mlir/include/mlir/IR/OpDefinition.h", 809, __extension__ __PRETTY_FUNCTION__
))
;
810 return &region.front();
811 }
812 Region &getBodyRegion(unsigned idx = 0) {
813 return this->getOperation()->getRegion(idx);
814 }
815
816 //===------------------------------------------------------------------===//
817 // Single Region Utilities
818 //===------------------------------------------------------------------===//
819
820 /// The following are a set of methods only enabled when the parent
821 /// operation has a single region. Each of these methods take an additional
822 /// template parameter that represents the concrete operation so that we
823 /// can use SFINAE to disable the methods for non-single region operations.
824 template <typename OpT, typename T = void>
825 using enable_if_single_region =
826 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
827
828 template <typename OpT = ConcreteType>
829 enable_if_single_region<OpT, Block::iterator> begin() {
830 return getBody()->begin();
831 }
832 template <typename OpT = ConcreteType>
833 enable_if_single_region<OpT, Block::iterator> end() {
834 return getBody()->end();
835 }
836 template <typename OpT = ConcreteType>
837 enable_if_single_region<OpT, Operation &> front() {
838 return *begin();
839 }
840
841 /// Insert the operation into the back of the body.
842 template <typename OpT = ConcreteType>
843 enable_if_single_region<OpT> push_back(Operation *op) {
844 insert(Block::iterator(getBody()->end()), op);
845 }
846
847 /// Insert the operation at the given insertion point.
848 template <typename OpT = ConcreteType>
849 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
850 insert(Block::iterator(insertPt), op);
851 }
852 template <typename OpT = ConcreteType>
853 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
854 getBody()->getOperations().insert(insertPt, op);
855 }
856};
857
858//===----------------------------------------------------------------------===//
859// SingleBlockImplicitTerminator
860
861/// This class provides APIs and verifiers for ops with regions having a single
862/// block that must terminate with `TerminatorOpType`.
863template <typename TerminatorOpType>
864struct SingleBlockImplicitTerminator {
865 template <typename ConcreteType>
866 class Impl : public SingleBlock<ConcreteType> {
867 private:
868 using Base = SingleBlock<ConcreteType>;
869 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
870 /// cyclic header inclusion.
871 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
872 OperationState state(loc, TerminatorOpType::getOperationName());
873 TerminatorOpType::build(builder, state);
874 return Operation::create(state);
875 }
876
877 public:
878 /// The type of the operation used as the implicit terminator type.
879 using ImplicitTerminatorOpT = TerminatorOpType;
880
881 static LogicalResult verifyTrait(Operation *op) {
882 if (failed(Base::verifyTrait(op)))
883 return failure();
884 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
885 Region &region = op->getRegion(i);
886 // Empty regions are fine.
887 if (region.empty())
888 continue;
889 Operation &terminator = region.front().back();
890 if (isa<TerminatorOpType>(terminator))
891 continue;
892
893 return op->emitOpError("expects regions to end with '" +
894 TerminatorOpType::getOperationName() +
895 "', found '" +
896 terminator.getName().getStringRef() + "'")
897 .attachNote()
898 << "in custom textual format, the absence of terminator implies "
899 "'"
900 << TerminatorOpType::getOperationName() << '\'';
901 }
902
903 return success();
904 }
905
906 /// Ensure that the given region has the terminator required by this trait.
907 /// If OpBuilder is provided, use it to build the terminator and notify the
908 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
909 /// construct an OpBuilder with no listeners; this should only be used if no
910 /// OpBuilder is available at the call site, e.g., in the parser.
911 static void ensureTerminator(Region &region, Builder &builder,
912 Location loc) {
913 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
914 buildTerminator);
915 }
916 static void ensureTerminator(Region &region, OpBuilder &builder,
917 Location loc) {
918 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
919 buildTerminator);
920 }
921
922 //===------------------------------------------------------------------===//
923 // Single Region Utilities
924 //===------------------------------------------------------------------===//
925 using Base::getBody;
926
927 template <typename OpT, typename T = void>
928 using enable_if_single_region =
929 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
930
931 /// Insert the operation into the back of the body, before the terminator.
932 template <typename OpT = ConcreteType>
933 enable_if_single_region<OpT> push_back(Operation *op) {
934 insert(Block::iterator(getBody()->getTerminator()), op);
935 }
936
937 /// Insert the operation at the given insertion point. Note: The operation
938 /// is never inserted after the terminator, even if the insertion point is
939 /// end().
940 template <typename OpT = ConcreteType>
941 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
942 insert(Block::iterator(insertPt), op);
943 }
944 template <typename OpT = ConcreteType>
945 enable_if_single_region<OpT> insert(Block::iterator insertPt,
946 Operation *op) {
947 auto *body = getBody();
948 if (insertPt == body->end())
949 insertPt = Block::iterator(body->getTerminator());
950 body->getOperations().insert(insertPt, op);
951 }
952 };
953};
954
955/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
956/// to be used with `llvm::is_detected`.
957template <class T>
958using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
959
960/// Support to check if an operation has the SingleBlockImplicitTerminator
961/// trait. We can't just use `hasTrait` because this class is templated on a
962/// specific terminator op.
963template <class Op, bool hasTerminator =
964 llvm::is_detected<has_implicit_terminator_t, Op>::value>
965struct hasSingleBlockImplicitTerminator {
966 static constexpr bool value = std::is_base_of<
967 typename OpTrait::SingleBlockImplicitTerminator<
968 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
969 Op>::value;
970};
971template <class Op>
972struct hasSingleBlockImplicitTerminator<Op, false> {
973 static constexpr bool value = false;
974};
975
976//===----------------------------------------------------------------------===//
977// Misc Traits
978
979/// This class provides verification for ops that are known to have the same
980/// operand shape: all operands are scalars, vectors/tensors of the same
981/// shape.
982template <typename ConcreteType>
983class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
984public:
985 static LogicalResult verifyTrait(Operation *op) {
986 return impl::verifySameOperandsShape(op);
987 }
988};
989
990/// This class provides verification for ops that are known to have the same
991/// operand and result shape: both are scalars, vectors/tensors of the same
992/// shape.
993template <typename ConcreteType>
994class SameOperandsAndResultShape
995 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
996public:
997 static LogicalResult verifyTrait(Operation *op) {
998 return impl::verifySameOperandsAndResultShape(op);
999 }
1000};
1001
1002/// This class provides verification for ops that are known to have the same
1003/// operand element type (or the type itself if it is scalar).
1004///
1005template <typename ConcreteType>
1006class SameOperandsElementType
1007 : public TraitBase<ConcreteType, SameOperandsElementType> {
1008public:
1009 static LogicalResult verifyTrait(Operation *op) {
1010 return impl::verifySameOperandsElementType(op);
1011 }
1012};
1013
1014/// This class provides verification for ops that are known to have the same
1015/// operand and result element type (or the type itself if it is scalar).
1016///
1017template <typename ConcreteType>
1018class SameOperandsAndResultElementType
1019 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1020public:
1021 static LogicalResult verifyTrait(Operation *op) {
1022 return impl::verifySameOperandsAndResultElementType(op);
1023 }
1024};
1025
1026/// This class provides verification for ops that are known to have the same
1027/// operand and result type.
1028///
1029/// Note: this trait subsumes the SameOperandsAndResultShape and
1030/// SameOperandsAndResultElementType traits.
1031template <typename ConcreteType>
1032class SameOperandsAndResultType
1033 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1034public:
1035 static LogicalResult verifyTrait(Operation *op) {
1036 return impl::verifySameOperandsAndResultType(op);
1037 }
1038};
1039
1040/// This class verifies that any results of the specified op have a boolean
1041/// type, a vector thereof, or a tensor thereof.
1042template <typename ConcreteType>
1043class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1044public:
1045 static LogicalResult verifyTrait(Operation *op) {
1046 return impl::verifyResultsAreBoolLike(op);
1047 }
1048};
1049
1050/// This class verifies that any results of the specified op have a floating
1051/// point type, a vector thereof, or a tensor thereof.
1052template <typename ConcreteType>
1053class ResultsAreFloatLike
1054 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1055public:
1056 static LogicalResult verifyTrait(Operation *op) {
1057 return impl::verifyResultsAreFloatLike(op);
1058 }
1059};
1060
1061/// This class verifies that any results of the specified op have a signless
1062/// integer or index type, a vector thereof, or a tensor thereof.
1063template <typename ConcreteType>
1064class ResultsAreSignlessIntegerLike
1065 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1066public:
1067 static LogicalResult verifyTrait(Operation *op) {
1068 return impl::verifyResultsAreSignlessIntegerLike(op);
1069 }
1070};
1071
1072/// This class adds property that the operation is commutative.
1073template <typename ConcreteType>
1074class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
1075
1076/// This class adds property that the operation is an involution.
1077/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1078template <typename ConcreteType>
1079class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1080public:
1081 static LogicalResult verifyTrait(Operation *op) {
1082 static_assert(ConcreteType::template hasTrait<OneResult>(),
1083 "expected operation to produce one result");
1084 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1085 "expected operation to take one operand");
1086 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1087 "expected operation to preserve type");
1088 // Involution requires the operation to be side effect free as well
1089 // but currently this check is under a FIXME and is not actually done.
1090 return impl::verifyIsInvolution(op);
1091 }
1092
1093 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1094 return impl::foldInvolution(op);
1095 }
1096};
1097
1098/// This class adds property that the operation is idempotent.
1099/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1100/// or a binary operation "g" that satisfies g(x, x) = x.
1101template <typename ConcreteType>
1102class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1103public:
1104 static LogicalResult verifyTrait(Operation *op) {
1105 static_assert(ConcreteType::template hasTrait<OneResult>(),
1106 "expected operation to produce one result");
1107 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1108 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1109 "expected operation to take one or two operands");
1110 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1111 "expected operation to preserve type");
1112 // Idempotent requires the operation to be side effect free as well
1113 // but currently this check is under a FIXME and is not actually done.
1114 return impl::verifyIsIdempotent(op);
1115 }
1116
1117 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1118 return impl::foldIdempotent(op);
1119 }
1120};
1121
1122/// This class verifies that all operands of the specified op have a float type,
1123/// a vector thereof, or a tensor thereof.
1124template <typename ConcreteType>
1125class OperandsAreFloatLike
1126 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1127public:
1128 static LogicalResult verifyTrait(Operation *op) {
1129 return impl::verifyOperandsAreFloatLike(op);
1130 }
1131};
1132
1133/// This class verifies that all operands of the specified op have a signless
1134/// integer or index type, a vector thereof, or a tensor thereof.
1135template <typename ConcreteType>
1136class OperandsAreSignlessIntegerLike
1137 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1138public:
1139 static LogicalResult verifyTrait(Operation *op) {
1140 return impl::verifyOperandsAreSignlessIntegerLike(op);
1141 }
1142};
1143
1144/// This class verifies that all operands of the specified op have the same
1145/// type.
1146template <typename ConcreteType>
1147class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1148public:
1149 static LogicalResult verifyTrait(Operation *op) {
1150 return impl::verifySameTypeOperands(op);
1151 }
1152};
1153
1154/// This class provides the API for a sub-set of ops that are known to be
1155/// constant-like. These are non-side effecting operations with one result and
1156/// zero operands that can always be folded to a specific attribute value.
1157template <typename ConcreteType>
1158class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1159public:
1160 static LogicalResult verifyTrait(Operation *op) {
1161 static_assert(ConcreteType::template hasTrait<OneResult>(),
1162 "expected operation to produce one result");
1163 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1164 "expected operation to take zero operands");
1165 // TODO: We should verify that the operation can always be folded, but this
1166 // requires that the attributes of the op already be verified. We should add
1167 // support for verifying traits "after" the operation to enable this use
1168 // case.
1169 return success();
1170 }
1171};
1172
1173/// This class provides the API for ops that are known to be isolated from
1174/// above.
1175template <typename ConcreteType>
1176class IsIsolatedFromAbove
1177 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1178public:
1179 static LogicalResult verifyTrait(Operation *op) {
1180 return impl::verifyIsIsolatedFromAbove(op);
1181 }
1182};
1183
1184/// A trait of region holding operations that defines a new scope for polyhedral
1185/// optimization purposes. Any SSA values of 'index' type that either dominate
1186/// such an operation or are used at the top-level of such an operation
1187/// automatically become valid symbols for the polyhedral scope defined by that
1188/// operation. For more details, see `Traits.md#AffineScope`.
1189template <typename ConcreteType>
1190class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1191public:
1192 static LogicalResult verifyTrait(Operation *op) {
1193 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1194 "expected operation to have one or more regions");
1195 return success();
1196 }
1197};
1198
1199/// A trait of region holding operations that define a new scope for automatic
1200/// allocations, i.e., allocations that are freed when control is transferred
1201/// back from the operation's region. Any operations performing such allocations
1202/// (for eg. memref.alloca) will have their allocations automatically freed at
1203/// their closest enclosing operation with this trait.
1204template <typename ConcreteType>
1205class AutomaticAllocationScope
1206 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1207public:
1208 static LogicalResult verifyTrait(Operation *op) {
1209 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1210 "expected operation to have one or more regions");
1211 return success();
1212 }
1213};
1214
1215/// This class provides a verifier for ops that are expecting their parent
1216/// to be one of the given parent ops
1217template <typename... ParentOpTypes>
1218struct HasParent {
1219 template <typename ConcreteType>
1220 class Impl : public TraitBase<ConcreteType, Impl> {
1221 public:
1222 static LogicalResult verifyTrait(Operation *op) {
1223 if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1224 return success();
1225
1226 return op->emitOpError()
1227 << "expects parent op "
1228 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1229 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1230 << "'";
1231 }
1232 };
1233};
1234
1235/// A trait for operations that have an attribute specifying operand segments.
1236///
1237/// Certain operations can have multiple variadic operands and their size
1238/// relationship is not always known statically. For such cases, we need
1239/// a per-op-instance specification to divide the operands into logical groups
1240/// or segments. This can be modeled by attributes. The attribute will be named
1241/// as `operand_segment_sizes`.
1242///
1243/// This trait verifies the attribute for specifying operand segments has
1244/// the correct type (1D vector) and values (non-negative), etc.
1245template <typename ConcreteType>
1246class AttrSizedOperandSegments
1247 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1248public:
1249 static StringRef getOperandSegmentSizeAttr() {
1250 return "operand_segment_sizes";
1251 }
1252
1253 static LogicalResult verifyTrait(Operation *op) {
1254 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1255 op, getOperandSegmentSizeAttr());
1256 }
1257};
1258
1259/// Similar to AttrSizedOperandSegments but used for results.
1260template <typename ConcreteType>
1261class AttrSizedResultSegments
1262 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1263public:
1264 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1265
1266 static LogicalResult verifyTrait(Operation *op) {
1267 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1268 op, getResultSegmentSizeAttr());
1269 }
1270};
1271
1272/// This trait provides a verifier for ops that are expecting their regions to
1273/// not have any arguments
1274template <typename ConcrentType>
1275struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1276 static LogicalResult verifyTrait(Operation *op) {
1277 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1278 }
1279};
1280
1281// This trait is used to flag operations that consume or produce
1282// values of `MemRef` type where those references can be 'normalized'.
1283// TODO: Right now, the operands of an operation are either all normalizable,
1284// or not. In the future, we may want to allow some of the operands to be
1285// normalizable.
1286template <typename ConcrentType>
1287struct MemRefsNormalizable
1288 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1289
1290/// This trait tags element-wise ops on vectors or tensors.
1291///
1292/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1293/// trait. In particular, broadcasting behavior is not allowed.
1294///
1295/// An `Elementwise` op must satisfy the following properties:
1296///
1297/// 1. If any result is a vector/tensor then at least one operand must also be a
1298/// vector/tensor.
1299/// 2. If any operand is a vector/tensor then there must be at least one result
1300/// and all results must be vectors/tensors.
1301/// 3. All operand and result vector/tensor types must be of the same shape. The
1302/// shape may be dynamic in which case the op's behaviour is undefined for
1303/// non-matching shapes.
1304/// 4. The operation must be elementwise on its vector/tensor operands and
1305/// results. When applied to single-element vectors/tensors, the result must
1306/// be the same per elememnt.
1307///
1308/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1309/// interface `ElementwiseTypeInterface` that describes the container types for
1310/// which the operation is elementwise.
1311///
1312/// Rationale:
1313/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1314/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1315/// generic definition of the iteration space.
1316/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1317/// the same pattern, as otherwise lots of special handling for type
1318/// mismatches would be needed.
1319/// - 4. guarantees that no error handling is needed. Higher-level dialects
1320/// should reify any needed guards or error handling code before lowering to
1321/// an `Elementwise` op.
1322template <typename ConcreteType>
1323struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1324 static LogicalResult verifyTrait(Operation *op) {
1325 return ::mlir::OpTrait::impl::verifyElementwise(op);
1326 }
1327};
1328
1329/// This trait tags `Elementwise` operatons that can be systematically
1330/// scalarized. All vector/tensor operands and results are then replaced by
1331/// scalars of the respective element type. Semantically, this is the operation
1332/// on a single element of the vector/tensor.
1333///
1334/// Rationale:
1335/// Allow to define the vector/tensor semantics of elementwise operations based
1336/// on the same op's behavior on scalars. This provides a constructive procedure
1337/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1338///
1339/// Example:
1340/// ```
1341/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
1342/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1343/// -> tensor<?xf32>
1344/// ```
1345/// can be scalarized to
1346///
1347/// ```
1348/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1349/// : (i1, f32, f32) -> f32
1350/// ```
1351template <typename ConcreteType>
1352struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1353 static LogicalResult verifyTrait(Operation *op) {
1354 static_assert(
1355 ConcreteType::template hasTrait<Elementwise>(),
1356 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1357 return success();
1358 }
1359};
1360
1361/// This trait tags `Elementwise` operatons that can be systematically
1362/// vectorized. All scalar operands and results are then replaced by vectors
1363/// with the respective element type. Semantically, this is the operation on
1364/// multiple elements simultaneously. See also `Tensorizable`.
1365///
1366/// Rationale:
1367/// Provide the reverse to `Scalarizable` which, when chained together, allows
1368/// reasoning about the relationship between the tensor and vector case.
1369/// Additionally, it permits reasoning about promoting scalars to vectors via
1370/// broadcasting in cases like `%select_scalar_pred` below.
1371template <typename ConcreteType>
1372struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1373 static LogicalResult verifyTrait(Operation *op) {
1374 static_assert(
1375 ConcreteType::template hasTrait<Elementwise>(),
1376 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1377 return success();
1378 }
1379};
1380
1381/// This trait tags `Elementwise` operatons that can be systematically
1382/// tensorized. All scalar operands and results are then replaced by tensors
1383/// with the respective element type. Semantically, this is the operation on
1384/// multiple elements simultaneously. See also `Vectorizable`.
1385///
1386/// Rationale:
1387/// Provide the reverse to `Scalarizable` which, when chained together, allows
1388/// reasoning about the relationship between the tensor and vector case.
1389/// Additionally, it permits reasoning about promoting scalars to tensors via
1390/// broadcasting in cases like `%select_scalar_pred` below.
1391///
1392/// Examples:
1393/// ```
1394/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1395/// ```
1396/// can be tensorized to
1397/// ```
1398/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1399/// -> tensor<?xf32>
1400/// ```
1401///
1402/// ```
1403/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
1404/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1405/// ```
1406/// can be tensorized to
1407/// ```
1408/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
1409/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1410/// -> tensor<?xf32>
1411/// ```
1412template <typename ConcreteType>
1413struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1414 static LogicalResult verifyTrait(Operation *op) {
1415 static_assert(
1416 ConcreteType::template hasTrait<Elementwise>(),
1417 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1418 return success();
1419 }
1420};
1421
1422/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1423/// provide an easy way for scalar operations to conveniently generalize their
1424/// behavior to vectors/tensors, and systematize conversion between these forms.
1425bool hasElementwiseMappableTraits(Operation *op);
1426
1427} // namespace OpTrait
1428
1429//===----------------------------------------------------------------------===//
1430// Internal Trait Utilities
1431//===----------------------------------------------------------------------===//
1432
1433namespace op_definition_impl {
1434//===----------------------------------------------------------------------===//
1435// Trait Existence
1436
1437/// Returns true if this given Trait ID matches the IDs of any of the provided
1438/// trait types `Traits`.
1439template <template <typename T> class... Traits>
1440static bool hasTrait(TypeID traitID) {
1441 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1442 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1443 if (traitIDs[i] == traitID)
1444 return true;
1445 return false;
1446}
1447
1448//===----------------------------------------------------------------------===//
1449// Trait Folding
1450
1451/// Trait to check if T provides a 'foldTrait' method for single result
1452/// operations.
1453template <typename T, typename... Args>
1454using has_single_result_fold_trait = decltype(T::foldTrait(
1455 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1456template <typename T>
1457using detect_has_single_result_fold_trait =
1458 llvm::is_detected<has_single_result_fold_trait, T>;
1459/// Trait to check if T provides a general 'foldTrait' method.
1460template <typename T, typename... Args>
1461using has_fold_trait =
1462 decltype(T::foldTrait(std::declval<Operation *>(),
1463 std::declval<ArrayRef<Attribute>>(),
1464 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1465template <typename T>
1466using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1467/// Trait to check if T provides any `foldTrait` method.
1468/// NOTE: This should use std::disjunction when C++17 is available.
1469template <typename T>
1470using detect_has_any_fold_trait =
1471 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1472 detect_has_fold_trait<T>,
1473 detect_has_single_result_fold_trait<T>>;
1474
1475/// Returns the result of folding a trait that implements a `foldTrait` function
1476/// that is specialized for operations that have a single result.
1477template <typename Trait>
1478static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1479 LogicalResult>
1480foldTrait(Operation *op, ArrayRef<Attribute> operands,
1481 SmallVectorImpl<OpFoldResult> &results) {
1482 assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
1483 "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
1484 "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
;
1485 // If a previous trait has already been folded and replaced this operation, we
1486 // fail to fold this trait.
1487 if (!results.empty())
1488 return failure();
1489
1490 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1491 if (result.template dyn_cast<Value>() != op->getResult(0))
1492 results.push_back(result);
1493 return success();
1494 }
1495 return failure();
1496}
1497/// Returns the result of folding a trait that implements a generalized
1498/// `foldTrait` function that is supports any operation type.
1499template <typename Trait>
1500static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1501foldTrait(Operation *op, ArrayRef<Attribute> operands,
1502 SmallVectorImpl<OpFoldResult> &results) {
1503 // If a previous trait has already been folded and replaced this operation, we
1504 // fail to fold this trait.
1505 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1506}
1507
1508/// The internal implementation of `foldTraits` below that returns the result of
1509/// folding a set of trait types `Ts` that implement a `foldTrait` method.
1510template <typename... Ts>
1511static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1512 SmallVectorImpl<OpFoldResult> &results,
1513 std::tuple<Ts...> *) {
1514 bool anyFolded = false;
1515 (void)std::initializer_list<int>{
1516 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1517 return success(anyFolded);
1518}
1519
1520/// Given a tuple type containing a set of traits that contain a `foldTrait`
1521/// method, return the result of folding the given operation.
1522template <typename TraitTupleT>
1523static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
1524foldTraits(Operation *op, ArrayRef<Attribute> operands,
1525 SmallVectorImpl<OpFoldResult> &results) {
1526 return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1527}
1528/// A variant of the method above that is specialized when there are no traits
1529/// that contain a `foldTrait` method.
1530template <typename TraitTupleT>
1531static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
1532foldTraits(Operation *op, ArrayRef<Attribute> operands,
1533 SmallVectorImpl<OpFoldResult> &results) {
1534 return failure();
1535}
1536
1537//===----------------------------------------------------------------------===//
1538// Trait Verification
1539
1540/// Trait to check if T provides a `verifyTrait` method.
1541template <typename T, typename... Args>
1542using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1543template <typename T>
1544using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1545
1546/// The internal implementation of `verifyTraits` below that returns the result
1547/// of verifying the current operation with all of the provided trait types
1548/// `Ts`.
1549template <typename... Ts>
1550static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1551 LogicalResult result = success();
1552 (void)std::initializer_list<int>{
1553 (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1554 return result;
1555}
1556
1557/// Given a tuple type containing a set of traits that contain a
1558/// `verifyTrait` method, return the result of verifying the given operation.
1559template <typename TraitTupleT>
1560static LogicalResult verifyTraits(Operation *op) {
1561 return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1562}
1563} // namespace op_definition_impl
1564
1565//===----------------------------------------------------------------------===//
1566// Operation Definition classes
1567//===----------------------------------------------------------------------===//
1568
1569/// This provides public APIs that all operations should have. The template
1570/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1571/// are base classes by the policy pattern.
1572template <typename ConcreteType, template <typename T> class... Traits>
1573class Op : public OpState, public Traits<ConcreteType>... {
1574public:
1575 /// Inherit getOperation from `OpState`.
1576 using OpState::getOperation;
1577
1578 /// Return if this operation contains the provided trait.
1579 template <template <typename T> class Trait>
1580 static constexpr bool hasTrait() {
1581 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1582 }
1583
1584 /// Create a deep copy of this operation.
1585 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1586
1587 /// Create a partial copy of this operation without traversing into attached
1588 /// regions. The new operation will have the same number of regions as the
1589 /// original one, but they will be left empty.
1590 ConcreteType cloneWithoutRegions() {
1591 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1592 }
1593
1594 /// Return true if this "op class" can match against the specified operation.
1595 static bool classof(Operation *op) {
1596 if (auto info = op->getRegisteredInfo())
1597 return TypeID::get<ConcreteType>() == info->getTypeID();
1598#ifndef NDEBUG
1599 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1600 llvm::report_fatal_error(
1601 "classof on '" + ConcreteType::getOperationName() +
1602 "' failed due to the operation not being registered");
1603#endif
1604 return false;
1605 }
1606 /// Provide `classof` support for other OpBase derived classes, such as
1607 /// Interfaces.
1608 template <typename T>
1609 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1610 classof(const T *op) {
1611 return classof(const_cast<T *>(op)->getOperation());
1612 }
1613
1614 /// Expose the type we are instantiated on to template machinery that may want
1615 /// to introspect traits on this operation.
1616 using ConcreteOpType = ConcreteType;
1617
1618 /// This is a public constructor. Any op can be initialized to null.
1619 explicit Op() : OpState(nullptr) {}
1620 Op(std::nullptr_t) : OpState(nullptr) {}
1621
1622 /// This is a public constructor to enable access via the llvm::cast family of
1623 /// methods. This should not be used directly.
1624 explicit Op(Operation *state) : OpState(state) {}
1625
1626 /// Methods for supporting PointerLikeTypeTraits.
1627 const void *getAsOpaquePointer() const {
1628 return static_cast<const void *>((Operation *)*this);
1629 }
1630 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1631 return ConcreteOpType(
1632 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1633 }
1634
1635 /// Attach the given models as implementations of the corresponding interfaces
1636 /// for the concrete operation.
1637 template <typename... Models>
1638 static void attachInterface(MLIRContext &context) {
1639 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
1640 ConcreteType::getOperationName(), &context);
1641 if (!info)
1642 llvm::report_fatal_error(
1643 "Attempting to attach an interface to an unregistered operation " +
1644 ConcreteType::getOperationName() + ".");
1645 info->attachInterface<Models...>();
1646 }
1647
1648private:
1649 /// Trait to check if T provides a 'fold' method for a single result op.
1650 template <typename T, typename... Args>
1651 using has_single_result_fold =
1652 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1653 template <typename T>
1654 using detect_has_single_result_fold =
1655 llvm::is_detected<has_single_result_fold, T>;
1656 /// Trait to check if T provides a general 'fold' method.
1657 template <typename T, typename... Args>
1658 using has_fold = decltype(std::declval<T>().fold(
1659 std::declval<ArrayRef<Attribute>>(),
1660 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1661 template <typename T>
1662 using detect_has_fold = llvm::is_detected<has_fold, T>;
1663 /// Trait to check if T provides a 'print' method.
1664 template <typename T, typename... Args>
1665 using has_print =
1666 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1667 template <typename T>
1668 using detect_has_print = llvm::is_detected<has_print, T>;
1669 /// A tuple type containing the traits that have a `foldTrait` function.
1670 using FoldableTraitsTupleT = typename detail::FilterTypes<
1671 op_definition_impl::detect_has_any_fold_trait,
1672 Traits<ConcreteType>...>::type;
1673 /// A tuple type containing the traits that have a verify function.
1674 using VerifiableTraitsTupleT =
1675 typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1676 Traits<ConcreteType>...>::type;
1677
1678 /// Returns an interface map containing the interfaces registered to this
1679 /// operation.
1680 static detail::InterfaceMap getInterfaceMap() {
1681 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1682 }
1683
1684 /// Return the internal implementations of each of the OperationName
1685 /// hooks.
1686 /// Implementation of `FoldHookFn` OperationName hook.
1687 static OperationName::FoldHookFn getFoldHookFn() {
1688 return getFoldHookFnImpl<ConcreteType>();
1689 }
1690 /// The internal implementation of `getFoldHookFn` above that is invoked if
1691 /// the operation is single result and defines a `fold` method.
1692 template <typename ConcreteOpT>
1693 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1694 Traits<ConcreteOpT>...>::value &&
1695 detect_has_single_result_fold<ConcreteOpT>::value,
1696 OperationName::FoldHookFn>
1697 getFoldHookFnImpl() {
1698 return [](Operation *op, ArrayRef<Attribute> operands,
1699 SmallVectorImpl<OpFoldResult> &results) {
1700 return foldSingleResultHook<ConcreteOpT>(op, operands, results);
1701 };
1702 }
1703 /// The internal implementation of `getFoldHookFn` above that is invoked if
1704 /// the operation is not single result and defines a `fold` method.
1705 template <typename ConcreteOpT>
1706 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1707 Traits<ConcreteOpT>...>::value &&
1708 detect_has_fold<ConcreteOpT>::value,
1709 OperationName::FoldHookFn>
1710 getFoldHookFnImpl() {
1711 return [](Operation *op, ArrayRef<Attribute> operands,
1712 SmallVectorImpl<OpFoldResult> &results) {
1713 return foldHook<ConcreteOpT>(op, operands, results);
1714 };
1715 }
1716 /// The internal implementation of `getFoldHookFn` above that is invoked if
1717 /// the operation does not define a `fold` method.
1718 template <typename ConcreteOpT>
1719 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1720 !detect_has_fold<ConcreteOpT>::value,
1721 OperationName::FoldHookFn>
1722 getFoldHookFnImpl() {
1723 return [](Operation *op, ArrayRef<Attribute> operands,
1724 SmallVectorImpl<OpFoldResult> &results) {
1725 // In this case, we only need to fold the traits of the operation.
1726 return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
1727 results);
1728 };
1729 }
1730 /// Return the result of folding a single result operation that defines a
1731 /// `fold` method.
1732 template <typename ConcreteOpT>
1733 static LogicalResult
1734 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1735 SmallVectorImpl<OpFoldResult> &results) {
1736 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1737
1738 // If the fold failed or was in-place, try to fold the traits of the
1739 // operation.
1740 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1741 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1742 op, operands, results)))
1743 return success();
1744 return success(static_cast<bool>(result));
1745 }
1746 results.push_back(result);
1747 return success();
1748 }
1749 /// Return the result of folding an operation that defines a `fold` method.
1750 template <typename ConcreteOpT>
1751 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1752 SmallVectorImpl<OpFoldResult> &results) {
1753 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1754
1755 // If the fold failed or was in-place, try to fold the traits of the
1756 // operation.
1757 if (failed(result) || results.empty()) {
1758 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1759 op, operands, results)))
1760 return success();
1761 }
1762 return result;
1763 }
1764
1765 /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1766 static OperationName::GetCanonicalizationPatternsFn
1767 getGetCanonicalizationPatternsFn() {
1768 return &ConcreteType::getCanonicalizationPatterns;
1769 }
1770 /// Implementation of `GetHasTraitFn`
1771 static OperationName::HasTraitFn getHasTraitFn() {
1772 return
1773 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1774 }
1775 /// Implementation of `ParseAssemblyFn` OperationName hook.
1776 static OperationName::ParseAssemblyFn getParseAssemblyFn() {
1777 return &ConcreteType::parse;
1778 }
1779 /// Implementation of `PrintAssemblyFn` OperationName hook.
1780 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1781 return getPrintAssemblyFnImpl<ConcreteType>();
1782 }
1783 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1784 /// the concrete operation does not define a `print` method.
1785 template <typename ConcreteOpT>
1786 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1787 OperationName::PrintAssemblyFn>
1788 getPrintAssemblyFnImpl() {
1789 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1790 return OpState::print(op, printer, defaultDialect);
1791 };
1792 }
1793 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1794 /// the concrete operation defines a `print` method.
1795 template <typename ConcreteOpT>
1796 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1797 OperationName::PrintAssemblyFn>
1798 getPrintAssemblyFnImpl() {
1799 return &printAssembly;
1800 }
1801 static void printAssembly(Operation *op, OpAsmPrinter &p,
1802 StringRef defaultDialect) {
1803 OpState::printOpName(op, p, defaultDialect);
1804 return cast<ConcreteType>(op).print(p);
1805 }
1806 /// Implementation of `VerifyInvariantsFn` OperationName hook.
1807 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
1808 return &verifyInvariants;
1809 }
1810
1811 static constexpr bool hasNoDataMembers() {
1812 // Checking that the derived class does not define any member by comparing
1813 // its size to an ad-hoc EmptyOp.
1814 class EmptyOp : public Op<EmptyOp, Traits...> {};
1815 return sizeof(ConcreteType) == sizeof(EmptyOp);
1816 }
1817
1818 static LogicalResult verifyInvariants(Operation *op) {
1819 static_assert(hasNoDataMembers(),
1820 "Op class shouldn't define new data members");
1821 return failure(
1822 failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1823 failed(cast<ConcreteType>(op).verify()));
1824 }
1825
1826 /// Allow access to internal implementation methods.
1827 friend RegisteredOperationName;
1828};
1829
1830/// This class represents the base of an operation interface. See the definition
1831/// of `detail::Interface` for requirements on the `Traits` type.
1832template <typename ConcreteType, typename Traits>
1833class OpInterface
1834 : public detail::Interface<ConcreteType, Operation *, Traits,
1835 Op<ConcreteType>, OpTrait::TraitBase> {
1836public:
1837 using Base = OpInterface<ConcreteType, Traits>;
1838 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1839 Op<ConcreteType>, OpTrait::TraitBase>;
1840
1841 /// Inherit the base class constructor.
1842 using InterfaceBase::InterfaceBase;
1843
1844protected:
1845 /// Returns the impl interface instance for the given operation.
1846 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1847 OperationName name = op->getName();
1848
1849 // Access the raw interface from the operation info.
1850 if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1851 if (auto *opIface = rInfo->getInterface<ConcreteType>())
1852 return opIface;
1853 // Fallback to the dialect to provide it with a chance to implement this
1854 // interface for this operation.
1855 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
1856 op->getName());
1857 }
1858 // Fallback to the dialect to provide it with a chance to implement this
1859 // interface for this operation.
1860 if (Dialect *dialect = name.getDialect())
1861 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1862 return nullptr;
1863 }
1864
1865 /// Allow access to `getInterfaceFor`.
1866 friend InterfaceBase;
1867};
1868
1869//===----------------------------------------------------------------------===//
1870// Common Operation Folders/Parsers/Printers
1871//===----------------------------------------------------------------------===//
1872
1873// These functions are out-of-line implementations of the methods in UnaryOp and
1874// BinaryOp, which avoids them being template instantiated/duplicated.
1875namespace impl {
1876ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1877 OperationState &result);
1878
1879void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1880 Value rhs);
1881ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1882 OperationState &result);
1883
1884// Prints the given binary `op` in custom assembly form if both the two operands
1885// and the result have the same time. Otherwise, prints the generic assembly
1886// form.
1887void printOneResultOp(Operation *op, OpAsmPrinter &p);
1888} // namespace impl
1889
1890// These functions are out-of-line implementations of the methods in
1891// CastOpInterface, which avoids them being template instantiated/duplicated.
1892namespace impl {
1893/// Attempt to fold the given cast operation.
1894LogicalResult foldCastInterfaceOp(Operation *op,
1895 ArrayRef<Attribute> attrOperands,
1896 SmallVectorImpl<OpFoldResult> &foldResults);
1897/// Attempt to verify the given cast operation.
1898LogicalResult verifyCastInterfaceOp(
1899 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1900
1901// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1902// need for them, but some older ODS code in `std` still depends on them).
1903void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1904 Type destType);
1905ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1906void printCastOp(Operation *op, OpAsmPrinter &p);
1907// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
1908// when all uses have been updated. Also, consider adding functionality to
1909// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
1910// generically.
1911Value foldCastOp(Operation *op);
1912LogicalResult verifyCastOp(Operation *op,
1913 function_ref<bool(Type, Type)> areCastCompatible);
1914} // namespace impl
1915} // namespace mlir
1916
1917namespace llvm {
1918
1919template <typename T>
1920struct DenseMapInfo<
1921 T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
1922 static inline T getEmptyKey() {
1923 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1924 return T::getFromOpaquePointer(pointer);
1925 }
1926 static inline T getTombstoneKey() {
1927 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1928 return T::getFromOpaquePointer(pointer);
1929 }
1930 static unsigned getHashValue(T val) {
1931 return hash_value(val.getAsOpaquePointer());
1932 }
1933 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
1934};
1935
1936} // namespace llvm
1937
1938#endif

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include/mlir/Support/LogicalResult.h

1//===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_SUPPORT_LOGICALRESULT_H
10#define MLIR_SUPPORT_LOGICALRESULT_H
11
12#include "mlir/Support/LLVM.h"
13#include "llvm/ADT/Optional.h"
14
15namespace mlir {
16
17/// This class represents an efficient way to signal success or failure. It
18/// should be preferred over the use of `bool` when appropriate, as it avoids
19/// all of the ambiguity that arises in interpreting a boolean result. This
20/// class is marked as NODISCARD to ensure that the result is processed. Users
21/// may explicitly discard a result by using `(void)`, e.g.
22/// `(void)functionThatReturnsALogicalResult();`. Given the intended nature of
23/// this class, it generally shouldn't be used as the result of functions that
24/// very frequently have the result ignored. This class is intended to be used
25/// in conjunction with the utility functions below.
26struct LLVM_NODISCARD[[clang::warn_unused_result]] LogicalResult {
27public:
28 /// If isSuccess is true a `success` result is generated, otherwise a
29 /// 'failure' result is generated.
30 static LogicalResult success(bool isSuccess = true) {
31 return LogicalResult(isSuccess);
32 }
33
34 /// If isFailure is true a `failure` result is generated, otherwise a
35 /// 'success' result is generated.
36 static LogicalResult failure(bool isFailure = true) {
37 return success(!isFailure);
38 }
39
40 /// Returns true if the provided LogicalResult corresponds to a success value.
41 bool succeeded() const { return isSuccess; }
42
43 /// Returns true if the provided LogicalResult corresponds to a failure value.
44 bool failed() const { return !succeeded(); }
8
Assuming the condition is false
9
Returning zero, which participates in a condition later
24
Assuming the condition is false
25
Returning zero, which participates in a condition later
45
46private:
47 LogicalResult(bool isSuccess) : isSuccess(isSuccess) {}
48
49 /// Boolean indicating if this is a success result, if false this is a
50 /// failure result.
51 bool isSuccess;
52};
53
54/// Utility function to generate a LogicalResult. If isSuccess is true a
55/// `success` result is generated, otherwise a 'failure' result is generated.
56inline LogicalResult success(bool isSuccess = true) {
57 return LogicalResult::success(isSuccess);
58}
59
60/// Utility function to generate a LogicalResult. If isFailure is true a
61/// `failure` result is generated, otherwise a 'success' result is generated.
62inline LogicalResult failure(bool isFailure = true) {
63 return LogicalResult::failure(isFailure);
64}
65
66/// Utility function that returns true if the provided LogicalResult corresponds
67/// to a success value.
68inline bool succeeded(LogicalResult result) { return result.succeeded(); }
69
70/// Utility function that returns true if the provided LogicalResult corresponds
71/// to a failure value.
72inline bool failed(LogicalResult result) { return result.failed(); }
23
Calling 'LogicalResult::failed'
26
Returning from 'LogicalResult::failed'
27
Returning zero, which participates in a condition later
73
74/// This class provides support for representing a failure result, or a valid
75/// value of type `T`. This allows for integrating with LogicalResult, while
76/// also providing a value on the success path.
77template <typename T> class LLVM_NODISCARD[[clang::warn_unused_result]] FailureOr : public Optional<T> {
78public:
79 /// Allow constructing from a LogicalResult. The result *must* be a failure.
80 /// Success results should use a proper instance of type `T`.
81 FailureOr(LogicalResult result) {
82 assert(failed(result) &&(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'"
) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\""
, "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__
__PRETTY_FUNCTION__))
83 "success should be constructed with an instance of 'T'")(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'"
) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\""
, "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__
__PRETTY_FUNCTION__))
;
84 }
85 FailureOr() : FailureOr(failure()) {}
86 FailureOr(T &&y) : Optional<T>(std::forward<T>(y)) {}
87 FailureOr(const T &y) : Optional<T>(y) {}
88 template <typename U,
89 std::enable_if_t<std::is_constructible<T, U>::value> * = nullptr>
90 FailureOr(const FailureOr<U> &other)
91 : Optional<T>(failed(other) ? Optional<T>() : Optional<T>(*other)) {}
92
93 operator LogicalResult() const { return success(this->hasValue()); }
94
95private:
96 /// Hide the bool conversion as it easily creates confusion.
97 using Optional<T>::operator bool;
98 using Optional<T>::hasValue;
99};
100
101} // namespace mlir
102
103#endif // MLIR_SUPPORT_LOGICALRESULT_H

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include/mlir/IR/Types.h

1//===- Types.h - MLIR Type Classes ------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_IR_TYPES_H
10#define MLIR_IR_TYPES_H
11
12#include "mlir/IR/TypeSupport.h"
13#include "llvm/ADT/ArrayRef.h"
14#include "llvm/ADT/DenseMapInfo.h"
15#include "llvm/Support/PointerLikeTypeTraits.h"
16
17namespace mlir {
18/// Instances of the Type class are uniqued, have an immutable identifier and an
19/// optional mutable component. They wrap a pointer to the storage object owned
20/// by MLIRContext. Therefore, instances of Type are passed around by value.
21///
22/// Some types are "primitives" meaning they do not have any parameters, for
23/// example the Index type. Parametric types have additional information that
24/// differentiates the types of the same class, for example the Integer type has
25/// bitwidth, making i8 and i16 belong to the same kind by be different
26/// instances of the IntegerType. Type parameters are part of the unique
27/// immutable key. The mutable component of the type can be modified after the
28/// type is created, but cannot affect the identity of the type.
29///
30/// Types are constructed and uniqued via the 'detail::TypeUniquer' class.
31///
32/// Derived type classes are expected to implement several required
33/// implementation hooks:
34/// * Optional:
35/// - static LogicalResult verify(
36/// function_ref<InFlightDiagnostic()> emitError,
37/// Args... args)
38/// * This method is invoked when calling the 'TypeBase::get/getChecked'
39/// methods to ensure that the arguments passed in are valid to construct
40/// a type instance with.
41/// * This method is expected to return failure if a type cannot be
42/// constructed with 'args', success otherwise.
43/// * 'args' must correspond with the arguments passed into the
44/// 'TypeBase::get' call.
45///
46///
47/// Type storage objects inherit from TypeStorage and contain the following:
48/// - The dialect that defined the type.
49/// - Any parameters of the type.
50/// - An optional mutable component.
51/// For non-parametric types, a convenience DefaultTypeStorage is provided.
52/// Parametric storage types must derive TypeStorage and respect the following:
53/// - Define a type alias, KeyTy, to a type that uniquely identifies the
54/// instance of the type.
55/// * The key type must be constructible from the values passed into the
56/// detail::TypeUniquer::get call.
57/// * If the KeyTy does not have an llvm::DenseMapInfo specialization, the
58/// storage class must define a hashing method:
59/// 'static unsigned hashKey(const KeyTy &)'
60///
61/// - Provide a method, 'bool operator==(const KeyTy &) const', to
62/// compare the storage instance against an instance of the key type.
63///
64/// - Provide a static construction method:
65/// 'DerivedStorage *construct(TypeStorageAllocator &, const KeyTy &key)'
66/// that builds a unique instance of the derived storage. The arguments to
67/// this function are an allocator to store any uniqued data within the
68/// context and the key type for this storage.
69///
70/// - If they have a mutable component, this component must not be a part of
71// the key.
72class Type {
73public:
74 /// Utility class for implementing types.
75 template <typename ConcreteType, typename BaseType, typename StorageType,
76 template <typename T> class... Traits>
77 using TypeBase = detail::StorageUserBase<ConcreteType, BaseType, StorageType,
78 detail::TypeUniquer, Traits...>;
79
80 using ImplType = TypeStorage;
81
82 using AbstractTy = AbstractType;
83
84 constexpr Type() {}
85 /* implicit */ Type(const ImplType *impl)
86 : impl(const_cast<ImplType *>(impl)) {}
87
88 Type(const Type &other) = default;
89 Type &operator=(const Type &other) = default;
90
91 bool operator==(Type other) const { return impl == other.impl; }
92 bool operator!=(Type other) const { return !(*this == other); }
93 explicit operator bool() const { return impl; }
94
95 bool operator!() const { return impl == nullptr; }
15
Assuming the condition is false
16
Returning zero, which participates in a condition later
96
97 template <typename U> bool isa() const;
98 template <typename First, typename Second, typename... Rest> bool isa() const;
99 template <typename U> U dyn_cast() const;
100 template <typename U> U dyn_cast_or_null() const;
101 template <typename U> U cast() const;
102
103 // Support type casting Type to itself.
104 static bool classof(Type) { return true; }
105
106 /// Return a unique identifier for the concrete type. This is used to support
107 /// dynamic type casting.
108 TypeID getTypeID() { return impl->getAbstractType().getTypeID(); }
109
110 /// Return the MLIRContext in which this type was uniqued.
111 MLIRContext *getContext() const;
112
113 /// Get the dialect this type is registered to.
114 Dialect &getDialect() const { return impl->getAbstractType().getDialect(); }
115
116 // Convenience predicates. This is only for floating point types,
117 // derived types should use isa/dyn_cast.
118 bool isIndex() const;
119 bool isBF16() const;
120 bool isF16() const;
121 bool isF32() const;
122 bool isF64() const;
123 bool isF80() const;
124 bool isF128() const;
125
126 /// Return true if this is an integer type with the specified width.
127 bool isInteger(unsigned width) const;
128 /// Return true if this is a signless integer type (with the specified width).
129 bool isSignlessInteger() const;
130 bool isSignlessInteger(unsigned width) const;
131 /// Return true if this is a signed integer type (with the specified width).
132 bool isSignedInteger() const;
133 bool isSignedInteger(unsigned width) const;
134 /// Return true if this is an unsigned integer type (with the specified
135 /// width).
136 bool isUnsignedInteger() const;
137 bool isUnsignedInteger(unsigned width) const;
138
139 /// Return the bit width of an integer or a float type, assert failure on
140 /// other types.
141 unsigned getIntOrFloatBitWidth() const;
142
143 /// Return true if this is a signless integer or index type.
144 bool isSignlessIntOrIndex() const;
145 /// Return true if this is a signless integer, index, or float type.
146 bool isSignlessIntOrIndexOrFloat() const;
147 /// Return true of this is a signless integer or a float type.
148 bool isSignlessIntOrFloat() const;
149
150 /// Return true if this is an integer (of any signedness) or an index type.
151 bool isIntOrIndex() const;
152 /// Return true if this is an integer (of any signedness) or a float type.
153 bool isIntOrFloat() const;
154 /// Return true if this is an integer (of any signedness), index, or float
155 /// type.
156 bool isIntOrIndexOrFloat() const;
157
158 /// Print the current type.
159 void print(raw_ostream &os) const;
160 void dump() const;
161
162 friend ::llvm::hash_code hash_value(Type arg);
163
164 /// Methods for supporting PointerLikeTypeTraits.
165 const void *getAsOpaquePointer() const {
166 return static_cast<const void *>(impl);
167 }
168 static Type getFromOpaquePointer(const void *pointer) {
169 return Type(reinterpret_cast<ImplType *>(const_cast<void *>(pointer)));
170 }
171
172 /// Returns true if the type was registered with a particular trait.
173 template <template <typename T> class Trait>
174 bool hasTrait() {
175 return getAbstractType().hasTrait<Trait>();
176 }
177
178 /// Return the abstract type descriptor for this type.
179 const AbstractTy &getAbstractType() { return impl->getAbstractType(); }
180
181protected:
182 ImplType *impl{nullptr};
183};
184
185inline raw_ostream &operator<<(raw_ostream &os, Type type) {
186 type.print(os);
187 return os;
188}
189
190//===----------------------------------------------------------------------===//
191// TypeTraitBase
192//===----------------------------------------------------------------------===//
193
194namespace TypeTrait {
195/// This class represents the base of a type trait.
196template <typename ConcreteType, template <typename> class TraitType>
197using TraitBase = detail::StorageUserTraitBase<ConcreteType, TraitType>;
198} // namespace TypeTrait
199
200//===----------------------------------------------------------------------===//
201// TypeInterface
202//===----------------------------------------------------------------------===//
203
204/// This class represents the base of a type interface. See the definition of
205/// `detail::Interface` for requirements on the `Traits` type.
206template <typename ConcreteType, typename Traits>
207class TypeInterface : public detail::Interface<ConcreteType, Type, Traits, Type,
208 TypeTrait::TraitBase> {
209public:
210 using Base = TypeInterface<ConcreteType, Traits>;
211 using InterfaceBase =
212 detail::Interface<ConcreteType, Type, Traits, Type, TypeTrait::TraitBase>;
213 using InterfaceBase::InterfaceBase;
214
215private:
216 /// Returns the impl interface instance for the given type.
217 static typename InterfaceBase::Concept *getInterfaceFor(Type type) {
218 return type.getAbstractType().getInterface<ConcreteType>();
219 }
220
221 /// Allow access to 'getInterfaceFor'.
222 friend InterfaceBase;
223};
224
225//===----------------------------------------------------------------------===//
226// Type Utils
227//===----------------------------------------------------------------------===//
228
229// Make Type hashable.
230inline ::llvm::hash_code hash_value(Type arg) {
231 return DenseMapInfo<const Type::ImplType *>::getHashValue(arg.impl);
232}
233
234template <typename U> bool Type::isa() const {
235 assert(impl && "isa<> used on a null type.")(static_cast <bool> (impl && "isa<> used on a null type."
) ? void (0) : __assert_fail ("impl && \"isa<> used on a null type.\""
, "mlir/include/mlir/IR/Types.h", 235, __extension__ __PRETTY_FUNCTION__
))
;
236 return U::classof(*this);
237}
238
239template <typename First, typename Second, typename... Rest>
240bool Type::isa() const {
241 return isa<First>() || isa<Second, Rest...>();
242}
243
244template <typename U> U Type::dyn_cast() const {
245 return isa<U>() ? U(impl) : U(nullptr);
246}
247template <typename U> U Type::dyn_cast_or_null() const {
248 return (impl && isa<U>()) ? U(impl) : U(nullptr);
249}
250template <typename U> U Type::cast() const {
251 assert(isa<U>())(static_cast <bool> (isa<U>()) ? void (0) : __assert_fail
("isa<U>()", "mlir/include/mlir/IR/Types.h", 251, __extension__
__PRETTY_FUNCTION__))
;
252 return U(impl);
253}
254
255} // namespace mlir
256
257namespace llvm {
258
259// Type hash just like pointers.
260template <> struct DenseMapInfo<mlir::Type> {
261 static mlir::Type getEmptyKey() {
262 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
263 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
264 }
265 static mlir::Type getTombstoneKey() {
266 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
267 return mlir::Type(static_cast<mlir::Type::ImplType *>(pointer));
268 }
269 static unsigned getHashValue(mlir::Type val) { return mlir::hash_value(val); }
270 static bool isEqual(mlir::Type LHS, mlir::Type RHS) { return LHS == RHS; }
271};
272template <typename T>
273struct DenseMapInfo<T, std::enable_if_t<std::is_base_of<mlir::Type, T>::value>>
274 : public DenseMapInfo<mlir::Type> {
275 static T getEmptyKey() {
276 const void *pointer = llvm::DenseMapInfo<const void *>::getEmptyKey();
277 return T::getFromOpaquePointer(pointer);
278 }
279 static T getTombstoneKey() {
280 const void *pointer = llvm::DenseMapInfo<const void *>::getTombstoneKey();
281 return T::getFromOpaquePointer(pointer);
282 }
283};
284
285/// We align TypeStorage by 8, so allow LLVM to steal the low bits.
286template <> struct PointerLikeTypeTraits<mlir::Type> {
287public:
288 static inline void *getAsVoidPointer(mlir::Type I) {
289 return const_cast<void *>(I.getAsOpaquePointer());
290 }
291 static inline mlir::Type getFromVoidPointer(void *P) {
292 return mlir::Type::getFromOpaquePointer(P);
293 }
294 static constexpr int NumLowBitsAvailable = 3;
295};
296
297} // namespace llvm
298
299#endif // MLIR_IR_TYPES_H

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23
24class Builder;
25
26//===----------------------------------------------------------------------===//
27// AsmPrinter
28//===----------------------------------------------------------------------===//
29
30/// This base class exposes generic asm printer hooks, usable across the various
31/// derived printers.
32class AsmPrinter {
33public:
34 /// This class contains the internal default implementation of the base
35 /// printer methods.
36 class Impl;
37
38 /// Initialize the printer with the given internal implementation.
39 AsmPrinter(Impl &impl) : impl(&impl) {}
40 virtual ~AsmPrinter();
41
42 /// Return the raw output stream used by this printer.
43 virtual raw_ostream &getStream() const;
44
45 /// Print the given floating point value in a stabilized form that can be
46 /// roundtripped through the IR. This is the companion to the 'parseFloat'
47 /// hook on the AsmParser.
48 virtual void printFloat(const APFloat &value);
49
50 virtual void printType(Type type);
51 virtual void printAttribute(Attribute attr);
52
53 /// Trait to check if `AttrType` provides a `print` method.
54 template <typename AttrOrType>
55 using has_print_method =
56 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
57 template <typename AttrOrType>
58 using detect_has_print_method =
59 llvm::is_detected<has_print_method, AttrOrType>;
60
61 /// Print the provided attribute in the context of an operation custom
62 /// printer/parser: this will invoke directly the print method on the
63 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
64 template <typename AttrOrType,
65 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
66 *sfinae = nullptr>
67 void printStrippedAttrOrType(AttrOrType attrOrType) {
68 if (succeeded(printAlias(attrOrType)))
69 return;
70 attrOrType.print(*this);
71 }
72
73 /// SFINAE for printing the provided attribute in the context of an operation
74 /// custom printer in the case where the attribute does not define a print
75 /// method.
76 template <typename AttrOrType,
77 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
78 *sfinae = nullptr>
79 void printStrippedAttrOrType(AttrOrType attrOrType) {
80 *this << attrOrType;
81 }
82
83 /// Print the given attribute without its type. The corresponding parser must
84 /// provide a valid type for the attribute.
85 virtual void printAttributeWithoutType(Attribute attr);
86
87 /// Print the given string as a keyword, or a quoted and escaped string if it
88 /// has any special or non-printable characters in it.
89 virtual void printKeywordOrString(StringRef keyword);
90
91 /// Print the given string as a symbol reference, i.e. a form representable by
92 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
93 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
94 /// special or non-printable characters in it.
95 virtual void printSymbolName(StringRef symbolRef);
96
97 /// Print an optional arrow followed by a type list.
98 template <typename TypeRange>
99 void printOptionalArrowTypeList(TypeRange &&types) {
100 if (types.begin() != types.end())
101 printArrowTypeList(types);
102 }
103 template <typename TypeRange>
104 void printArrowTypeList(TypeRange &&types) {
105 auto &os = getStream() << " -> ";
106
107 bool wrapped = !llvm::hasSingleElement(types) ||
108 (*types.begin()).template isa<FunctionType>();
109 if (wrapped)
110 os << '(';
111 llvm::interleaveComma(types, *this);
112 if (wrapped)
113 os << ')';
114 }
115
116 /// Print the two given type ranges in a functional form.
117 template <typename InputRangeT, typename ResultRangeT>
118 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
119 auto &os = getStream();
120 os << '(';
121 llvm::interleaveComma(inputs, *this);
122 os << ')';
123 printArrowTypeList(results);
124 }
125
126protected:
127 /// Initialize the printer with no internal implementation. In this case, all
128 /// virtual methods of this class must be overriden.
129 AsmPrinter() {}
130
131private:
132 AsmPrinter(const AsmPrinter &) = delete;
133 void operator=(const AsmPrinter &) = delete;
134
135 /// Print the alias for the given attribute, return failure if no alias could
136 /// be printed.
137 virtual LogicalResult printAlias(Attribute attr);
138
139 /// Print the alias for the given type, return failure if no alias could
140 /// be printed.
141 virtual LogicalResult printAlias(Type type);
142
143 /// The internal implementation of the printer.
144 Impl *impl{nullptr};
145};
146
147template <typename AsmPrinterT>
148inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
149 AsmPrinterT &>
150operator<<(AsmPrinterT &p, Type type) {
151 p.printType(type);
152 return p;
153}
154
155template <typename AsmPrinterT>
156inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
157 AsmPrinterT &>
158operator<<(AsmPrinterT &p, Attribute attr) {
159 p.printAttribute(attr);
160 return p;
161}
162
163template <typename AsmPrinterT>
164inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
165 AsmPrinterT &>
166operator<<(AsmPrinterT &p, const APFloat &value) {
167 p.printFloat(value);
168 return p;
169}
170template <typename AsmPrinterT>
171inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
172 AsmPrinterT &>
173operator<<(AsmPrinterT &p, float value) {
174 return p << APFloat(value);
175}
176template <typename AsmPrinterT>
177inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
178 AsmPrinterT &>
179operator<<(AsmPrinterT &p, double value) {
180 return p << APFloat(value);
181}
182
183// Support printing anything that isn't convertible to one of the other
184// streamable types, even if it isn't exactly one of them. For example, we want
185// to print FunctionType with the Type version above, not have it match this.
186template <
187 typename AsmPrinterT, typename T,
188 typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
189 !std::is_convertible<T &, Type &>::value &&
190 !std::is_convertible<T &, Attribute &>::value &&
191 !std::is_convertible<T &, ValueRange>::value &&
192 !std::is_convertible<T &, APFloat &>::value &&
193 !llvm::is_one_of<T, bool, float, double>::value,
194 T>::type * = nullptr>
195inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
196 AsmPrinterT &>
197operator<<(AsmPrinterT &p, const T &other) {
198 p.getStream() << other;
199 return p;
200}
201
202template <typename AsmPrinterT>
203inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
204 AsmPrinterT &>
205operator<<(AsmPrinterT &p, bool value) {
206 return p << (value ? StringRef("true") : "false");
207}
208
209template <typename AsmPrinterT, typename ValueRangeT>
210inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
211 AsmPrinterT &>
212operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
213 llvm::interleaveComma(types, p);
214 return p;
215}
216template <typename AsmPrinterT>
217inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
218 AsmPrinterT &>
219operator<<(AsmPrinterT &p, const TypeRange &types) {
220 llvm::interleaveComma(types, p);
221 return p;
222}
223template <typename AsmPrinterT, typename ElementT>
224inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
225 AsmPrinterT &>
226operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
227 llvm::interleaveComma(types, p);
228 return p;
229}
230
231//===----------------------------------------------------------------------===//
232// OpAsmPrinter
233//===----------------------------------------------------------------------===//
234
235/// This is a pure-virtual base class that exposes the asmprinter hooks
236/// necessary to implement a custom print() method.
237class OpAsmPrinter : public AsmPrinter {
238public:
239 using AsmPrinter::AsmPrinter;
240 ~OpAsmPrinter() override;
241
242 /// Print a newline and indent the printer to the start of the current
243 /// operation.
244 virtual void printNewline() = 0;
245
246 /// Print a block argument in the usual format of:
247 /// %ssaName : type {attr1=42} loc("here")
248 /// where location printing is controlled by the standard internal option.
249 /// You may pass omitType=true to not print a type, and pass an empty
250 /// attribute list if you don't care for attributes.
251 virtual void printRegionArgument(BlockArgument arg,
252 ArrayRef<NamedAttribute> argAttrs = {},
253 bool omitType = false) = 0;
254
255 /// Print implementations for various things an operation contains.
256 virtual void printOperand(Value value) = 0;
257 virtual void printOperand(Value value, raw_ostream &os) = 0;
258
259 /// Print a comma separated list of operands.
260 template <typename ContainerType>
261 void printOperands(const ContainerType &container) {
262 printOperands(container.begin(), container.end());
263 }
264
265 /// Print a comma separated list of operands.
266 template <typename IteratorType>
267 void printOperands(IteratorType it, IteratorType end) {
268 if (it == end)
269 return;
270 printOperand(*it);
271 for (++it; it != end; ++it) {
272 getStream() << ", ";
273 printOperand(*it);
274 }
275 }
276
277 /// Print the given successor.
278 virtual void printSuccessor(Block *successor) = 0;
279
280 /// Print the successor and its operands.
281 virtual void printSuccessorAndUseList(Block *successor,
282 ValueRange succOperands) = 0;
283
284 /// If the specified operation has attributes, print out an attribute
285 /// dictionary with their values. elidedAttrs allows the client to ignore
286 /// specific well known attributes, commonly used if the attribute value is
287 /// printed some other way (like as a fixed operand).
288 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
289 ArrayRef<StringRef> elidedAttrs = {}) = 0;
290
291 /// If the specified operation has attributes, print out an attribute
292 /// dictionary prefixed with 'attributes'.
293 virtual void
294 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
295 ArrayRef<StringRef> elidedAttrs = {}) = 0;
296
297 /// Print the entire operation with the default generic assembly form.
298 /// If `printOpName` is true, then the operation name is printed (the default)
299 /// otherwise it is omitted and the print will start with the operand list.
300 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
301
302 /// Prints a region.
303 /// If 'printEntryBlockArgs' is false, the arguments of the
304 /// block are not printed. If 'printBlockTerminator' is false, the terminator
305 /// operation of the block is not printed. If printEmptyBlock is true, then
306 /// the block header is printed even if the block is empty.
307 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
308 bool printBlockTerminators = true,
309 bool printEmptyBlock = false) = 0;
310
311 /// Renumber the arguments for the specified region to the same names as the
312 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
313 /// operations. If any entry in namesToUse is null, the corresponding
314 /// argument name is left alone.
315 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
316
317 /// Prints an affine map of SSA ids, where SSA id names are used in place
318 /// of dims/symbols.
319 /// Operand values must come from single-result sources, and be valid
320 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
321 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
322 ValueRange operands) = 0;
323
324 /// Prints an affine expression of SSA ids with SSA id names used instead of
325 /// dims and symbols.
326 /// Operand values must come from single-result sources, and be valid
327 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
328 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
329 ValueRange symOperands) = 0;
330
331 /// Print the complete type of an operation in functional form.
332 void printFunctionalType(Operation *op);
333 using AsmPrinter::printFunctionalType;
334};
335
336// Make the implementations convenient to use.
337inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
338 p.printOperand(value);
339 return p;
340}
341
342template <typename T,
343 typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
344 !std::is_convertible<T &, Value &>::value,
345 T>::type * = nullptr>
346inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
347 p.printOperands(values);
348 return p;
349}
350
351inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
352 p.printSuccessor(value);
353 return p;
354}
355
356//===----------------------------------------------------------------------===//
357// AsmParser
358//===----------------------------------------------------------------------===//
359
360/// This base class exposes generic asm parser hooks, usable across the various
361/// derived parsers.
362class AsmParser {
363public:
364 AsmParser() = default;
365 virtual ~AsmParser();
366
367 MLIRContext *getContext() const;
368
369 /// Return the location of the original name token.
370 virtual llvm::SMLoc getNameLoc() const = 0;
371
372 //===--------------------------------------------------------------------===//
373 // Utilities
374 //===--------------------------------------------------------------------===//
375
376 /// Emit a diagnostic at the specified location and return failure.
377 virtual InFlightDiagnostic emitError(llvm::SMLoc loc,
378 const Twine &message = {}) = 0;
379
380 /// Return a builder which provides useful access to MLIRContext, global
381 /// objects like types and attributes.
382 virtual Builder &getBuilder() const = 0;
383
384 /// Get the location of the next token and store it into the argument. This
385 /// always succeeds.
386 virtual llvm::SMLoc getCurrentLocation() = 0;
387 ParseResult getCurrentLocation(llvm::SMLoc *loc) {
388 *loc = getCurrentLocation();
389 return success();
390 }
391
392 /// Re-encode the given source location as an MLIR location and return it.
393 /// Note: This method should only be used when a `Location` is necessary, as
394 /// the encoding process is not efficient.
395 virtual Location getEncodedSourceLoc(llvm::SMLoc loc) = 0;
396
397 //===--------------------------------------------------------------------===//
398 // Token Parsing
399 //===--------------------------------------------------------------------===//
400
401 /// Parse a '->' token.
402 virtual ParseResult parseArrow() = 0;
403
404 /// Parse a '->' token if present
405 virtual ParseResult parseOptionalArrow() = 0;
406
407 /// Parse a `{` token.
408 virtual ParseResult parseLBrace() = 0;
409
410 /// Parse a `{` token if present.
411 virtual ParseResult parseOptionalLBrace() = 0;
412
413 /// Parse a `}` token.
414 virtual ParseResult parseRBrace() = 0;
415
416 /// Parse a `}` token if present.
417 virtual ParseResult parseOptionalRBrace() = 0;
418
419 /// Parse a `:` token.
420 virtual ParseResult parseColon() = 0;
421
422 /// Parse a `:` token if present.
423 virtual ParseResult parseOptionalColon() = 0;
424
425 /// Parse a `,` token.
426 virtual ParseResult parseComma() = 0;
427
428 /// Parse a `,` token if present.
429 virtual ParseResult parseOptionalComma() = 0;
430
431 /// Parse a `=` token.
432 virtual ParseResult parseEqual() = 0;
433
434 /// Parse a `=` token if present.
435 virtual ParseResult parseOptionalEqual() = 0;
436
437 /// Parse a '<' token.
438 virtual ParseResult parseLess() = 0;
439
440 /// Parse a '<' token if present.
441 virtual ParseResult parseOptionalLess() = 0;
442
443 /// Parse a '>' token.
444 virtual ParseResult parseGreater() = 0;
445
446 /// Parse a '>' token if present.
447 virtual ParseResult parseOptionalGreater() = 0;
448
449 /// Parse a '?' token.
450 virtual ParseResult parseQuestion() = 0;
451
452 /// Parse a '?' token if present.
453 virtual ParseResult parseOptionalQuestion() = 0;
454
455 /// Parse a '+' token.
456 virtual ParseResult parsePlus() = 0;
457
458 /// Parse a '+' token if present.
459 virtual ParseResult parseOptionalPlus() = 0;
460
461 /// Parse a '*' token.
462 virtual ParseResult parseStar() = 0;
463
464 /// Parse a '*' token if present.
465 virtual ParseResult parseOptionalStar() = 0;
466
467 /// Parse a quoted string token.
468 ParseResult parseString(std::string *string) {
469 auto loc = getCurrentLocation();
470 if (parseOptionalString(string))
471 return emitError(loc, "expected string");
472 return success();
473 }
474
475 /// Parse a quoted string token if present.
476 virtual ParseResult parseOptionalString(std::string *string) = 0;
477
478 /// Parse a given keyword.
479 ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") {
480 auto loc = getCurrentLocation();
481 if (parseOptionalKeyword(keyword))
482 return emitError(loc, "expected '") << keyword << "'" << msg;
483 return success();
484 }
485
486 /// Parse a keyword into 'keyword'.
487 ParseResult parseKeyword(StringRef *keyword) {
488 auto loc = getCurrentLocation();
489 if (parseOptionalKeyword(keyword))
490 return emitError(loc, "expected valid keyword");
491 return success();
492 }
493
494 /// Parse the given keyword if present.
495 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
496
497 /// Parse a keyword, if present, into 'keyword'.
498 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
499
500 /// Parse a keyword, if present, and if one of the 'allowedValues',
501 /// into 'keyword'
502 virtual ParseResult
503 parseOptionalKeyword(StringRef *keyword,
504 ArrayRef<StringRef> allowedValues) = 0;
505
506 /// Parse a keyword or a quoted string.
507 ParseResult parseKeywordOrString(std::string *result) {
508 if (failed(parseOptionalKeywordOrString(result)))
509 return emitError(getCurrentLocation())
510 << "expected valid keyword or string";
511 return success();
512 }
513
514 /// Parse an optional keyword or string.
515 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
516
517 /// Parse a `(` token.
518 virtual ParseResult parseLParen() = 0;
519
520 /// Parse a `(` token if present.
521 virtual ParseResult parseOptionalLParen() = 0;
522
523 /// Parse a `)` token.
524 virtual ParseResult parseRParen() = 0;
525
526 /// Parse a `)` token if present.
527 virtual ParseResult parseOptionalRParen() = 0;
528
529 /// Parse a `[` token.
530 virtual ParseResult parseLSquare() = 0;
531
532 /// Parse a `[` token if present.
533 virtual ParseResult parseOptionalLSquare() = 0;
534
535 /// Parse a `]` token.
536 virtual ParseResult parseRSquare() = 0;
537
538 /// Parse a `]` token if present.
539 virtual ParseResult parseOptionalRSquare() = 0;
540
541 /// Parse a `...` token if present;
542 virtual ParseResult parseOptionalEllipsis() = 0;
543
544 /// Parse a floating point value from the stream.
545 virtual ParseResult parseFloat(double &result) = 0;
546
547 /// Parse an integer value from the stream.
548 template <typename IntT>
549 ParseResult parseInteger(IntT &result) {
550 auto loc = getCurrentLocation();
551 OptionalParseResult parseResult = parseOptionalInteger(result);
31
Calling 'AsmParser::parseOptionalInteger'
35
Returning from 'AsmParser::parseOptionalInteger'
552 if (!parseResult.hasValue())
36
Taking false branch
553 return emitError(loc, "expected integer value");
554 return *parseResult;
37
Returning without writing to 'result'
555 }
556
557 /// Parse an optional integer value from the stream.
558 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
559
560 template <typename IntT>
561 OptionalParseResult parseOptionalInteger(IntT &result) {
562 auto loc = getCurrentLocation();
563
564 // Parse the unsigned variant.
565 APInt uintResult;
566 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
567 if (!parseResult.hasValue() || failed(*parseResult))
32
Assuming the condition is false
33
Taking true branch
568 return parseResult;
34
Returning without writing to 'result'
569
570 // Try to convert to the provided integer type. sextOrTrunc is correct even
571 // for unsigned types because parseOptionalInteger ensures the sign bit is
572 // zero for non-negated integers.
573 result =
574 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
575 if (APInt(uintResult.getBitWidth(), result) != uintResult)
576 return emitError(loc, "integer value too large");
577 return success();
578 }
579
580 /// These are the supported delimiters around operand lists and region
581 /// argument lists, used by parseOperandList and parseRegionArgumentList.
582 enum class Delimiter {
583 /// Zero or more operands with no delimiters.
584 None,
585 /// Parens surrounding zero or more operands.
586 Paren,
587 /// Square brackets surrounding zero or more operands.
588 Square,
589 /// <> brackets surrounding zero or more operands.
590 LessGreater,
591 /// {} brackets surrounding zero or more operands.
592 Braces,
593 /// Parens supporting zero or more operands, or nothing.
594 OptionalParen,
595 /// Square brackets supporting zero or more ops, or nothing.
596 OptionalSquare,
597 /// <> brackets supporting zero or more ops, or nothing.
598 OptionalLessGreater,
599 /// {} brackets surrounding zero or more operands, or nothing.
600 OptionalBraces,
601 };
602
603 /// Parse a list of comma-separated items with an optional delimiter. If a
604 /// delimiter is provided, then an empty list is allowed. If not, then at
605 /// least one element will be parsed.
606 ///
607 /// contextMessage is an optional message appended to "expected '('" sorts of
608 /// diagnostics when parsing the delimeters.
609 virtual ParseResult
610 parseCommaSeparatedList(Delimiter delimiter,
611 function_ref<ParseResult()> parseElementFn,
612 StringRef contextMessage = StringRef()) = 0;
613
614 /// Parse a comma separated list of elements that must have at least one entry
615 /// in it.
616 ParseResult
617 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
618 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
619 }
620
621 //===--------------------------------------------------------------------===//
622 // Attribute/Type Parsing
623 //===--------------------------------------------------------------------===//
624
625 /// Invoke the `getChecked` method of the given Attribute or Type class, using
626 /// the provided location to emit errors in the case of failure. Note that
627 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
628 /// context parameter.
629 template <typename T, typename... ParamsT>
630 T getChecked(llvm::SMLoc loc, ParamsT &&... params) {
631 return T::getChecked([&] { return emitError(loc); },
632 std::forward<ParamsT>(params)...);
633 }
634 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
635 /// errors.
636 template <typename T, typename... ParamsT>
637 T getChecked(ParamsT &&... params) {
638 return T::getChecked([&] { return emitError(getNameLoc()); },
639 std::forward<ParamsT>(params)...);
640 }
641
642 //===--------------------------------------------------------------------===//
643 // Attribute Parsing
644 //===--------------------------------------------------------------------===//
645
646 /// Parse an arbitrary attribute of a given type and return it in result.
647 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
648
649 /// Parse a custom attribute with the provided callback, unless the next
650 /// token is `#`, in which case the generic parser is invoked.
651 virtual ParseResult parseCustomAttributeWithFallback(
652 Attribute &result, Type type,
653 function_ref<ParseResult(Attribute &result, Type type)>
654 parseAttribute) = 0;
655
656 /// Parse an attribute of a specific kind and type.
657 template <typename AttrType>
658 ParseResult parseAttribute(AttrType &result, Type type = {}) {
659 llvm::SMLoc loc = getCurrentLocation();
660
661 // Parse any kind of attribute.
662 Attribute attr;
663 if (parseAttribute(attr, type))
664 return failure();
665
666 // Check for the right kind of attribute.
667 if (!(result = attr.dyn_cast<AttrType>()))
668 return emitError(loc, "invalid kind of attribute specified");
669
670 return success();
671 }
672
673 /// Parse an arbitrary attribute and return it in result. This also adds the
674 /// attribute to the specified attribute list with the specified name.
675 ParseResult parseAttribute(Attribute &result, StringRef attrName,
676 NamedAttrList &attrs) {
677 return parseAttribute(result, Type(), attrName, attrs);
678 }
679
680 /// Parse an attribute of a specific kind and type.
681 template <typename AttrType>
682 ParseResult parseAttribute(AttrType &result, StringRef attrName,
683 NamedAttrList &attrs) {
684 return parseAttribute(result, Type(), attrName, attrs);
685 }
686
687 /// Parse an arbitrary attribute of a given type and populate it in `result`.
688 /// This also adds the attribute to the specified attribute list with the
689 /// specified name.
690 template <typename AttrType>
691 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
692 NamedAttrList &attrs) {
693 llvm::SMLoc loc = getCurrentLocation();
694
695 // Parse any kind of attribute.
696 Attribute attr;
697 if (parseAttribute(attr, type))
698 return failure();
699
700 // Check for the right kind of attribute.
701 result = attr.dyn_cast<AttrType>();
702 if (!result)
703 return emitError(loc, "invalid kind of attribute specified");
704
705 attrs.append(attrName, result);
706 return success();
707 }
708
709 /// Trait to check if `AttrType` provides a `parse` method.
710 template <typename AttrType>
711 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
712 std::declval<Type>()));
713 template <typename AttrType>
714 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
715
716 /// Parse a custom attribute of a given type unless the next token is `#`, in
717 /// which case the generic parser is invoked. The parsed attribute is
718 /// populated in `result` and also added to the specified attribute list with
719 /// the specified name.
720 template <typename AttrType>
721 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
722 parseCustomAttributeWithFallback(AttrType &result, Type type,
723 StringRef attrName, NamedAttrList &attrs) {
724 llvm::SMLoc loc = getCurrentLocation();
725
726 // Parse any kind of attribute.
727 Attribute attr;
728 if (parseCustomAttributeWithFallback(
729 attr, type, [&](Attribute &result, Type type) -> ParseResult {
730 result = AttrType::parse(*this, type);
731 if (!result)
732 return failure();
733 return success();
734 }))
735 return failure();
736
737 // Check for the right kind of attribute.
738 result = attr.dyn_cast<AttrType>();
739 if (!result)
740 return emitError(loc, "invalid kind of attribute specified");
741
742 attrs.append(attrName, result);
743 return success();
744 }
745
746 /// SFINAE parsing method for Attribute that don't implement a parse method.
747 template <typename AttrType>
748 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
749 parseCustomAttributeWithFallback(AttrType &result, Type type,
750 StringRef attrName, NamedAttrList &attrs) {
751 return parseAttribute(result, type, attrName, attrs);
752 }
753
754 /// Parse a custom attribute of a given type unless the next token is `#`, in
755 /// which case the generic parser is invoked. The parsed attribute is
756 /// populated in `result`.
757 template <typename AttrType>
758 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
759 parseCustomAttributeWithFallback(AttrType &result) {
760 llvm::SMLoc loc = getCurrentLocation();
761
762 // Parse any kind of attribute.
763 Attribute attr;
764 if (parseCustomAttributeWithFallback(
765 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
766 result = AttrType::parse(*this, type);
767 return success(!!result);
768 }))
769 return failure();
770
771 // Check for the right kind of attribute.
772 result = attr.dyn_cast<AttrType>();
773 if (!result)
774 return emitError(loc, "invalid kind of attribute specified");
775 return success();
776 }
777
778 /// SFINAE parsing method for Attribute that don't implement a parse method.
779 template <typename AttrType>
780 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
781 parseCustomAttributeWithFallback(AttrType &result) {
782 return parseAttribute(result);
783 }
784
785 /// Parse an arbitrary optional attribute of a given type and return it in
786 /// result.
787 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
788 Type type = {}) = 0;
789
790 /// Parse an optional array attribute and return it in result.
791 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
792 Type type = {}) = 0;
793
794 /// Parse an optional string attribute and return it in result.
795 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
796 Type type = {}) = 0;
797
798 /// Parse an optional attribute of a specific type and add it to the list with
799 /// the specified name.
800 template <typename AttrType>
801 OptionalParseResult parseOptionalAttribute(AttrType &result,
802 StringRef attrName,
803 NamedAttrList &attrs) {
804 return parseOptionalAttribute(result, Type(), attrName, attrs);
805 }
806
807 /// Parse an optional attribute of a specific type and add it to the list with
808 /// the specified name.
809 template <typename AttrType>
810 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
811 StringRef attrName,
812 NamedAttrList &attrs) {
813 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
814 if (parseResult.hasValue() && succeeded(*parseResult))
815 attrs.append(attrName, result);
816 return parseResult;
817 }
818
819 /// Parse a named dictionary into 'result' if it is present.
820 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
821
822 /// Parse a named dictionary into 'result' if the `attributes` keyword is
823 /// present.
824 virtual ParseResult
825 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
826
827 /// Parse an affine map instance into 'map'.
828 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
829
830 /// Parse an integer set instance into 'set'.
831 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
832
833 //===--------------------------------------------------------------------===//
834 // Identifier Parsing
835 //===--------------------------------------------------------------------===//
836
837 /// Parse an @-identifier and store it (without the '@' symbol) in a string
838 /// attribute named 'attrName'.
839 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
840 NamedAttrList &attrs) {
841 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
842 return emitError(getCurrentLocation())
843 << "expected valid '@'-identifier for symbol name";
844 return success();
845 }
846
847 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
848 /// string attribute named 'attrName'.
849 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
850 StringRef attrName,
851 NamedAttrList &attrs) = 0;
852
853 //===--------------------------------------------------------------------===//
854 // Type Parsing
855 //===--------------------------------------------------------------------===//
856
857 /// Parse a type.
858 virtual ParseResult parseType(Type &result) = 0;
859
860 /// Parse a custom type with the provided callback, unless the next
861 /// token is `#`, in which case the generic parser is invoked.
862 virtual ParseResult parseCustomTypeWithFallback(
863 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
864
865 /// Parse an optional type.
866 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
867
868 /// Parse a type of a specific type.
869 template <typename TypeT>
870 ParseResult parseType(TypeT &result) {
871 llvm::SMLoc loc = getCurrentLocation();
872
873 // Parse any kind of type.
874 Type type;
875 if (parseType(type))
876 return failure();
877
878 // Check for the right kind of type.
879 result = type.dyn_cast<TypeT>();
880 if (!result)
881 return emitError(loc, "invalid kind of type specified");
882
883 return success();
884 }
885
886 /// Trait to check if `TypeT` provides a `parse` method.
887 template <typename TypeT>
888 using type_has_parse_method =
889 decltype(TypeT::parse(std::declval<AsmParser &>()));
890 template <typename TypeT>
891 using detect_type_has_parse_method =
892 llvm::is_detected<type_has_parse_method, TypeT>;
893
894 /// Parse a custom Type of a given type unless the next token is `#`, in
895 /// which case the generic parser is invoked. The parsed Type is
896 /// populated in `result`.
897 template <typename TypeT>
898 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
899 parseCustomTypeWithFallback(TypeT &result) {
900 llvm::SMLoc loc = getCurrentLocation();
901
902 // Parse any kind of Type.
903 Type type;
904 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
905 result = TypeT::parse(*this);
906 return success(!!result);
907 }))
908 return failure();
909
910 // Check for the right kind of Type.
911 result = type.dyn_cast<TypeT>();
912 if (!result)
913 return emitError(loc, "invalid kind of Type specified");
914 return success();
915 }
916
917 /// SFINAE parsing method for Type that don't implement a parse method.
918 template <typename TypeT>
919 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
920 parseCustomTypeWithFallback(TypeT &result) {
921 return parseType(result);
922 }
923
924 /// Parse a type list.
925 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
926 do {
927 Type type;
928 if (parseType(type))
929 return failure();
930 result.push_back(type);
931 } while (succeeded(parseOptionalComma()));
932 return success();
933 }
934
935 /// Parse an arrow followed by a type list.
936 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
937
938 /// Parse an optional arrow followed by a type list.
939 virtual ParseResult
940 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
941
942 /// Parse a colon followed by a type.
943 virtual ParseResult parseColonType(Type &result) = 0;
944
945 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
946 template <typename TypeType>
947 ParseResult parseColonType(TypeType &result) {
948 llvm::SMLoc loc = getCurrentLocation();
949
950 // Parse any kind of type.
951 Type type;
952 if (parseColonType(type))
953 return failure();
954
955 // Check for the right kind of type.
956 result = type.dyn_cast<TypeType>();
957 if (!result)
958 return emitError(loc, "invalid kind of type specified");
959
960 return success();
961 }
962
963 /// Parse a colon followed by a type list, which must have at least one type.
964 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
965
966 /// Parse an optional colon followed by a type list, which if present must
967 /// have at least one type.
968 virtual ParseResult
969 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
970
971 /// Parse a keyword followed by a type.
972 ParseResult parseKeywordType(const char *keyword, Type &result) {
973 return failure(parseKeyword(keyword) || parseType(result));
974 }
975
976 /// Add the specified type to the end of the specified type list and return
977 /// success. This is a helper designed to allow parse methods to be simple
978 /// and chain through || operators.
979 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
980 result.push_back(type);
981 return success();
982 }
983
984 /// Add the specified types to the end of the specified type list and return
985 /// success. This is a helper designed to allow parse methods to be simple
986 /// and chain through || operators.
987 ParseResult addTypesToList(ArrayRef<Type> types,
988 SmallVectorImpl<Type> &result) {
989 result.append(types.begin(), types.end());
990 return success();
991 }
992
993 /// Parse a 'x' separated dimension list. This populates the dimension list,
994 /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on
995 /// `?` otherwise.
996 ///
997 /// dimension-list ::= (dimension `x`)*
998 /// dimension ::= `?` | integer
999 ///
1000 /// When `allowDynamic` is not set, this is used to parse:
1001 ///
1002 /// static-dimension-list ::= (integer `x`)*
1003 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1004 bool allowDynamic = true) = 0;
1005
1006 /// Parse an 'x' token in a dimension list, handling the case where the x is
1007 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1008 /// next token.
1009 virtual ParseResult parseXInDimensionList() = 0;
1010
1011private:
1012 AsmParser(const AsmParser &) = delete;
1013 void operator=(const AsmParser &) = delete;
1014};
1015
1016//===----------------------------------------------------------------------===//
1017// OpAsmParser
1018//===----------------------------------------------------------------------===//
1019
1020/// The OpAsmParser has methods for interacting with the asm parser: parsing
1021/// things from it, emitting errors etc. It has an intentionally high-level API
1022/// that is designed to reduce/constrain syntax innovation in individual
1023/// operations.
1024///
1025/// For example, consider an op like this:
1026///
1027/// %x = load %p[%1, %2] : memref<...>
1028///
1029/// The "%x = load" tokens are already parsed and therefore invisible to the
1030/// custom op parser. This can be supported by calling `parseOperandList` to
1031/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1032/// parse the indices, then calling `parseColonTypeList` to parse the result
1033/// type.
1034///
1035class OpAsmParser : public AsmParser {
1036public:
1037 using AsmParser::AsmParser;
1038 ~OpAsmParser() override;
1039
1040 /// Parse a loc(...) specifier if present, filling in result if so.
1041 /// Location for BlockArgument and Operation may be deferred with an alias, in
1042 /// which case an OpaqueLoc is set and will be resolved when parsing
1043 /// completes.
1044 virtual ParseResult
1045 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1046
1047 /// Return the name of the specified result in the specified syntax, as well
1048 /// as the sub-element in the name. It returns an empty string and ~0U for
1049 /// invalid result numbers. For example, in this operation:
1050 ///
1051 /// %x, %y:2, %z = foo.op
1052 ///
1053 /// getResultName(0) == {"x", 0 }
1054 /// getResultName(1) == {"y", 0 }
1055 /// getResultName(2) == {"y", 1 }
1056 /// getResultName(3) == {"z", 0 }
1057 /// getResultName(4) == {"", ~0U }
1058 virtual std::pair<StringRef, unsigned>
1059 getResultName(unsigned resultNo) const = 0;
1060
1061 /// Return the number of declared SSA results. This returns 4 for the foo.op
1062 /// example in the comment for `getResultName`.
1063 virtual size_t getNumResults() const = 0;
1064
1065 // These methods emit an error and return failure or success. This allows
1066 // these to be chained together into a linear sequence of || expressions in
1067 // many cases.
1068
1069 /// Parse an operation in its generic form.
1070 /// The parsed operation is parsed in the current context and inserted in the
1071 /// provided block and insertion point. The results produced by this operation
1072 /// aren't mapped to any named value in the parser. Returns nullptr on
1073 /// failure.
1074 virtual Operation *parseGenericOperation(Block *insertBlock,
1075 Block::iterator insertPt) = 0;
1076
1077 /// Parse the name of an operation, in the custom form. On success, return a
1078 /// an object of type 'OperationName'. Otherwise, failure is returned.
1079 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1080
1081 //===--------------------------------------------------------------------===//
1082 // Operand Parsing
1083 //===--------------------------------------------------------------------===//
1084
1085 /// This is the representation of an operand reference.
1086 struct OperandType {
1087 llvm::SMLoc location; // Location of the token.
1088 StringRef name; // Value name, e.g. %42 or %abc
1089 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1090 };
1091
1092 /// Parse different components, viz., use-info of operand(s), successor(s),
1093 /// region(s), attribute(s) and function-type, of the generic form of an
1094 /// operation instance and populate the input operation-state 'result' with
1095 /// those components. If any of the components is explicitly provided, then
1096 /// skip parsing that component.
1097 virtual ParseResult parseGenericOperationAfterOpName(
1098 OperationState &result,
1099 Optional<ArrayRef<OperandType>> parsedOperandType = llvm::None,
1100 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1101 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1102 llvm::None,
1103 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1104 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1105
1106 /// Parse a single operand.
1107 virtual ParseResult parseOperand(OperandType &result) = 0;
1108
1109 /// Parse a single operand if present.
1110 virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0;
1111
1112 /// Parse zero or more SSA comma-separated operand references with a specified
1113 /// surrounding delimiter, and an optional required operand count.
1114 virtual ParseResult
1115 parseOperandList(SmallVectorImpl<OperandType> &result,
1116 int requiredOperandCount = -1,
1117 Delimiter delimiter = Delimiter::None) = 0;
1118 ParseResult parseOperandList(SmallVectorImpl<OperandType> &result,
1119 Delimiter delimiter) {
1120 return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter);
1121 }
1122
1123 /// Parse zero or more trailing SSA comma-separated trailing operand
1124 /// references with a specified surrounding delimiter, and an optional
1125 /// required operand count. A leading comma is expected before the operands.
1126 virtual ParseResult
1127 parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
1128 int requiredOperandCount = -1,
1129 Delimiter delimiter = Delimiter::None) = 0;
1130 ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result,
1131 Delimiter delimiter) {
1132 return parseTrailingOperandList(result, /*requiredOperandCount=*/-1,
1133 delimiter);
1134 }
1135
1136 /// Resolve an operand to an SSA value, emitting an error on failure.
1137 virtual ParseResult resolveOperand(const OperandType &operand, Type type,
1138 SmallVectorImpl<Value> &result) = 0;
1139
1140 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1141 /// appending the results to the list on success. This method should be used
1142 /// when all operands have the same type.
1143 ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type,
1144 SmallVectorImpl<Value> &result) {
1145 for (auto elt : operands)
1146 if (resolveOperand(elt, type, result))
1147 return failure();
1148 return success();
1149 }
1150
1151 /// Resolve a list of operands and a list of operand types to SSA values,
1152 /// emitting an error and returning failure, or appending the results
1153 /// to the list on success.
1154 ParseResult resolveOperands(ArrayRef<OperandType> operands,
1155 ArrayRef<Type> types, llvm::SMLoc loc,
1156 SmallVectorImpl<Value> &result) {
1157 if (operands.size() != types.size())
1158 return emitError(loc)
1159 << operands.size() << " operands present, but expected "
1160 << types.size();
1161
1162 for (unsigned i = 0, e = operands.size(); i != e; ++i)
1163 if (resolveOperand(operands[i], types[i], result))
1164 return failure();
1165 return success();
1166 }
1167 template <typename Operands>
1168 ParseResult resolveOperands(Operands &&operands, Type type, llvm::SMLoc loc,
1169 SmallVectorImpl<Value> &result) {
1170 return resolveOperands(std::forward<Operands>(operands),
1171 ArrayRef<Type>(type), loc, result);
1172 }
1173 template <typename Operands, typename Types>
1174 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1175 resolveOperands(Operands &&operands, Types &&types, llvm::SMLoc loc,
1176 SmallVectorImpl<Value> &result) {
1177 size_t operandSize = std::distance(operands.begin(), operands.end());
1178 size_t typeSize = std::distance(types.begin(), types.end());
1179 if (operandSize != typeSize)
1180 return emitError(loc)
1181 << operandSize << " operands present, but expected " << typeSize;
1182
1183 for (auto it : llvm::zip(operands, types))
1184 if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
1185 return failure();
1186 return success();
1187 }
1188
1189 /// Parses an affine map attribute where dims and symbols are SSA operands.
1190 /// Operand values must come from single-result sources, and be valid
1191 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1192 virtual ParseResult
1193 parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map,
1194 StringRef attrName, NamedAttrList &attrs,
1195 Delimiter delimiter = Delimiter::Square) = 0;
1196
1197 /// Parses an affine expression where dims and symbols are SSA operands.
1198 /// Operand values must come from single-result sources, and be valid
1199 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1200 virtual ParseResult
1201 parseAffineExprOfSSAIds(SmallVectorImpl<OperandType> &dimOperands,
1202 SmallVectorImpl<OperandType> &symbOperands,
1203 AffineExpr &expr) = 0;
1204
1205 //===--------------------------------------------------------------------===//
1206 // Region Parsing
1207 //===--------------------------------------------------------------------===//
1208
1209 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1210 /// moved to the op regions after the op is created. The first block of the
1211 /// region takes 'arguments' of types 'argTypes'. If 'enableNameShadowing' is
1212 /// set to true, the argument names are allowed to shadow the names of other
1213 /// existing SSA values defined above the region scope. 'enableNameShadowing'
1214 /// can only be set to true for regions attached to operations that are
1215 /// 'IsolatedFromAbove.
1216 virtual ParseResult parseRegion(Region &region,
1217 ArrayRef<OperandType> arguments = {},
1218 ArrayRef<Type> argTypes = {},
1219 bool enableNameShadowing = false) = 0;
1220
1221 /// Parses a region if present.
1222 virtual OptionalParseResult
1223 parseOptionalRegion(Region &region, ArrayRef<OperandType> arguments = {},
1224 ArrayRef<Type> argTypes = {},
1225 bool enableNameShadowing = false) = 0;
1226
1227 /// Parses a region if present. If the region is present, a new region is
1228 /// allocated and placed in `region`. If no region is present or on failure,
1229 /// `region` remains untouched.
1230 virtual OptionalParseResult parseOptionalRegion(
1231 std::unique_ptr<Region> &region, ArrayRef<OperandType> arguments = {},
1232 ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0;
1233
1234 /// Parse a region argument, this argument is resolved when calling
1235 /// 'parseRegion'.
1236 virtual ParseResult parseRegionArgument(OperandType &argument) = 0;
1237
1238 /// Parse zero or more region arguments with a specified surrounding
1239 /// delimiter, and an optional required argument count. Region arguments
1240 /// define new values; so this also checks if values with the same names have
1241 /// not been defined yet.
1242 virtual ParseResult
1243 parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
1244 int requiredOperandCount = -1,
1245 Delimiter delimiter = Delimiter::None) = 0;
1246 virtual ParseResult
1247 parseRegionArgumentList(SmallVectorImpl<OperandType> &result,
1248 Delimiter delimiter) {
1249 return parseRegionArgumentList(result, /*requiredOperandCount=*/-1,
1250 delimiter);
1251 }
1252
1253 /// Parse a region argument if present.
1254 virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0;
1255
1256 //===--------------------------------------------------------------------===//
1257 // Successor Parsing
1258 //===--------------------------------------------------------------------===//
1259
1260 /// Parse a single operation successor.
1261 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1262
1263 /// Parse an optional operation successor.
1264 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1265
1266 /// Parse a single operation successor and its operand list.
1267 virtual ParseResult
1268 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1269
1270 //===--------------------------------------------------------------------===//
1271 // Type Parsing
1272 //===--------------------------------------------------------------------===//
1273
1274 /// Parse a list of assignments of the form
1275 /// (%x1 = %y1, %x2 = %y2, ...)
1276 ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs,
1277 SmallVectorImpl<OperandType> &rhs) {
1278 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1279 if (!result.hasValue())
1280 return emitError(getCurrentLocation(), "expected '('");
1281 return result.getValue();
1282 }
1283
1284 virtual OptionalParseResult
1285 parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs,
1286 SmallVectorImpl<OperandType> &rhs) = 0;
1287
1288 /// Parse a list of assignments of the form
1289 /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...)
1290 ParseResult parseAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
1291 SmallVectorImpl<OperandType> &rhs,
1292 SmallVectorImpl<Type> &types) {
1293 OptionalParseResult result =
1294 parseOptionalAssignmentListWithTypes(lhs, rhs, types);
1295 if (!result.hasValue())
1296 return emitError(getCurrentLocation(), "expected '('");
1297 return result.getValue();
1298 }
1299
1300 virtual OptionalParseResult
1301 parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs,
1302 SmallVectorImpl<OperandType> &rhs,
1303 SmallVectorImpl<Type> &types) = 0;
1304
1305private:
1306 /// Parse either an operand list or a region argument list depending on
1307 /// whether isOperandList is true.
1308 ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result,
1309 bool isOperandList,
1310 int requiredOperandCount,
1311 Delimiter delimiter);
1312};
1313
1314//===--------------------------------------------------------------------===//
1315// Dialect OpAsm interface.
1316//===--------------------------------------------------------------------===//
1317
1318/// A functor used to set the name of the start of a result group of an
1319/// operation. See 'getAsmResultNames' below for more details.
1320using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1321
1322class OpAsmDialectInterface
1323 : public DialectInterface::Base<OpAsmDialectInterface> {
1324public:
1325 /// Holds the result of `getAlias` hook call.
1326 enum class AliasResult {
1327 /// The object (type or attribute) is not supported by the hook
1328 /// and an alias was not provided.
1329 NoAlias,
1330 /// An alias was provided, but it might be overriden by other hook.
1331 OverridableAlias,
1332 /// An alias was provided and it should be used
1333 /// (no other hooks will be checked).
1334 FinalAlias
1335 };
1336
1337 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1338
1339 /// Hooks for getting an alias identifier alias for a given symbol, that is
1340 /// not necessarily a part of this dialect. The identifier is used in place of
1341 /// the symbol when printing textual IR. These aliases must not contain `.` or
1342 /// end with a numeric digit([0-9]+).
1343 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1344 return AliasResult::NoAlias;
1345 }
1346 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1347 return AliasResult::NoAlias;
1348 }
1349
1350 /// Get a special name to use when printing the given operation. See
1351 /// OpAsmInterface.td#getAsmResultNames for usage details and documentation.
1352 virtual void getAsmResultNames(Operation *op,
1353 OpAsmSetValueNameFn setNameFn) const {}
1354};
1355} // namespace mlir
1356
1357//===--------------------------------------------------------------------===//
1358// Operation OpAsm interface.
1359//===--------------------------------------------------------------------===//
1360
1361/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1362#include "mlir/IR/OpAsmInterface.h.inc"
1363
1364#endif