File: | build-llvm/tools/clang/stage2-bins/tools/mlir/test/lib/Dialect/Test/TestTypeDefs.cpp.inc |
Warning: | line 452, column 10 3rd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- TestTypes.cpp - MLIR Test Dialect Types ------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file contains types defined by the TestDialect for testing various | |||
10 | // features of MLIR. | |||
11 | // | |||
12 | //===----------------------------------------------------------------------===// | |||
13 | ||||
14 | #include "TestTypes.h" | |||
15 | #include "TestDialect.h" | |||
16 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h" | |||
17 | #include "mlir/IR/Builders.h" | |||
18 | #include "mlir/IR/DialectImplementation.h" | |||
19 | #include "mlir/IR/Types.h" | |||
20 | #include "llvm/ADT/Hashing.h" | |||
21 | #include "llvm/ADT/SetVector.h" | |||
22 | #include "llvm/ADT/TypeSwitch.h" | |||
23 | ||||
24 | using namespace mlir; | |||
25 | using namespace test; | |||
26 | ||||
27 | // Custom parser for SignednessSemantics. | |||
28 | static ParseResult | |||
29 | parseSignedness(AsmParser &parser, | |||
30 | TestIntegerType::SignednessSemantics &result) { | |||
31 | StringRef signStr; | |||
32 | auto loc = parser.getCurrentLocation(); | |||
33 | if (parser.parseKeyword(&signStr)) | |||
34 | return failure(); | |||
35 | if (signStr.equals_insensitive("u") || signStr.equals_insensitive("unsigned")) | |||
36 | result = TestIntegerType::SignednessSemantics::Unsigned; | |||
37 | else if (signStr.equals_insensitive("s") || | |||
38 | signStr.equals_insensitive("signed")) | |||
39 | result = TestIntegerType::SignednessSemantics::Signed; | |||
40 | else if (signStr.equals_insensitive("n") || | |||
41 | signStr.equals_insensitive("none")) | |||
42 | result = TestIntegerType::SignednessSemantics::Signless; | |||
43 | else | |||
44 | return parser.emitError(loc, "expected signed, unsigned, or none"); | |||
45 | return success(); | |||
46 | } | |||
47 | ||||
48 | // Custom printer for SignednessSemantics. | |||
49 | static void printSignedness(AsmPrinter &printer, | |||
50 | const TestIntegerType::SignednessSemantics &ss) { | |||
51 | switch (ss) { | |||
52 | case TestIntegerType::SignednessSemantics::Unsigned: | |||
53 | printer << "unsigned"; | |||
54 | break; | |||
55 | case TestIntegerType::SignednessSemantics::Signed: | |||
56 | printer << "signed"; | |||
57 | break; | |||
58 | case TestIntegerType::SignednessSemantics::Signless: | |||
59 | printer << "none"; | |||
60 | break; | |||
61 | } | |||
62 | } | |||
63 | ||||
64 | // The functions don't need to be in the header file, but need to be in the mlir | |||
65 | // namespace. Declare them here, then define them immediately below. Separating | |||
66 | // the declaration and definition adheres to the LLVM coding standards. | |||
67 | namespace test { | |||
68 | // FieldInfo is used as part of a parameter, so equality comparison is | |||
69 | // compulsory. | |||
70 | static bool operator==(const FieldInfo &a, const FieldInfo &b); | |||
71 | // FieldInfo is used as part of a parameter, so a hash will be computed. | |||
72 | static llvm::hash_code hash_value(const FieldInfo &fi); // NOLINT | |||
73 | } // namespace test | |||
74 | ||||
75 | // FieldInfo is used as part of a parameter, so equality comparison is | |||
76 | // compulsory. | |||
77 | static bool test::operator==(const FieldInfo &a, const FieldInfo &b) { | |||
78 | return a.name == b.name && a.type == b.type; | |||
79 | } | |||
80 | ||||
81 | // FieldInfo is used as part of a parameter, so a hash will be computed. | |||
82 | static llvm::hash_code test::hash_value(const FieldInfo &fi) { // NOLINT | |||
83 | return llvm::hash_combine(fi.name, fi.type); | |||
84 | } | |||
85 | ||||
86 | //===----------------------------------------------------------------------===// | |||
87 | // CompoundAType | |||
88 | //===----------------------------------------------------------------------===// | |||
89 | ||||
90 | Type CompoundAType::parse(AsmParser &parser) { | |||
91 | int widthOfSomething; | |||
92 | Type oneType; | |||
93 | SmallVector<int, 4> arrayOfInts; | |||
94 | if (parser.parseLess() || parser.parseInteger(widthOfSomething) || | |||
95 | parser.parseComma() || parser.parseType(oneType) || parser.parseComma() || | |||
96 | parser.parseLSquare()) | |||
97 | return Type(); | |||
98 | ||||
99 | int i; | |||
100 | while (!*parser.parseOptionalInteger(i)) { | |||
101 | arrayOfInts.push_back(i); | |||
102 | if (parser.parseOptionalComma()) | |||
103 | break; | |||
104 | } | |||
105 | ||||
106 | if (parser.parseRSquare() || parser.parseGreater()) | |||
107 | return Type(); | |||
108 | ||||
109 | return get(parser.getContext(), widthOfSomething, oneType, arrayOfInts); | |||
110 | } | |||
111 | void CompoundAType::print(AsmPrinter &printer) const { | |||
112 | printer << "<" << getWidthOfSomething() << ", " << getOneType() << ", ["; | |||
113 | auto intArray = getArrayOfInts(); | |||
114 | llvm::interleaveComma(intArray, printer); | |||
115 | printer << "]>"; | |||
116 | } | |||
117 | ||||
118 | //===----------------------------------------------------------------------===// | |||
119 | // TestIntegerType | |||
120 | //===----------------------------------------------------------------------===// | |||
121 | ||||
122 | // Example type validity checker. | |||
123 | LogicalResult | |||
124 | TestIntegerType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
125 | unsigned width, | |||
126 | TestIntegerType::SignednessSemantics ss) { | |||
127 | if (width > 8) | |||
128 | return failure(); | |||
129 | return success(); | |||
130 | } | |||
131 | ||||
132 | //===----------------------------------------------------------------------===// | |||
133 | // TestType | |||
134 | //===----------------------------------------------------------------------===// | |||
135 | ||||
136 | void TestType::printTypeC(Location loc) const { | |||
137 | emitRemark(loc) << *this << " - TestC"; | |||
138 | } | |||
139 | ||||
140 | //===----------------------------------------------------------------------===// | |||
141 | // TestTypeWithLayout | |||
142 | //===----------------------------------------------------------------------===// | |||
143 | ||||
144 | Type TestTypeWithLayoutType::parse(AsmParser &parser) { | |||
145 | unsigned val; | |||
146 | if (parser.parseLess() || parser.parseInteger(val) || parser.parseGreater()) | |||
147 | return Type(); | |||
148 | return TestTypeWithLayoutType::get(parser.getContext(), val); | |||
149 | } | |||
150 | ||||
151 | void TestTypeWithLayoutType::print(AsmPrinter &printer) const { | |||
152 | printer << "<" << getKey() << ">"; | |||
153 | } | |||
154 | ||||
155 | unsigned | |||
156 | TestTypeWithLayoutType::getTypeSizeInBits(const DataLayout &dataLayout, | |||
157 | DataLayoutEntryListRef params) const { | |||
158 | return extractKind(params, "size"); | |||
159 | } | |||
160 | ||||
161 | unsigned | |||
162 | TestTypeWithLayoutType::getABIAlignment(const DataLayout &dataLayout, | |||
163 | DataLayoutEntryListRef params) const { | |||
164 | return extractKind(params, "alignment"); | |||
165 | } | |||
166 | ||||
167 | unsigned TestTypeWithLayoutType::getPreferredAlignment( | |||
168 | const DataLayout &dataLayout, DataLayoutEntryListRef params) const { | |||
169 | return extractKind(params, "preferred"); | |||
170 | } | |||
171 | ||||
172 | bool TestTypeWithLayoutType::areCompatible( | |||
173 | DataLayoutEntryListRef oldLayout, DataLayoutEntryListRef newLayout) const { | |||
174 | unsigned old = extractKind(oldLayout, "alignment"); | |||
175 | return old == 1 || extractKind(newLayout, "alignment") <= old; | |||
176 | } | |||
177 | ||||
178 | LogicalResult | |||
179 | TestTypeWithLayoutType::verifyEntries(DataLayoutEntryListRef params, | |||
180 | Location loc) const { | |||
181 | for (DataLayoutEntryInterface entry : params) { | |||
182 | // This is for testing purposes only, so assert well-formedness. | |||
183 | assert(entry.isTypeEntry() && "unexpected identifier entry")(static_cast <bool> (entry.isTypeEntry() && "unexpected identifier entry" ) ? void (0) : __assert_fail ("entry.isTypeEntry() && \"unexpected identifier entry\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 183, __extension__ __PRETTY_FUNCTION__)); | |||
184 | assert(entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() &&(static_cast <bool> (entry.getKey().get<Type>().isa <TestTypeWithLayoutType>() && "wrong type passed in" ) ? void (0) : __assert_fail ("entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() && \"wrong type passed in\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 185, __extension__ __PRETTY_FUNCTION__)) | |||
185 | "wrong type passed in")(static_cast <bool> (entry.getKey().get<Type>().isa <TestTypeWithLayoutType>() && "wrong type passed in" ) ? void (0) : __assert_fail ("entry.getKey().get<Type>().isa<TestTypeWithLayoutType>() && \"wrong type passed in\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 185, __extension__ __PRETTY_FUNCTION__)); | |||
186 | auto array = entry.getValue().dyn_cast<ArrayAttr>(); | |||
187 | assert(array && array.getValue().size() == 2 &&(static_cast <bool> (array && array.getValue(). size() == 2 && "expected array of two elements") ? void (0) : __assert_fail ("array && array.getValue().size() == 2 && \"expected array of two elements\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 188, __extension__ __PRETTY_FUNCTION__)) | |||
188 | "expected array of two elements")(static_cast <bool> (array && array.getValue(). size() == 2 && "expected array of two elements") ? void (0) : __assert_fail ("array && array.getValue().size() == 2 && \"expected array of two elements\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 188, __extension__ __PRETTY_FUNCTION__)); | |||
189 | auto kind = array.getValue().front().dyn_cast<StringAttr>(); | |||
190 | (void)kind; | |||
191 | assert(kind &&(static_cast <bool> (kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind") ? void (0) : __assert_fail ("kind && (kind.getValue() == \"size\" || kind.getValue() == \"alignment\" || kind.getValue() == \"preferred\") && \"unexpected kind\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 194, __extension__ __PRETTY_FUNCTION__)) | |||
192 | (kind.getValue() == "size" || kind.getValue() == "alignment" ||(static_cast <bool> (kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind") ? void (0) : __assert_fail ("kind && (kind.getValue() == \"size\" || kind.getValue() == \"alignment\" || kind.getValue() == \"preferred\") && \"unexpected kind\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 194, __extension__ __PRETTY_FUNCTION__)) | |||
193 | kind.getValue() == "preferred") &&(static_cast <bool> (kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind") ? void (0) : __assert_fail ("kind && (kind.getValue() == \"size\" || kind.getValue() == \"alignment\" || kind.getValue() == \"preferred\") && \"unexpected kind\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 194, __extension__ __PRETTY_FUNCTION__)) | |||
194 | "unexpected kind")(static_cast <bool> (kind && (kind.getValue() == "size" || kind.getValue() == "alignment" || kind.getValue() == "preferred") && "unexpected kind") ? void (0) : __assert_fail ("kind && (kind.getValue() == \"size\" || kind.getValue() == \"alignment\" || kind.getValue() == \"preferred\") && \"unexpected kind\"" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 194, __extension__ __PRETTY_FUNCTION__)); | |||
195 | assert(array.getValue().back().isa<IntegerAttr>())(static_cast <bool> (array.getValue().back().isa<IntegerAttr >()) ? void (0) : __assert_fail ("array.getValue().back().isa<IntegerAttr>()" , "mlir/test/lib/Dialect/Test/TestTypes.cpp", 195, __extension__ __PRETTY_FUNCTION__)); | |||
196 | } | |||
197 | return success(); | |||
198 | } | |||
199 | ||||
200 | unsigned TestTypeWithLayoutType::extractKind(DataLayoutEntryListRef params, | |||
201 | StringRef expectedKind) const { | |||
202 | for (DataLayoutEntryInterface entry : params) { | |||
203 | ArrayRef<Attribute> pair = entry.getValue().cast<ArrayAttr>().getValue(); | |||
204 | StringRef kind = pair.front().cast<StringAttr>().getValue(); | |||
205 | if (kind == expectedKind) | |||
206 | return pair.back().cast<IntegerAttr>().getValue().getZExtValue(); | |||
207 | } | |||
208 | return 1; | |||
209 | } | |||
210 | ||||
211 | //===----------------------------------------------------------------------===// | |||
212 | // Tablegen Generated Definitions | |||
213 | //===----------------------------------------------------------------------===// | |||
214 | ||||
215 | #define GET_TYPEDEF_CLASSES | |||
216 | #include "TestTypeDefs.cpp.inc" | |||
217 | ||||
218 | //===----------------------------------------------------------------------===// | |||
219 | // TestDialect | |||
220 | //===----------------------------------------------------------------------===// | |||
221 | ||||
222 | namespace { | |||
223 | ||||
224 | struct PtrElementModel | |||
225 | : public LLVM::PointerElementTypeInterface::ExternalModel<PtrElementModel, | |||
226 | SimpleAType> {}; | |||
227 | } // namespace | |||
228 | ||||
229 | void TestDialect::registerTypes() { | |||
230 | addTypes<TestRecursiveType, | |||
231 | #define GET_TYPEDEF_LIST | |||
232 | #include "TestTypeDefs.cpp.inc" | |||
233 | >(); | |||
234 | SimpleAType::attachInterface<PtrElementModel>(*getContext()); | |||
235 | } | |||
236 | ||||
237 | static Type parseTestType(AsmParser &parser, SetVector<Type> &stack) { | |||
238 | StringRef typeTag; | |||
239 | if (failed(parser.parseKeyword(&typeTag))) | |||
240 | return Type(); | |||
241 | ||||
242 | { | |||
243 | Type genType; | |||
244 | auto parseResult = generatedTypeParser(parser, typeTag, genType); | |||
245 | if (parseResult.hasValue()) | |||
246 | return genType; | |||
247 | } | |||
248 | ||||
249 | if (typeTag != "test_rec") { | |||
250 | parser.emitError(parser.getNameLoc()) << "unknown type!"; | |||
251 | return Type(); | |||
252 | } | |||
253 | ||||
254 | StringRef name; | |||
255 | if (parser.parseLess() || parser.parseKeyword(&name)) | |||
256 | return Type(); | |||
257 | auto rec = TestRecursiveType::get(parser.getContext(), name); | |||
258 | ||||
259 | // If this type already has been parsed above in the stack, expect just the | |||
260 | // name. | |||
261 | if (stack.contains(rec)) { | |||
262 | if (failed(parser.parseGreater())) | |||
263 | return Type(); | |||
264 | return rec; | |||
265 | } | |||
266 | ||||
267 | // Otherwise, parse the body and update the type. | |||
268 | if (failed(parser.parseComma())) | |||
269 | return Type(); | |||
270 | stack.insert(rec); | |||
271 | Type subtype = parseTestType(parser, stack); | |||
272 | stack.pop_back(); | |||
273 | if (!subtype || failed(parser.parseGreater()) || failed(rec.setBody(subtype))) | |||
274 | return Type(); | |||
275 | ||||
276 | return rec; | |||
277 | } | |||
278 | ||||
279 | Type TestDialect::parseType(DialectAsmParser &parser) const { | |||
280 | SetVector<Type> stack; | |||
281 | return parseTestType(parser, stack); | |||
| ||||
282 | } | |||
283 | ||||
284 | static void printTestType(Type type, AsmPrinter &printer, | |||
285 | SetVector<Type> &stack) { | |||
286 | if (succeeded(generatedTypePrinter(type, printer))) | |||
287 | return; | |||
288 | ||||
289 | auto rec = type.cast<TestRecursiveType>(); | |||
290 | printer << "test_rec<" << rec.getName(); | |||
291 | if (!stack.contains(rec)) { | |||
292 | printer << ", "; | |||
293 | stack.insert(rec); | |||
294 | printTestType(rec.getBody(), printer, stack); | |||
295 | stack.pop_back(); | |||
296 | } | |||
297 | printer << ">"; | |||
298 | } | |||
299 | ||||
300 | void TestDialect::printType(Type type, DialectAsmPrinter &printer) const { | |||
301 | SetVector<Type> stack; | |||
302 | printTestType(type, printer, stack); | |||
303 | } |
1 | /*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ | |||
2 | |* *| | |||
3 | |* TypeDef Definitions *| | |||
4 | |* *| | |||
5 | |* Automatically generated file, do not edit! *| | |||
6 | |* *| | |||
7 | \*===----------------------------------------------------------------------===*/ | |||
8 | ||||
9 | #ifdef GET_TYPEDEF_LIST | |||
10 | #undef GET_TYPEDEF_LIST | |||
11 | ||||
12 | ::test::CompoundNestedInnerType, | |||
13 | ::test::CompoundNestedOuterType, | |||
14 | ::test::CompoundNestedOuterQualType, | |||
15 | ::test::CompoundAType, | |||
16 | ::test::TestIntegerType, | |||
17 | ::test::SimpleAType, | |||
18 | ::test::StructType, | |||
19 | ::test::TestMemRefElementTypeType, | |||
20 | ::test::TestType, | |||
21 | ::test::TestTypeNoParserType, | |||
22 | ::test::TestStructTypeCaptureAllType, | |||
23 | ::test::TestTypeWithFormatType, | |||
24 | ::test::TestTypeWithLayoutType, | |||
25 | ::test::TestTypeWithTraitType | |||
26 | ||||
27 | #endif // GET_TYPEDEF_LIST | |||
28 | ||||
29 | #ifdef GET_TYPEDEF_CLASSES | |||
30 | #undef GET_TYPEDEF_CLASSES | |||
31 | ||||
32 | static ::mlir::OptionalParseResult generatedTypeParser(::mlir::AsmParser &parser, ::llvm::StringRef mnemonic, ::mlir::Type &value) { | |||
33 | if (mnemonic == ::test::CompoundNestedInnerType::getMnemonic()) { | |||
34 | value = ::test::CompoundNestedInnerType::parse(parser); | |||
35 | return ::mlir::success(!!value); | |||
36 | } | |||
37 | if (mnemonic == ::test::CompoundNestedOuterType::getMnemonic()) { | |||
38 | value = ::test::CompoundNestedOuterType::parse(parser); | |||
39 | return ::mlir::success(!!value); | |||
40 | } | |||
41 | if (mnemonic == ::test::CompoundNestedOuterQualType::getMnemonic()) { | |||
42 | value = ::test::CompoundNestedOuterQualType::parse(parser); | |||
43 | return ::mlir::success(!!value); | |||
44 | } | |||
45 | if (mnemonic == ::test::CompoundAType::getMnemonic()) { | |||
46 | value = ::test::CompoundAType::parse(parser); | |||
47 | return ::mlir::success(!!value); | |||
48 | } | |||
49 | if (mnemonic == ::test::TestIntegerType::getMnemonic()) { | |||
50 | value = ::test::TestIntegerType::parse(parser); | |||
51 | return ::mlir::success(!!value); | |||
52 | } | |||
53 | if (mnemonic == ::test::SimpleAType::getMnemonic()) { | |||
54 | value = ::test::SimpleAType::get(parser.getContext()); | |||
55 | return ::mlir::success(!!value); | |||
56 | } | |||
57 | if (mnemonic == ::test::StructType::getMnemonic()) { | |||
58 | value = ::test::StructType::parse(parser); | |||
59 | return ::mlir::success(!!value); | |||
60 | } | |||
61 | if (mnemonic == ::test::TestMemRefElementTypeType::getMnemonic()) { | |||
62 | value = ::test::TestMemRefElementTypeType::get(parser.getContext()); | |||
63 | return ::mlir::success(!!value); | |||
64 | } | |||
65 | if (mnemonic == ::test::TestType::getMnemonic()) { | |||
66 | value = ::test::TestType::get(parser.getContext()); | |||
67 | return ::mlir::success(!!value); | |||
68 | } | |||
69 | if (mnemonic == ::test::TestTypeNoParserType::getMnemonic()) { | |||
70 | value = ::test::TestTypeNoParserType::parse(parser); | |||
71 | return ::mlir::success(!!value); | |||
72 | } | |||
73 | if (mnemonic == ::test::TestStructTypeCaptureAllType::getMnemonic()) { | |||
74 | value = ::test::TestStructTypeCaptureAllType::parse(parser); | |||
75 | return ::mlir::success(!!value); | |||
76 | } | |||
77 | if (mnemonic == ::test::TestTypeWithFormatType::getMnemonic()) { | |||
78 | value = ::test::TestTypeWithFormatType::parse(parser); | |||
79 | return ::mlir::success(!!value); | |||
80 | } | |||
81 | if (mnemonic == ::test::TestTypeWithLayoutType::getMnemonic()) { | |||
82 | value = ::test::TestTypeWithLayoutType::parse(parser); | |||
83 | return ::mlir::success(!!value); | |||
84 | } | |||
85 | if (mnemonic == ::test::TestTypeWithTraitType::getMnemonic()) { | |||
86 | value = ::test::TestTypeWithTraitType::get(parser.getContext()); | |||
87 | return ::mlir::success(!!value); | |||
88 | } | |||
89 | return {}; | |||
90 | } | |||
91 | ||||
92 | static ::mlir::LogicalResult generatedTypePrinter(::mlir::Type def, ::mlir::AsmPrinter &printer) { | |||
93 | return ::llvm::TypeSwitch<::mlir::Type, ::mlir::LogicalResult>(def) .Case<::test::CompoundNestedInnerType>([&](auto t) { | |||
94 | printer << ::test::CompoundNestedInnerType::getMnemonic(); | |||
95 | t.print(printer); | |||
96 | return ::mlir::success(); | |||
97 | }) | |||
98 | .Case<::test::CompoundNestedOuterType>([&](auto t) { | |||
99 | printer << ::test::CompoundNestedOuterType::getMnemonic(); | |||
100 | t.print(printer); | |||
101 | return ::mlir::success(); | |||
102 | }) | |||
103 | .Case<::test::CompoundNestedOuterQualType>([&](auto t) { | |||
104 | printer << ::test::CompoundNestedOuterQualType::getMnemonic(); | |||
105 | t.print(printer); | |||
106 | return ::mlir::success(); | |||
107 | }) | |||
108 | .Case<::test::CompoundAType>([&](auto t) { | |||
109 | printer << ::test::CompoundAType::getMnemonic(); | |||
110 | t.print(printer); | |||
111 | return ::mlir::success(); | |||
112 | }) | |||
113 | .Case<::test::TestIntegerType>([&](auto t) { | |||
114 | printer << ::test::TestIntegerType::getMnemonic(); | |||
115 | t.print(printer); | |||
116 | return ::mlir::success(); | |||
117 | }) | |||
118 | .Case<::test::SimpleAType>([&](auto t) { | |||
119 | printer << ::test::SimpleAType::getMnemonic(); | |||
120 | return ::mlir::success(); | |||
121 | }) | |||
122 | .Case<::test::StructType>([&](auto t) { | |||
123 | printer << ::test::StructType::getMnemonic(); | |||
124 | t.print(printer); | |||
125 | return ::mlir::success(); | |||
126 | }) | |||
127 | .Case<::test::TestMemRefElementTypeType>([&](auto t) { | |||
128 | printer << ::test::TestMemRefElementTypeType::getMnemonic(); | |||
129 | return ::mlir::success(); | |||
130 | }) | |||
131 | .Case<::test::TestType>([&](auto t) { | |||
132 | printer << ::test::TestType::getMnemonic(); | |||
133 | return ::mlir::success(); | |||
134 | }) | |||
135 | .Case<::test::TestTypeNoParserType>([&](auto t) { | |||
136 | printer << ::test::TestTypeNoParserType::getMnemonic(); | |||
137 | t.print(printer); | |||
138 | return ::mlir::success(); | |||
139 | }) | |||
140 | .Case<::test::TestStructTypeCaptureAllType>([&](auto t) { | |||
141 | printer << ::test::TestStructTypeCaptureAllType::getMnemonic(); | |||
142 | t.print(printer); | |||
143 | return ::mlir::success(); | |||
144 | }) | |||
145 | .Case<::test::TestTypeWithFormatType>([&](auto t) { | |||
146 | printer << ::test::TestTypeWithFormatType::getMnemonic(); | |||
147 | t.print(printer); | |||
148 | return ::mlir::success(); | |||
149 | }) | |||
150 | .Case<::test::TestTypeWithLayoutType>([&](auto t) { | |||
151 | printer << ::test::TestTypeWithLayoutType::getMnemonic(); | |||
152 | t.print(printer); | |||
153 | return ::mlir::success(); | |||
154 | }) | |||
155 | .Case<::test::TestTypeWithTraitType>([&](auto t) { | |||
156 | printer << ::test::TestTypeWithTraitType::getMnemonic(); | |||
157 | return ::mlir::success(); | |||
158 | }) | |||
159 | .Default([](auto) { return ::mlir::failure(); }); | |||
160 | } | |||
161 | ||||
162 | namespace test { | |||
163 | namespace detail { | |||
164 | struct CompoundNestedInnerTypeStorage : public ::mlir::TypeStorage { | |||
165 | using KeyTy = std::tuple<int, ::test::CompoundAType>; | |||
166 | CompoundNestedInnerTypeStorage(int some_int, ::test::CompoundAType cmpdA) : some_int(some_int), cmpdA(cmpdA) {} | |||
167 | ||||
168 | bool operator==(const KeyTy &tblgenKey) const { | |||
169 | return (some_int == std::get<0>(tblgenKey)) && (cmpdA == std::get<1>(tblgenKey)); | |||
170 | } | |||
171 | ||||
172 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
173 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); | |||
174 | } | |||
175 | ||||
176 | static CompoundNestedInnerTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
177 | auto some_int = std::get<0>(tblgenKey); | |||
178 | auto cmpdA = std::get<1>(tblgenKey); | |||
179 | return new (allocator.allocate<CompoundNestedInnerTypeStorage>()) CompoundNestedInnerTypeStorage(some_int, cmpdA); | |||
180 | } | |||
181 | ||||
182 | int some_int; | |||
183 | ::test::CompoundAType cmpdA; | |||
184 | }; | |||
185 | } // namespace detail | |||
186 | CompoundNestedInnerType CompoundNestedInnerType::get(::mlir::MLIRContext *context, int some_int, ::test::CompoundAType cmpdA) { | |||
187 | return Base::get(context, some_int, cmpdA); | |||
188 | } | |||
189 | ||||
190 | ::mlir::Type CompoundNestedInnerType::parse(::mlir::AsmParser &parser) { | |||
191 | ::mlir::FailureOr<int> _result_some_int; | |||
192 | ::mlir::FailureOr<::test::CompoundAType> _result_cmpdA; | |||
193 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
194 | (void) loc; | |||
195 | // Parse literal '<' | |||
196 | if (parser.parseLess()) | |||
197 | return {}; | |||
198 | ||||
199 | // Parse variable 'some_int' | |||
200 | _result_some_int = ::mlir::FieldParser<int>::parse(parser); | |||
201 | if (failed(_result_some_int)) { | |||
202 | parser.emitError(parser.getCurrentLocation(), "failed to parse CompoundNestedInnerType parameter 'some_int' which is to be a `int`"); | |||
203 | return {}; | |||
204 | } | |||
205 | ||||
206 | // Parse variable 'cmpdA' | |||
207 | _result_cmpdA = ::mlir::FieldParser<::test::CompoundAType>::parse(parser); | |||
208 | if (failed(_result_cmpdA)) { | |||
209 | parser.emitError(parser.getCurrentLocation(), "failed to parse CompoundNestedInnerType parameter 'cmpdA' which is to be a `::test::CompoundAType`"); | |||
210 | return {}; | |||
211 | } | |||
212 | // Parse literal '>' | |||
213 | if (parser.parseGreater()) | |||
214 | return {}; | |||
215 | return CompoundNestedInnerType::get(parser.getContext(), | |||
216 | _result_some_int.getValue(), | |||
217 | _result_cmpdA.getValue()); | |||
218 | } | |||
219 | ||||
220 | void CompoundNestedInnerType::print(::mlir::AsmPrinter &printer) const { | |||
221 | printer << "<"; | |||
222 | printer.printStrippedAttrOrType(getSome_int()); | |||
223 | printer << ' '; | |||
224 | printer.printStrippedAttrOrType(getCmpdA()); | |||
225 | printer << ">"; | |||
226 | } | |||
227 | ||||
228 | int CompoundNestedInnerType::getSome_int() const { | |||
229 | return getImpl()->some_int; | |||
230 | } | |||
231 | ||||
232 | ::test::CompoundAType CompoundNestedInnerType::getCmpdA() const { | |||
233 | return getImpl()->cmpdA; | |||
234 | } | |||
235 | ||||
236 | } // namespace test | |||
237 | DEFINE_EXPLICIT_TYPE_ID(::test::CompoundNestedInnerType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::CompoundNestedInnerType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
238 | namespace test { | |||
239 | namespace detail { | |||
240 | struct CompoundNestedOuterTypeStorage : public ::mlir::TypeStorage { | |||
241 | using KeyTy = std::tuple<::test::CompoundNestedInnerType>; | |||
242 | CompoundNestedOuterTypeStorage(::test::CompoundNestedInnerType inner) : inner(inner) {} | |||
243 | ||||
244 | bool operator==(const KeyTy &tblgenKey) const { | |||
245 | return (inner == std::get<0>(tblgenKey)); | |||
246 | } | |||
247 | ||||
248 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
249 | return ::llvm::hash_combine(std::get<0>(tblgenKey)); | |||
250 | } | |||
251 | ||||
252 | static CompoundNestedOuterTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
253 | auto inner = std::get<0>(tblgenKey); | |||
254 | return new (allocator.allocate<CompoundNestedOuterTypeStorage>()) CompoundNestedOuterTypeStorage(inner); | |||
255 | } | |||
256 | ||||
257 | ::test::CompoundNestedInnerType inner; | |||
258 | }; | |||
259 | } // namespace detail | |||
260 | CompoundNestedOuterType CompoundNestedOuterType::get(::mlir::MLIRContext *context, ::test::CompoundNestedInnerType inner) { | |||
261 | return Base::get(context, inner); | |||
262 | } | |||
263 | ||||
264 | ::mlir::Type CompoundNestedOuterType::parse(::mlir::AsmParser &parser) { | |||
265 | ::mlir::FailureOr<::test::CompoundNestedInnerType> _result_inner; | |||
266 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
267 | (void) loc; | |||
268 | // Parse literal '<' | |||
269 | if (parser.parseLess()) | |||
270 | return {}; | |||
271 | // Parse literal 'i' | |||
272 | if (parser.parseKeyword("i")) | |||
273 | return {}; | |||
274 | ||||
275 | // Parse variable 'inner' | |||
276 | _result_inner = ::mlir::FieldParser<::test::CompoundNestedInnerType>::parse(parser); | |||
277 | if (failed(_result_inner)) { | |||
278 | parser.emitError(parser.getCurrentLocation(), "failed to parse CompoundNestedOuterType parameter 'inner' which is to be a `::test::CompoundNestedInnerType`"); | |||
279 | return {}; | |||
280 | } | |||
281 | // Parse literal '>' | |||
282 | if (parser.parseGreater()) | |||
283 | return {}; | |||
284 | return CompoundNestedOuterType::get(parser.getContext(), | |||
285 | _result_inner.getValue()); | |||
286 | } | |||
287 | ||||
288 | void CompoundNestedOuterType::print(::mlir::AsmPrinter &printer) const { | |||
289 | printer << "<"; | |||
290 | printer << "i"; | |||
291 | printer << ' '; | |||
292 | printer.printStrippedAttrOrType(getInner()); | |||
293 | printer << ">"; | |||
294 | } | |||
295 | ||||
296 | ::test::CompoundNestedInnerType CompoundNestedOuterType::getInner() const { | |||
297 | return getImpl()->inner; | |||
298 | } | |||
299 | ||||
300 | } // namespace test | |||
301 | DEFINE_EXPLICIT_TYPE_ID(::test::CompoundNestedOuterType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::CompoundNestedOuterType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
302 | namespace test { | |||
303 | namespace detail { | |||
304 | struct CompoundNestedOuterQualTypeStorage : public ::mlir::TypeStorage { | |||
305 | using KeyTy = std::tuple<::test::CompoundNestedInnerType>; | |||
306 | CompoundNestedOuterQualTypeStorage(::test::CompoundNestedInnerType inner) : inner(inner) {} | |||
307 | ||||
308 | bool operator==(const KeyTy &tblgenKey) const { | |||
309 | return (inner == std::get<0>(tblgenKey)); | |||
310 | } | |||
311 | ||||
312 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
313 | return ::llvm::hash_combine(std::get<0>(tblgenKey)); | |||
314 | } | |||
315 | ||||
316 | static CompoundNestedOuterQualTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
317 | auto inner = std::get<0>(tblgenKey); | |||
318 | return new (allocator.allocate<CompoundNestedOuterQualTypeStorage>()) CompoundNestedOuterQualTypeStorage(inner); | |||
319 | } | |||
320 | ||||
321 | ::test::CompoundNestedInnerType inner; | |||
322 | }; | |||
323 | } // namespace detail | |||
324 | CompoundNestedOuterQualType CompoundNestedOuterQualType::get(::mlir::MLIRContext *context, ::test::CompoundNestedInnerType inner) { | |||
325 | return Base::get(context, inner); | |||
326 | } | |||
327 | ||||
328 | ::mlir::Type CompoundNestedOuterQualType::parse(::mlir::AsmParser &parser) { | |||
329 | ::mlir::FailureOr<::test::CompoundNestedInnerType> _result_inner; | |||
330 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
331 | (void) loc; | |||
332 | // Parse literal '<' | |||
333 | if (parser.parseLess()) | |||
334 | return {}; | |||
335 | // Parse literal 'i' | |||
336 | if (parser.parseKeyword("i")) | |||
337 | return {}; | |||
338 | ||||
339 | // Parse variable 'inner' | |||
340 | _result_inner = ::mlir::FieldParser<::test::CompoundNestedInnerType>::parse(parser); | |||
341 | if (failed(_result_inner)) { | |||
342 | parser.emitError(parser.getCurrentLocation(), "failed to parse CompoundNestedOuterTypeQual parameter 'inner' which is to be a `::test::CompoundNestedInnerType`"); | |||
343 | return {}; | |||
344 | } | |||
345 | // Parse literal '>' | |||
346 | if (parser.parseGreater()) | |||
347 | return {}; | |||
348 | return CompoundNestedOuterQualType::get(parser.getContext(), | |||
349 | _result_inner.getValue()); | |||
350 | } | |||
351 | ||||
352 | void CompoundNestedOuterQualType::print(::mlir::AsmPrinter &printer) const { | |||
353 | printer << "<"; | |||
354 | printer << "i"; | |||
355 | printer << ' '; | |||
356 | printer << getInner(); | |||
357 | printer << ">"; | |||
358 | } | |||
359 | ||||
360 | ::test::CompoundNestedInnerType CompoundNestedOuterQualType::getInner() const { | |||
361 | return getImpl()->inner; | |||
362 | } | |||
363 | ||||
364 | } // namespace test | |||
365 | DEFINE_EXPLICIT_TYPE_ID(::test::CompoundNestedOuterQualType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::CompoundNestedOuterQualType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
366 | namespace test { | |||
367 | namespace detail { | |||
368 | struct CompoundATypeStorage : public ::mlir::TypeStorage { | |||
369 | using KeyTy = std::tuple<int, ::mlir::Type, ::llvm::ArrayRef<int>>; | |||
370 | CompoundATypeStorage(int widthOfSomething, ::mlir::Type oneType, ::llvm::ArrayRef<int> arrayOfInts) : widthOfSomething(widthOfSomething), oneType(oneType), arrayOfInts(arrayOfInts) {} | |||
371 | ||||
372 | bool operator==(const KeyTy &tblgenKey) const { | |||
373 | return (widthOfSomething == std::get<0>(tblgenKey)) && (oneType == std::get<1>(tblgenKey)) && (arrayOfInts == std::get<2>(tblgenKey)); | |||
374 | } | |||
375 | ||||
376 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
377 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), std::get<2>(tblgenKey)); | |||
378 | } | |||
379 | ||||
380 | static CompoundATypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
381 | auto widthOfSomething = std::get<0>(tblgenKey); | |||
382 | auto oneType = std::get<1>(tblgenKey); | |||
383 | auto arrayOfInts = std::get<2>(tblgenKey); | |||
384 | arrayOfInts = allocator.copyInto(arrayOfInts); | |||
385 | return new (allocator.allocate<CompoundATypeStorage>()) CompoundATypeStorage(widthOfSomething, oneType, arrayOfInts); | |||
386 | } | |||
387 | ||||
388 | int widthOfSomething; | |||
389 | ::mlir::Type oneType; | |||
390 | ::llvm::ArrayRef<int> arrayOfInts; | |||
391 | }; | |||
392 | } // namespace detail | |||
393 | CompoundAType CompoundAType::get(::mlir::MLIRContext *context, int widthOfSomething, ::mlir::Type oneType, ::llvm::ArrayRef<int> arrayOfInts) { | |||
394 | return Base::get(context, widthOfSomething, oneType, arrayOfInts); | |||
395 | } | |||
396 | ||||
397 | int CompoundAType::getWidthOfSomething() const { | |||
398 | return getImpl()->widthOfSomething; | |||
399 | } | |||
400 | ||||
401 | ::mlir::Type CompoundAType::getOneType() const { | |||
402 | return getImpl()->oneType; | |||
403 | } | |||
404 | ||||
405 | ::llvm::ArrayRef<int> CompoundAType::getArrayOfInts() const { | |||
406 | return getImpl()->arrayOfInts; | |||
407 | } | |||
408 | ||||
409 | } // namespace test | |||
410 | DEFINE_EXPLICIT_TYPE_ID(::test::CompoundAType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::CompoundAType>() { static TypeID::Storage instance; return TypeID(&instance); } } } | |||
411 | namespace test { | |||
412 | namespace detail { | |||
413 | struct TestIntegerTypeStorage : public ::mlir::TypeStorage { | |||
414 | using KeyTy = std::tuple<unsigned, ::test::TestIntegerType::SignednessSemantics>; | |||
415 | TestIntegerTypeStorage(unsigned width, ::test::TestIntegerType::SignednessSemantics signedness) : width(width), signedness(signedness) {} | |||
416 | ||||
417 | bool operator==(const KeyTy &tblgenKey) const { | |||
418 | return (width == std::get<0>(tblgenKey)) && (signedness == std::get<1>(tblgenKey)); | |||
419 | } | |||
420 | ||||
421 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
422 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey)); | |||
423 | } | |||
424 | ||||
425 | static TestIntegerTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
426 | auto width = std::get<0>(tblgenKey); | |||
427 | auto signedness = std::get<1>(tblgenKey); | |||
428 | return new (allocator.allocate<TestIntegerTypeStorage>()) TestIntegerTypeStorage(width, signedness); | |||
429 | } | |||
430 | ||||
431 | unsigned width; | |||
432 | ::test::TestIntegerType::SignednessSemantics signedness; | |||
433 | }; | |||
434 | } // namespace detail | |||
435 | TestIntegerType TestIntegerType::get(::mlir::MLIRContext *context, unsigned width, SignednessSemantics signedness) { | |||
436 | return Base::get(context, width, signedness); | |||
437 | } | |||
438 | ||||
439 | TestIntegerType TestIntegerType::getChecked(::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError, ::mlir::MLIRContext *context, unsigned width, SignednessSemantics signedness) { | |||
440 | return Base::getChecked(emitError, context, width, signedness); | |||
441 | } | |||
442 | ||||
443 | ::mlir::Type TestIntegerType::parse(::mlir::AsmParser &parser) { | |||
444 | if (parser.parseLess()) return Type(); | |||
445 | SignednessSemantics signedness; | |||
446 | if (parseSignedness(parser, signedness)) return mlir::Type(); | |||
447 | if (parser.parseComma()) return Type(); | |||
448 | int width; | |||
449 | if (parser.parseInteger(width)) return Type(); | |||
450 | if (parser.parseGreater()) return Type(); | |||
451 | ::mlir::Location loc = parser.getEncodedSourceLoc(parser.getNameLoc()); | |||
452 | return getChecked(loc, loc.getContext(), width, signedness); | |||
| ||||
453 | } | |||
454 | ||||
455 | void TestIntegerType::print(::mlir::AsmPrinter &printer) const { | |||
456 | printer << "<"; | |||
457 | printSignedness(printer, getImpl()->signedness); | |||
458 | printer << ", " << getImpl()->width << ">"; | |||
459 | } | |||
460 | ||||
461 | unsigned TestIntegerType::getWidth() const { | |||
462 | return getImpl()->width; | |||
463 | } | |||
464 | ||||
465 | ::test::TestIntegerType::SignednessSemantics TestIntegerType::getSignedness() const { | |||
466 | return getImpl()->signedness; | |||
467 | } | |||
468 | ||||
469 | } // namespace test | |||
470 | DEFINE_EXPLICIT_TYPE_ID(::test::TestIntegerType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestIntegerType>() { static TypeID::Storage instance; return TypeID(&instance); } } } | |||
471 | namespace test { | |||
472 | } // namespace test | |||
473 | DEFINE_EXPLICIT_TYPE_ID(::test::SimpleAType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::SimpleAType>() { static TypeID::Storage instance; return TypeID(&instance); } } } | |||
474 | namespace test { | |||
475 | namespace detail { | |||
476 | struct StructTypeStorage : public ::mlir::TypeStorage { | |||
477 | using KeyTy = std::tuple<::llvm::ArrayRef<::test::FieldInfo>>; | |||
478 | StructTypeStorage(::llvm::ArrayRef<::test::FieldInfo> fields) : fields(fields) {} | |||
479 | ||||
480 | bool operator==(const KeyTy &tblgenKey) const { | |||
481 | return (fields == std::get<0>(tblgenKey)); | |||
482 | } | |||
483 | ||||
484 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
485 | return ::llvm::hash_combine(std::get<0>(tblgenKey)); | |||
486 | } | |||
487 | ||||
488 | static StructTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
489 | auto fields = std::get<0>(tblgenKey); | |||
490 | ||||
491 | llvm::SmallVector<::test::FieldInfo, 4> tmpFields; | |||
492 | for (size_t i = 0, e = fields.size(); i < e; ++i) | |||
493 | tmpFields.push_back(fields[i].allocateInto(allocator)); | |||
494 | fields = allocator.copyInto(ArrayRef<::test::FieldInfo>(tmpFields)); | |||
495 | ||||
496 | return new (allocator.allocate<StructTypeStorage>()) StructTypeStorage(fields); | |||
497 | } | |||
498 | ||||
499 | ::llvm::ArrayRef<::test::FieldInfo> fields; | |||
500 | }; | |||
501 | } // namespace detail | |||
502 | StructType StructType::get(::mlir::MLIRContext *context, ::llvm::ArrayRef<::test::FieldInfo> fields) { | |||
503 | return Base::get(context, fields); | |||
504 | } | |||
505 | ||||
506 | ::mlir::Type StructType::parse(::mlir::AsmParser &parser) { | |||
507 | llvm::SmallVector<FieldInfo, 4> parameters; | |||
508 | if (parser.parseLess()) return Type(); | |||
509 | while (mlir::succeeded(parser.parseOptionalLBrace())) { | |||
510 | llvm::StringRef name; | |||
511 | if (parser.parseKeyword(&name)) return Type(); | |||
512 | if (parser.parseComma()) return Type(); | |||
513 | Type type; | |||
514 | if (parser.parseType(type)) return Type(); | |||
515 | if (parser.parseRBrace()) return Type(); | |||
516 | parameters.push_back(FieldInfo {name, type}); | |||
517 | if (parser.parseOptionalComma()) break; | |||
518 | } | |||
519 | if (parser.parseGreater()) return Type(); | |||
520 | return get(parser.getContext(), parameters); | |||
521 | } | |||
522 | ||||
523 | void StructType::print(::mlir::AsmPrinter &printer) const { | |||
524 | printer << "<"; | |||
525 | for (size_t i=0, e = getImpl()->fields.size(); i < e; i++) { | |||
526 | const auto& field = getImpl()->fields[i]; | |||
527 | printer << "{" << field.name << "," << field.type << "}"; | |||
528 | if (i < getImpl()->fields.size() - 1) | |||
529 | printer << ","; | |||
530 | } | |||
531 | printer << ">"; | |||
532 | } | |||
533 | ||||
534 | ::llvm::ArrayRef<::test::FieldInfo> StructType::getFields() const { | |||
535 | return getImpl()->fields; | |||
536 | } | |||
537 | ||||
538 | } // namespace test | |||
539 | DEFINE_EXPLICIT_TYPE_ID(::test::StructType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::StructType>() { static TypeID::Storage instance; return TypeID (&instance); } } } | |||
540 | namespace test { | |||
541 | } // namespace test | |||
542 | DEFINE_EXPLICIT_TYPE_ID(::test::TestMemRefElementTypeType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestMemRefElementTypeType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
543 | namespace test { | |||
544 | } // namespace test | |||
545 | DEFINE_EXPLICIT_TYPE_ID(::test::TestType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestType>() { static TypeID::Storage instance; return TypeID (&instance); } } } | |||
546 | namespace test { | |||
547 | namespace detail { | |||
548 | struct TestTypeNoParserTypeStorage : public ::mlir::TypeStorage { | |||
549 | using KeyTy = std::tuple<uint32_t, ::llvm::ArrayRef<int64_t>, ::llvm::StringRef, ::test::CustomParam>; | |||
550 | TestTypeNoParserTypeStorage(uint32_t one, ::llvm::ArrayRef<int64_t> two, ::llvm::StringRef three, ::test::CustomParam four) : one(one), two(two), three(three), four(four) {} | |||
551 | ||||
552 | bool operator==(const KeyTy &tblgenKey) const { | |||
553 | return (one == std::get<0>(tblgenKey)) && (two == std::get<1>(tblgenKey)) && (three == std::get<2>(tblgenKey)) && (four == std::get<3>(tblgenKey)); | |||
554 | } | |||
555 | ||||
556 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
557 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), std::get<2>(tblgenKey), std::get<3>(tblgenKey)); | |||
558 | } | |||
559 | ||||
560 | static TestTypeNoParserTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
561 | auto one = std::get<0>(tblgenKey); | |||
562 | auto two = std::get<1>(tblgenKey); | |||
563 | auto three = std::get<2>(tblgenKey); | |||
564 | auto four = std::get<3>(tblgenKey); | |||
565 | two = allocator.copyInto(two); | |||
566 | three = allocator.copyInto(three); | |||
567 | return new (allocator.allocate<TestTypeNoParserTypeStorage>()) TestTypeNoParserTypeStorage(one, two, three, four); | |||
568 | } | |||
569 | ||||
570 | uint32_t one; | |||
571 | ::llvm::ArrayRef<int64_t> two; | |||
572 | ::llvm::StringRef three; | |||
573 | ::test::CustomParam four; | |||
574 | }; | |||
575 | } // namespace detail | |||
576 | TestTypeNoParserType TestTypeNoParserType::get(::mlir::MLIRContext *context, uint32_t one, ::llvm::ArrayRef<int64_t> two, ::llvm::StringRef three, ::test::CustomParam four) { | |||
577 | return Base::get(context, one, two, three, four); | |||
578 | } | |||
579 | ||||
580 | ::mlir::Type TestTypeNoParserType::parse(::mlir::AsmParser &parser) { | |||
581 | ::mlir::FailureOr<uint32_t> _result_one; | |||
582 | ::mlir::FailureOr<::llvm::SmallVector<int64_t>> _result_two; | |||
583 | ::mlir::FailureOr<std::string> _result_three; | |||
584 | ::mlir::FailureOr<::test::CustomParam> _result_four; | |||
585 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
586 | (void) loc; | |||
587 | // Parse literal '<' | |||
588 | if (parser.parseLess()) | |||
589 | return {}; | |||
590 | ||||
591 | // Parse variable 'one' | |||
592 | _result_one = ::mlir::FieldParser<uint32_t>::parse(parser); | |||
593 | if (failed(_result_one)) { | |||
594 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeNoParser parameter 'one' which is to be a `uint32_t`"); | |||
595 | return {}; | |||
596 | } | |||
597 | // Parse literal ',' | |||
598 | if (parser.parseComma()) | |||
599 | return {}; | |||
600 | // Parse literal '[' | |||
601 | if (parser.parseLSquare()) | |||
602 | return {}; | |||
603 | ||||
604 | // Parse variable 'two' | |||
605 | _result_two = ::mlir::FieldParser<::llvm::SmallVector<int64_t>>::parse(parser); | |||
606 | if (failed(_result_two)) { | |||
607 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeNoParser parameter 'two' which is to be a `::llvm::ArrayRef<int64_t>`"); | |||
608 | return {}; | |||
609 | } | |||
610 | // Parse literal ']' | |||
611 | if (parser.parseRSquare()) | |||
612 | return {}; | |||
613 | // Parse literal ',' | |||
614 | if (parser.parseComma()) | |||
615 | return {}; | |||
616 | ||||
617 | // Parse variable 'three' | |||
618 | _result_three = ::mlir::FieldParser<std::string>::parse(parser); | |||
619 | if (failed(_result_three)) { | |||
620 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeNoParser parameter 'three' which is to be a `::llvm::StringRef`"); | |||
621 | return {}; | |||
622 | } | |||
623 | // Parse literal ',' | |||
624 | if (parser.parseComma()) | |||
625 | return {}; | |||
626 | ||||
627 | // Parse variable 'four' | |||
628 | _result_four = ::mlir::FieldParser<::test::CustomParam>::parse(parser); | |||
629 | if (failed(_result_four)) { | |||
630 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeNoParser parameter 'four' which is to be a `::test::CustomParam`"); | |||
631 | return {}; | |||
632 | } | |||
633 | // Parse literal '>' | |||
634 | if (parser.parseGreater()) | |||
635 | return {}; | |||
636 | return TestTypeNoParserType::get(parser.getContext(), | |||
637 | _result_one.getValue(), | |||
638 | _result_two.getValue(), | |||
639 | _result_three.getValue(), | |||
640 | _result_four.getValue()); | |||
641 | } | |||
642 | ||||
643 | void TestTypeNoParserType::print(::mlir::AsmPrinter &printer) const { | |||
644 | printer << "<"; | |||
645 | printer.printStrippedAttrOrType(getOne()); | |||
646 | printer << ","; | |||
647 | printer << ' ' << "["; | |||
648 | printer.printStrippedAttrOrType(getTwo()); | |||
649 | printer << "]"; | |||
650 | printer << ","; | |||
651 | printer << ' '; | |||
652 | printer << '"' << getThree() << '"';; | |||
653 | printer << ","; | |||
654 | printer << ' '; | |||
655 | printer.printStrippedAttrOrType(getFour()); | |||
656 | printer << ">"; | |||
657 | } | |||
658 | ||||
659 | uint32_t TestTypeNoParserType::getOne() const { | |||
660 | return getImpl()->one; | |||
661 | } | |||
662 | ||||
663 | ::llvm::ArrayRef<int64_t> TestTypeNoParserType::getTwo() const { | |||
664 | return getImpl()->two; | |||
665 | } | |||
666 | ||||
667 | ::llvm::StringRef TestTypeNoParserType::getThree() const { | |||
668 | return getImpl()->three; | |||
669 | } | |||
670 | ||||
671 | ::test::CustomParam TestTypeNoParserType::getFour() const { | |||
672 | return getImpl()->four; | |||
673 | } | |||
674 | ||||
675 | } // namespace test | |||
676 | DEFINE_EXPLICIT_TYPE_ID(::test::TestTypeNoParserType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestTypeNoParserType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
677 | namespace test { | |||
678 | namespace detail { | |||
679 | struct TestStructTypeCaptureAllTypeStorage : public ::mlir::TypeStorage { | |||
680 | using KeyTy = std::tuple<int, int, int, int>; | |||
681 | TestStructTypeCaptureAllTypeStorage(int v0, int v1, int v2, int v3) : v0(v0), v1(v1), v2(v2), v3(v3) {} | |||
682 | ||||
683 | bool operator==(const KeyTy &tblgenKey) const { | |||
684 | return (v0 == std::get<0>(tblgenKey)) && (v1 == std::get<1>(tblgenKey)) && (v2 == std::get<2>(tblgenKey)) && (v3 == std::get<3>(tblgenKey)); | |||
685 | } | |||
686 | ||||
687 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
688 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), std::get<2>(tblgenKey), std::get<3>(tblgenKey)); | |||
689 | } | |||
690 | ||||
691 | static TestStructTypeCaptureAllTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
692 | auto v0 = std::get<0>(tblgenKey); | |||
693 | auto v1 = std::get<1>(tblgenKey); | |||
694 | auto v2 = std::get<2>(tblgenKey); | |||
695 | auto v3 = std::get<3>(tblgenKey); | |||
696 | return new (allocator.allocate<TestStructTypeCaptureAllTypeStorage>()) TestStructTypeCaptureAllTypeStorage(v0, v1, v2, v3); | |||
697 | } | |||
698 | ||||
699 | int v0; | |||
700 | int v1; | |||
701 | int v2; | |||
702 | int v3; | |||
703 | }; | |||
704 | } // namespace detail | |||
705 | TestStructTypeCaptureAllType TestStructTypeCaptureAllType::get(::mlir::MLIRContext *context, int v0, int v1, int v2, int v3) { | |||
706 | return Base::get(context, v0, v1, v2, v3); | |||
707 | } | |||
708 | ||||
709 | ::mlir::Type TestStructTypeCaptureAllType::parse(::mlir::AsmParser &parser) { | |||
710 | ::mlir::FailureOr<int> _result_v0; | |||
711 | ::mlir::FailureOr<int> _result_v1; | |||
712 | ::mlir::FailureOr<int> _result_v2; | |||
713 | ::mlir::FailureOr<int> _result_v3; | |||
714 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
715 | (void) loc; | |||
716 | // Parse literal '<' | |||
717 | if (parser.parseLess()) | |||
718 | return {}; | |||
719 | // Parse parameter struct | |||
720 | bool _seen_v0 = false; | |||
721 | bool _seen_v1 = false; | |||
722 | bool _seen_v2 = false; | |||
723 | bool _seen_v3 = false; | |||
724 | for (unsigned _index = 0; _index < 4; ++_index) { | |||
725 | StringRef _paramKey; | |||
726 | if (parser.parseKeyword(&_paramKey)) { | |||
727 | parser.emitError(parser.getCurrentLocation(), | |||
728 | "expected a parameter name in struct"); | |||
729 | return {}; | |||
730 | } | |||
731 | // Parse literal '=' | |||
732 | if (parser.parseEqual()) | |||
733 | return {}; | |||
734 | if (!_seen_v0 && _paramKey == "v0") { | |||
735 | _seen_v0 = true; | |||
736 | ||||
737 | // Parse variable 'v0' | |||
738 | _result_v0 = ::mlir::FieldParser<int>::parse(parser); | |||
739 | if (failed(_result_v0)) { | |||
740 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeStructCaptureAll parameter 'v0' which is to be a `int`"); | |||
741 | return {}; | |||
742 | } | |||
743 | } else if (!_seen_v1 && _paramKey == "v1") { | |||
744 | _seen_v1 = true; | |||
745 | ||||
746 | // Parse variable 'v1' | |||
747 | _result_v1 = ::mlir::FieldParser<int>::parse(parser); | |||
748 | if (failed(_result_v1)) { | |||
749 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeStructCaptureAll parameter 'v1' which is to be a `int`"); | |||
750 | return {}; | |||
751 | } | |||
752 | } else if (!_seen_v2 && _paramKey == "v2") { | |||
753 | _seen_v2 = true; | |||
754 | ||||
755 | // Parse variable 'v2' | |||
756 | _result_v2 = ::mlir::FieldParser<int>::parse(parser); | |||
757 | if (failed(_result_v2)) { | |||
758 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeStructCaptureAll parameter 'v2' which is to be a `int`"); | |||
759 | return {}; | |||
760 | } | |||
761 | } else if (!_seen_v3 && _paramKey == "v3") { | |||
762 | _seen_v3 = true; | |||
763 | ||||
764 | // Parse variable 'v3' | |||
765 | _result_v3 = ::mlir::FieldParser<int>::parse(parser); | |||
766 | if (failed(_result_v3)) { | |||
767 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeStructCaptureAll parameter 'v3' which is to be a `int`"); | |||
768 | return {}; | |||
769 | } | |||
770 | } else { | |||
771 | parser.emitError(parser.getCurrentLocation(), "duplicate or unknown struct parameter name: ") << _paramKey; | |||
772 | return {}; | |||
773 | } | |||
774 | if ((_index != 4 - 1) && parser.parseComma()) | |||
775 | return {}; | |||
776 | } | |||
777 | // Parse literal '>' | |||
778 | if (parser.parseGreater()) | |||
779 | return {}; | |||
780 | return TestStructTypeCaptureAllType::get(parser.getContext(), | |||
781 | _result_v0.getValue(), | |||
782 | _result_v1.getValue(), | |||
783 | _result_v2.getValue(), | |||
784 | _result_v3.getValue()); | |||
785 | } | |||
786 | ||||
787 | void TestStructTypeCaptureAllType::print(::mlir::AsmPrinter &printer) const { | |||
788 | printer << "<"; | |||
789 | printer << "v0"; | |||
790 | printer << ' ' << "="; | |||
791 | printer << ' '; | |||
792 | printer.printStrippedAttrOrType(getV0()); | |||
793 | printer << ","; | |||
794 | printer << ' ' << "v1"; | |||
795 | printer << ' ' << "="; | |||
796 | printer << ' '; | |||
797 | printer.printStrippedAttrOrType(getV1()); | |||
798 | printer << ","; | |||
799 | printer << ' ' << "v2"; | |||
800 | printer << ' ' << "="; | |||
801 | printer << ' '; | |||
802 | printer.printStrippedAttrOrType(getV2()); | |||
803 | printer << ","; | |||
804 | printer << ' ' << "v3"; | |||
805 | printer << ' ' << "="; | |||
806 | printer << ' '; | |||
807 | printer.printStrippedAttrOrType(getV3()); | |||
808 | printer << ">"; | |||
809 | } | |||
810 | ||||
811 | int TestStructTypeCaptureAllType::getV0() const { | |||
812 | return getImpl()->v0; | |||
813 | } | |||
814 | ||||
815 | int TestStructTypeCaptureAllType::getV1() const { | |||
816 | return getImpl()->v1; | |||
817 | } | |||
818 | ||||
819 | int TestStructTypeCaptureAllType::getV2() const { | |||
820 | return getImpl()->v2; | |||
821 | } | |||
822 | ||||
823 | int TestStructTypeCaptureAllType::getV3() const { | |||
824 | return getImpl()->v3; | |||
825 | } | |||
826 | ||||
827 | } // namespace test | |||
828 | DEFINE_EXPLICIT_TYPE_ID(::test::TestStructTypeCaptureAllType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestStructTypeCaptureAllType>() { static TypeID::Storage instance; return TypeID(&instance); } } } | |||
829 | namespace test { | |||
830 | namespace detail { | |||
831 | struct TestTypeWithFormatTypeStorage : public ::mlir::TypeStorage { | |||
832 | using KeyTy = std::tuple<int64_t, std::string, ::mlir::Attribute>; | |||
833 | TestTypeWithFormatTypeStorage(int64_t one, std::string two, ::mlir::Attribute three) : one(one), two(two), three(three) {} | |||
834 | ||||
835 | bool operator==(const KeyTy &tblgenKey) const { | |||
836 | return (one == std::get<0>(tblgenKey)) && (two == std::get<1>(tblgenKey)) && (three == std::get<2>(tblgenKey)); | |||
837 | } | |||
838 | ||||
839 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
840 | return ::llvm::hash_combine(std::get<0>(tblgenKey), std::get<1>(tblgenKey), std::get<2>(tblgenKey)); | |||
841 | } | |||
842 | ||||
843 | static TestTypeWithFormatTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
844 | auto one = std::get<0>(tblgenKey); | |||
845 | auto two = std::get<1>(tblgenKey); | |||
846 | auto three = std::get<2>(tblgenKey); | |||
847 | return new (allocator.allocate<TestTypeWithFormatTypeStorage>()) TestTypeWithFormatTypeStorage(one, two, three); | |||
848 | } | |||
849 | ||||
850 | int64_t one; | |||
851 | std::string two; | |||
852 | ::mlir::Attribute three; | |||
853 | }; | |||
854 | } // namespace detail | |||
855 | TestTypeWithFormatType TestTypeWithFormatType::get(::mlir::MLIRContext *context, int64_t one, std::string two, ::mlir::Attribute three) { | |||
856 | return Base::get(context, one, two, three); | |||
857 | } | |||
858 | ||||
859 | ::mlir::Type TestTypeWithFormatType::parse(::mlir::AsmParser &parser) { | |||
860 | ::mlir::FailureOr<int64_t> _result_one; | |||
861 | ::mlir::FailureOr<std::string> _result_two; | |||
862 | ::mlir::FailureOr<::mlir::Attribute> _result_three; | |||
863 | ::llvm::SMLoc loc = parser.getCurrentLocation(); | |||
864 | (void) loc; | |||
865 | // Parse literal '<' | |||
866 | if (parser.parseLess()) | |||
867 | return {}; | |||
868 | ||||
869 | // Parse variable 'one' | |||
870 | _result_one = ::mlir::FieldParser<int64_t>::parse(parser); | |||
871 | if (failed(_result_one)) { | |||
872 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeWithFormat parameter 'one' which is to be a `int64_t`"); | |||
873 | return {}; | |||
874 | } | |||
875 | // Parse literal ',' | |||
876 | if (parser.parseComma()) | |||
877 | return {}; | |||
878 | // Parse parameter struct | |||
879 | bool _seen_three = false; | |||
880 | bool _seen_two = false; | |||
881 | for (unsigned _index = 0; _index < 2; ++_index) { | |||
882 | StringRef _paramKey; | |||
883 | if (parser.parseKeyword(&_paramKey)) { | |||
884 | parser.emitError(parser.getCurrentLocation(), | |||
885 | "expected a parameter name in struct"); | |||
886 | return {}; | |||
887 | } | |||
888 | // Parse literal '=' | |||
889 | if (parser.parseEqual()) | |||
890 | return {}; | |||
891 | if (!_seen_three && _paramKey == "three") { | |||
892 | _seen_three = true; | |||
893 | ||||
894 | // Parse variable 'three' | |||
895 | _result_three = ::mlir::FieldParser<::mlir::Attribute>::parse(parser); | |||
896 | if (failed(_result_three)) { | |||
897 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeWithFormat parameter 'three' which is to be a `::mlir::Attribute`"); | |||
898 | return {}; | |||
899 | } | |||
900 | } else if (!_seen_two && _paramKey == "two") { | |||
901 | _seen_two = true; | |||
902 | ||||
903 | // Parse variable 'two' | |||
904 | _result_two = ::mlir::FieldParser<std::string>::parse(parser); | |||
905 | if (failed(_result_two)) { | |||
906 | parser.emitError(parser.getCurrentLocation(), "failed to parse TestTypeWithFormat parameter 'two' which is to be a `std::string`"); | |||
907 | return {}; | |||
908 | } | |||
909 | } else { | |||
910 | parser.emitError(parser.getCurrentLocation(), "duplicate or unknown struct parameter name: ") << _paramKey; | |||
911 | return {}; | |||
912 | } | |||
913 | if ((_index != 2 - 1) && parser.parseComma()) | |||
914 | return {}; | |||
915 | } | |||
916 | // Parse literal '>' | |||
917 | if (parser.parseGreater()) | |||
918 | return {}; | |||
919 | return TestTypeWithFormatType::get(parser.getContext(), | |||
920 | _result_one.getValue(), | |||
921 | _result_two.getValue(), | |||
922 | _result_three.getValue()); | |||
923 | } | |||
924 | ||||
925 | void TestTypeWithFormatType::print(::mlir::AsmPrinter &printer) const { | |||
926 | printer << "<"; | |||
927 | printer.printStrippedAttrOrType(getOne()); | |||
928 | printer << ","; | |||
929 | printer << ' ' << "three"; | |||
930 | printer << ' ' << "="; | |||
931 | printer << ' '; | |||
932 | printer.printStrippedAttrOrType(getThree()); | |||
933 | printer << ","; | |||
934 | printer << ' ' << "two"; | |||
935 | printer << ' ' << "="; | |||
936 | printer << ' '; | |||
937 | printer << '"' << getTwo() << '"'; | |||
938 | printer << ">"; | |||
939 | } | |||
940 | ||||
941 | int64_t TestTypeWithFormatType::getOne() const { | |||
942 | return getImpl()->one; | |||
943 | } | |||
944 | ||||
945 | llvm::StringRef TestTypeWithFormatType::getTwo() const { | |||
946 | return getImpl()->two; | |||
947 | } | |||
948 | ||||
949 | ::mlir::Attribute TestTypeWithFormatType::getThree() const { | |||
950 | return getImpl()->three; | |||
951 | } | |||
952 | ||||
953 | } // namespace test | |||
954 | DEFINE_EXPLICIT_TYPE_ID(::test::TestTypeWithFormatType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestTypeWithFormatType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
955 | namespace test { | |||
956 | namespace detail { | |||
957 | struct TestTypeWithLayoutTypeStorage : public ::mlir::TypeStorage { | |||
958 | using KeyTy = std::tuple<unsigned>; | |||
959 | TestTypeWithLayoutTypeStorage(unsigned key) : key(key) {} | |||
960 | ||||
961 | bool operator==(const KeyTy &tblgenKey) const { | |||
962 | return (key == std::get<0>(tblgenKey)); | |||
963 | } | |||
964 | ||||
965 | static ::llvm::hash_code hashKey(const KeyTy &tblgenKey) { | |||
966 | return ::llvm::hash_combine(std::get<0>(tblgenKey)); | |||
967 | } | |||
968 | ||||
969 | static TestTypeWithLayoutTypeStorage *construct(::mlir::TypeStorageAllocator &allocator, const KeyTy &tblgenKey) { | |||
970 | auto key = std::get<0>(tblgenKey); | |||
971 | return new (allocator.allocate<TestTypeWithLayoutTypeStorage>()) TestTypeWithLayoutTypeStorage(key); | |||
972 | } | |||
973 | ||||
974 | unsigned key; | |||
975 | }; | |||
976 | } // namespace detail | |||
977 | TestTypeWithLayoutType TestTypeWithLayoutType::get(::mlir::MLIRContext *context, unsigned key) { | |||
978 | return Base::get(context, key); | |||
979 | } | |||
980 | ||||
981 | unsigned TestTypeWithLayoutType::getKey() const { | |||
982 | return getImpl()->key; | |||
983 | } | |||
984 | ||||
985 | } // namespace test | |||
986 | DEFINE_EXPLICIT_TYPE_ID(::test::TestTypeWithLayoutType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestTypeWithLayoutType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
987 | namespace test { | |||
988 | } // namespace test | |||
989 | DEFINE_EXPLICIT_TYPE_ID(::test::TestTypeWithTraitType)namespace mlir { namespace detail { template <> __attribute__ ((visibility("default"))) TypeID TypeIDExported::get< ::test ::TestTypeWithTraitType>() { static TypeID::Storage instance ; return TypeID(&instance); } } } | |||
990 | ||||
991 | #endif // GET_TYPEDEF_CLASSES | |||
992 |
1 | //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file implements helper classes for implementing the "Op" types. This |
10 | // includes the Op type, which is the base class for Op class definitions, |
11 | // as well as number of traits in the OpTrait namespace that provide a |
12 | // declarative way to specify properties of Ops. |
13 | // |
14 | // The purpose of these types are to allow light-weight implementation of |
15 | // concrete ops (like DimOp) with very little boilerplate. |
16 | // |
17 | //===----------------------------------------------------------------------===// |
18 | |
19 | #ifndef MLIR_IR_OPDEFINITION_H |
20 | #define MLIR_IR_OPDEFINITION_H |
21 | |
22 | #include "mlir/IR/Dialect.h" |
23 | #include "mlir/IR/Operation.h" |
24 | #include "llvm/Support/PointerLikeTypeTraits.h" |
25 | |
26 | #include <type_traits> |
27 | |
28 | namespace mlir { |
29 | class Builder; |
30 | class OpBuilder; |
31 | |
32 | /// This class represents success/failure for operation parsing. It is |
33 | /// essentially a simple wrapper class around LogicalResult that allows for |
34 | /// explicit conversion to bool. This allows for the parser to chain together |
35 | /// parse rules without the clutter of "failed/succeeded". |
36 | class ParseResult : public LogicalResult { |
37 | public: |
38 | ParseResult(LogicalResult result = success()) : LogicalResult(result) {} |
39 | |
40 | // Allow diagnostics emitted during parsing to be converted to failure. |
41 | ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {} |
42 | ParseResult(const Diagnostic &) : LogicalResult(failure()) {} |
43 | |
44 | /// Failure is true in a boolean context. |
45 | explicit operator bool() const { return failed(); } |
46 | }; |
47 | /// This class implements `Optional` functionality for ParseResult. We don't |
48 | /// directly use Optional here, because it provides an implicit conversion |
49 | /// to 'bool' which we want to avoid. This class is used to implement tri-state |
50 | /// 'parseOptional' functions that may have a failure mode when parsing that |
51 | /// shouldn't be attributed to "not present". |
52 | class OptionalParseResult { |
53 | public: |
54 | OptionalParseResult() = default; |
55 | OptionalParseResult(LogicalResult result) : impl(result) {} |
56 | OptionalParseResult(ParseResult result) : impl(result) {} |
57 | OptionalParseResult(const InFlightDiagnostic &) |
58 | : OptionalParseResult(failure()) {} |
59 | OptionalParseResult(llvm::NoneType) : impl(llvm::None) {} |
60 | |
61 | /// Returns true if we contain a valid ParseResult value. |
62 | bool hasValue() const { return impl.hasValue(); } |
63 | |
64 | /// Access the internal ParseResult value. |
65 | ParseResult getValue() const { return impl.getValue(); } |
66 | ParseResult operator*() const { return getValue(); } |
67 | |
68 | private: |
69 | Optional<ParseResult> impl; |
70 | }; |
71 | |
72 | // These functions are out-of-line utilities, which avoids them being template |
73 | // instantiated/duplicated. |
74 | namespace impl { |
75 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the |
76 | /// region's only block if it does not have a terminator already. If the region |
77 | /// is empty, insert a new block first. `buildTerminatorOp` should return the |
78 | /// terminator operation to insert. |
79 | void ensureRegionTerminator( |
80 | Region ®ion, OpBuilder &builder, Location loc, |
81 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
82 | void ensureRegionTerminator( |
83 | Region ®ion, Builder &builder, Location loc, |
84 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); |
85 | |
86 | } // namespace impl |
87 | |
88 | /// This is the concrete base class that holds the operation pointer and has |
89 | /// non-generic methods that only depend on State (to avoid having them |
90 | /// instantiated on template types that don't affect them. |
91 | /// |
92 | /// This also has the fallback implementations of customization hooks for when |
93 | /// they aren't customized. |
94 | class OpState { |
95 | public: |
96 | /// Ops are pointer-like, so we allow conversion to bool. |
97 | explicit operator bool() { return getOperation() != nullptr; } |
98 | |
99 | /// This implicitly converts to Operation*. |
100 | operator Operation *() const { return state; } |
101 | |
102 | /// Shortcut of `->` to access a member of Operation. |
103 | Operation *operator->() const { return state; } |
104 | |
105 | /// Return the operation that this refers to. |
106 | Operation *getOperation() { return state; } |
107 | |
108 | /// Return the context this operation belongs to. |
109 | MLIRContext *getContext() { return getOperation()->getContext(); } |
110 | |
111 | /// Print the operation to the given stream. |
112 | void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { |
113 | state->print(os, flags); |
114 | } |
115 | void print(raw_ostream &os, AsmState &asmState, |
116 | OpPrintingFlags flags = llvm::None) { |
117 | state->print(os, asmState, flags); |
118 | } |
119 | |
120 | /// Dump this operation. |
121 | void dump() { state->dump(); } |
122 | |
123 | /// The source location the operation was defined or derived from. |
124 | Location getLoc() { return state->getLoc(); } |
125 | |
126 | /// Return true if there are no users of any results of this operation. |
127 | bool use_empty() { return state->use_empty(); } |
128 | |
129 | /// Remove this operation from its parent block and delete it. |
130 | void erase() { state->erase(); } |
131 | |
132 | /// Emit an error with the op name prefixed, like "'dim' op " which is |
133 | /// convenient for verifiers. |
134 | InFlightDiagnostic emitOpError(const Twine &message = {}); |
135 | |
136 | /// Emit an error about fatal conditions with this operation, reporting up to |
137 | /// any diagnostic handlers that may be listening. |
138 | InFlightDiagnostic emitError(const Twine &message = {}); |
139 | |
140 | /// Emit a warning about this operation, reporting up to any diagnostic |
141 | /// handlers that may be listening. |
142 | InFlightDiagnostic emitWarning(const Twine &message = {}); |
143 | |
144 | /// Emit a remark about this operation, reporting up to any diagnostic |
145 | /// handlers that may be listening. |
146 | InFlightDiagnostic emitRemark(const Twine &message = {}); |
147 | |
148 | /// Walk the operation by calling the callback for each nested operation |
149 | /// (including this one), block or region, depending on the callback provided. |
150 | /// Regions, blocks and operations at the same nesting level are visited in |
151 | /// lexicographical order. The walk order for enclosing regions, blocks and |
152 | /// operations with respect to their nested ones is specified by 'Order' |
153 | /// (post-order by default). A callback on a block or operation is allowed to |
154 | /// erase that block or operation if either: |
155 | /// * the walk is in post-order, or |
156 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
157 | /// See Operation::walk for more details. |
158 | template <WalkOrder Order = WalkOrder::PostOrder, typename FnT, |
159 | typename RetT = detail::walkResultType<FnT>> |
160 | typename std::enable_if< |
161 | llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type |
162 | walk(FnT &&callback) { |
163 | return state->walk<Order>(std::forward<FnT>(callback)); |
164 | } |
165 | |
166 | /// Generic walker with a stage aware callback. Walk the operation by calling |
167 | /// the callback for each nested operation (including this one) N+1 times, |
168 | /// where N is the number of regions attached to that operation. |
169 | /// |
170 | /// The callback method can take any of the following forms: |
171 | /// void(Operation *, const WalkStage &) : Walk all operation opaquely |
172 | /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...}); |
173 | /// void(OpT, const WalkStage &) : Walk all operations of the given derived |
174 | /// type. |
175 | /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); |
176 | /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, |
177 | /// but allow for interruption/skipping. |
178 | /// * op.walk([](... op, const WalkStage &stage) { |
179 | /// // Skip the walk of this op based on some invariant. |
180 | /// if (some_invariant) |
181 | /// return WalkResult::skip(); |
182 | /// // Interrupt, i.e cancel, the walk based on some invariant. |
183 | /// if (another_invariant) |
184 | /// return WalkResult::interrupt(); |
185 | /// return WalkResult::advance(); |
186 | /// }); |
187 | template <typename FnT, typename RetT = detail::walkResultType<FnT>> |
188 | typename std::enable_if< |
189 | llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type |
190 | walk(FnT &&callback) { |
191 | return state->walk(std::forward<FnT>(callback)); |
192 | } |
193 | |
194 | // These are default implementations of customization hooks. |
195 | public: |
196 | /// This hook returns any canonicalization pattern rewrites that the operation |
197 | /// supports, for use by the canonicalization pass. |
198 | static void getCanonicalizationPatterns(RewritePatternSet &results, |
199 | MLIRContext *context) {} |
200 | |
201 | protected: |
202 | /// If the concrete type didn't implement a custom verifier hook, just fall |
203 | /// back to this one which accepts everything. |
204 | LogicalResult verifyInvariants() { return success(); } |
205 | |
206 | /// Parse the custom form of an operation. Unless overridden, this method will |
207 | /// first try to get an operation parser from the op's dialect. Otherwise the |
208 | /// custom assembly form of an op is always rejected. Op implementations |
209 | /// should implement this to return failure. On success, they should fill in |
210 | /// result with the fields to use. |
211 | static ParseResult parse(OpAsmParser &parser, OperationState &result); |
212 | |
213 | /// Print the operation. Unless overridden, this method will first try to get |
214 | /// an operation printer from the dialect. Otherwise, it prints the operation |
215 | /// in generic form. |
216 | static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); |
217 | |
218 | /// Print an operation name, eliding the dialect prefix if necessary. |
219 | static void printOpName(Operation *op, OpAsmPrinter &p, |
220 | StringRef defaultDialect); |
221 | |
222 | /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, |
223 | /// so we can cast it away here. |
224 | explicit OpState(Operation *state) : state(state) {} |
225 | |
226 | private: |
227 | Operation *state; |
228 | |
229 | /// Allow access to internal hook implementation methods. |
230 | friend RegisteredOperationName; |
231 | }; |
232 | |
233 | // Allow comparing operators. |
234 | inline bool operator==(OpState lhs, OpState rhs) { |
235 | return lhs.getOperation() == rhs.getOperation(); |
236 | } |
237 | inline bool operator!=(OpState lhs, OpState rhs) { |
238 | return lhs.getOperation() != rhs.getOperation(); |
239 | } |
240 | |
241 | raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr); |
242 | |
243 | /// This class represents a single result from folding an operation. |
244 | class OpFoldResult : public PointerUnion<Attribute, Value> { |
245 | using PointerUnion<Attribute, Value>::PointerUnion; |
246 | |
247 | public: |
248 | void dump() { llvm::errs() << *this << "\n"; } |
249 | }; |
250 | |
251 | /// Allow printing to a stream. |
252 | inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { |
253 | if (Value value = ofr.dyn_cast<Value>()) |
254 | value.print(os); |
255 | else |
256 | ofr.dyn_cast<Attribute>().print(os); |
257 | return os; |
258 | } |
259 | |
260 | /// Allow printing to a stream. |
261 | inline raw_ostream &operator<<(raw_ostream &os, OpState op) { |
262 | op.print(os, OpPrintingFlags().useLocalScope()); |
263 | return os; |
264 | } |
265 | |
266 | //===----------------------------------------------------------------------===// |
267 | // Operation Trait Types |
268 | //===----------------------------------------------------------------------===// |
269 | |
270 | namespace OpTrait { |
271 | |
272 | // These functions are out-of-line implementations of the methods in the |
273 | // corresponding trait classes. This avoids them being template |
274 | // instantiated/duplicated. |
275 | namespace impl { |
276 | OpFoldResult foldIdempotent(Operation *op); |
277 | OpFoldResult foldInvolution(Operation *op); |
278 | LogicalResult verifyZeroOperands(Operation *op); |
279 | LogicalResult verifyOneOperand(Operation *op); |
280 | LogicalResult verifyNOperands(Operation *op, unsigned numOperands); |
281 | LogicalResult verifyIsIdempotent(Operation *op); |
282 | LogicalResult verifyIsInvolution(Operation *op); |
283 | LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); |
284 | LogicalResult verifyOperandsAreFloatLike(Operation *op); |
285 | LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); |
286 | LogicalResult verifySameTypeOperands(Operation *op); |
287 | LogicalResult verifyZeroRegion(Operation *op); |
288 | LogicalResult verifyOneRegion(Operation *op); |
289 | LogicalResult verifyNRegions(Operation *op, unsigned numRegions); |
290 | LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); |
291 | LogicalResult verifyZeroResult(Operation *op); |
292 | LogicalResult verifyOneResult(Operation *op); |
293 | LogicalResult verifyNResults(Operation *op, unsigned numOperands); |
294 | LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); |
295 | LogicalResult verifySameOperandsShape(Operation *op); |
296 | LogicalResult verifySameOperandsAndResultShape(Operation *op); |
297 | LogicalResult verifySameOperandsElementType(Operation *op); |
298 | LogicalResult verifySameOperandsAndResultElementType(Operation *op); |
299 | LogicalResult verifySameOperandsAndResultType(Operation *op); |
300 | LogicalResult verifyResultsAreBoolLike(Operation *op); |
301 | LogicalResult verifyResultsAreFloatLike(Operation *op); |
302 | LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); |
303 | LogicalResult verifyIsTerminator(Operation *op); |
304 | LogicalResult verifyZeroSuccessor(Operation *op); |
305 | LogicalResult verifyOneSuccessor(Operation *op); |
306 | LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); |
307 | LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); |
308 | LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, |
309 | StringRef valueGroupName, |
310 | size_t expectedCount); |
311 | LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); |
312 | LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); |
313 | LogicalResult verifyNoRegionArguments(Operation *op); |
314 | LogicalResult verifyElementwise(Operation *op); |
315 | LogicalResult verifyIsIsolatedFromAbove(Operation *op); |
316 | } // namespace impl |
317 | |
318 | /// Helper class for implementing traits. Clients are not expected to interact |
319 | /// with this directly, so its members are all protected. |
320 | template <typename ConcreteType, template <typename> class TraitType> |
321 | class TraitBase { |
322 | protected: |
323 | /// Return the ultimate Operation being worked on. |
324 | Operation *getOperation() { |
325 | // We have to cast up to the trait type, then to the concrete type, then to |
326 | // the BaseState class in explicit hops because the concrete type will |
327 | // multiply derive from the (content free) TraitBase class, and we need to |
328 | // be able to disambiguate the path for the C++ compiler. |
329 | auto *trait = static_cast<TraitType<ConcreteType> *>(this); |
330 | auto *concrete = static_cast<ConcreteType *>(trait); |
331 | auto *base = static_cast<OpState *>(concrete); |
332 | return base->getOperation(); |
333 | } |
334 | }; |
335 | |
336 | //===----------------------------------------------------------------------===// |
337 | // Operand Traits |
338 | |
339 | namespace detail { |
340 | /// Utility trait base that provides accessors for derived traits that have |
341 | /// multiple operands. |
342 | template <typename ConcreteType, template <typename> class TraitType> |
343 | struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> { |
344 | using operand_iterator = Operation::operand_iterator; |
345 | using operand_range = Operation::operand_range; |
346 | using operand_type_iterator = Operation::operand_type_iterator; |
347 | using operand_type_range = Operation::operand_type_range; |
348 | |
349 | /// Return the number of operands. |
350 | unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } |
351 | |
352 | /// Return the operand at index 'i'. |
353 | Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } |
354 | |
355 | /// Set the operand at index 'i' to 'value'. |
356 | void setOperand(unsigned i, Value value) { |
357 | this->getOperation()->setOperand(i, value); |
358 | } |
359 | |
360 | /// Operand iterator access. |
361 | operand_iterator operand_begin() { |
362 | return this->getOperation()->operand_begin(); |
363 | } |
364 | operand_iterator operand_end() { return this->getOperation()->operand_end(); } |
365 | operand_range getOperands() { return this->getOperation()->getOperands(); } |
366 | |
367 | /// Operand type access. |
368 | operand_type_iterator operand_type_begin() { |
369 | return this->getOperation()->operand_type_begin(); |
370 | } |
371 | operand_type_iterator operand_type_end() { |
372 | return this->getOperation()->operand_type_end(); |
373 | } |
374 | operand_type_range getOperandTypes() { |
375 | return this->getOperation()->getOperandTypes(); |
376 | } |
377 | }; |
378 | } // namespace detail |
379 | |
380 | /// This class provides the API for ops that are known to have no |
381 | /// SSA operand. |
382 | template <typename ConcreteType> |
383 | class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { |
384 | public: |
385 | static LogicalResult verifyTrait(Operation *op) { |
386 | return impl::verifyZeroOperands(op); |
387 | } |
388 | |
389 | private: |
390 | // Disable these. |
391 | void getOperand() {} |
392 | void setOperand() {} |
393 | }; |
394 | |
395 | /// This class provides the API for ops that are known to have exactly one |
396 | /// SSA operand. |
397 | template <typename ConcreteType> |
398 | class OneOperand : public TraitBase<ConcreteType, OneOperand> { |
399 | public: |
400 | Value getOperand() { return this->getOperation()->getOperand(0); } |
401 | |
402 | void setOperand(Value value) { this->getOperation()->setOperand(0, value); } |
403 | |
404 | static LogicalResult verifyTrait(Operation *op) { |
405 | return impl::verifyOneOperand(op); |
406 | } |
407 | }; |
408 | |
409 | /// This class provides the API for ops that are known to have a specified |
410 | /// number of operands. This is used as a trait like this: |
411 | /// |
412 | /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { |
413 | /// |
414 | template <unsigned N> |
415 | class NOperands { |
416 | public: |
417 | static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); |
418 | |
419 | template <typename ConcreteType> |
420 | class Impl |
421 | : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> { |
422 | public: |
423 | static LogicalResult verifyTrait(Operation *op) { |
424 | return impl::verifyNOperands(op, N); |
425 | } |
426 | }; |
427 | }; |
428 | |
429 | /// This class provides the API for ops that are known to have a at least a |
430 | /// specified number of operands. This is used as a trait like this: |
431 | /// |
432 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { |
433 | /// |
434 | template <unsigned N> |
435 | class AtLeastNOperands { |
436 | public: |
437 | template <typename ConcreteType> |
438 | class Impl : public detail::MultiOperandTraitBase<ConcreteType, |
439 | AtLeastNOperands<N>::Impl> { |
440 | public: |
441 | static LogicalResult verifyTrait(Operation *op) { |
442 | return impl::verifyAtLeastNOperands(op, N); |
443 | } |
444 | }; |
445 | }; |
446 | |
447 | /// This class provides the API for ops which have an unknown number of |
448 | /// SSA operands. |
449 | template <typename ConcreteType> |
450 | class VariadicOperands |
451 | : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {}; |
452 | |
453 | //===----------------------------------------------------------------------===// |
454 | // Region Traits |
455 | |
456 | /// This class provides verification for ops that are known to have zero |
457 | /// regions. |
458 | template <typename ConcreteType> |
459 | class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> { |
460 | public: |
461 | static LogicalResult verifyTrait(Operation *op) { |
462 | return impl::verifyZeroRegion(op); |
463 | } |
464 | }; |
465 | |
466 | namespace detail { |
467 | /// Utility trait base that provides accessors for derived traits that have |
468 | /// multiple regions. |
469 | template <typename ConcreteType, template <typename> class TraitType> |
470 | struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> { |
471 | using region_iterator = MutableArrayRef<Region>; |
472 | using region_range = RegionRange; |
473 | |
474 | /// Return the number of regions. |
475 | unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } |
476 | |
477 | /// Return the region at `index`. |
478 | Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } |
479 | |
480 | /// Region iterator access. |
481 | region_iterator region_begin() { |
482 | return this->getOperation()->region_begin(); |
483 | } |
484 | region_iterator region_end() { return this->getOperation()->region_end(); } |
485 | region_range getRegions() { return this->getOperation()->getRegions(); } |
486 | }; |
487 | } // namespace detail |
488 | |
489 | /// This class provides APIs for ops that are known to have a single region. |
490 | template <typename ConcreteType> |
491 | class OneRegion : public TraitBase<ConcreteType, OneRegion> { |
492 | public: |
493 | Region &getRegion() { return this->getOperation()->getRegion(0); } |
494 | |
495 | /// Returns a range of operations within the region of this operation. |
496 | auto getOps() { return getRegion().getOps(); } |
497 | template <typename OpT> |
498 | auto getOps() { |
499 | return getRegion().template getOps<OpT>(); |
500 | } |
501 | |
502 | static LogicalResult verifyTrait(Operation *op) { |
503 | return impl::verifyOneRegion(op); |
504 | } |
505 | }; |
506 | |
507 | /// This class provides the API for ops that are known to have a specified |
508 | /// number of regions. |
509 | template <unsigned N> |
510 | class NRegions { |
511 | public: |
512 | static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); |
513 | |
514 | template <typename ConcreteType> |
515 | class Impl |
516 | : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> { |
517 | public: |
518 | static LogicalResult verifyTrait(Operation *op) { |
519 | return impl::verifyNRegions(op, N); |
520 | } |
521 | }; |
522 | }; |
523 | |
524 | /// This class provides APIs for ops that are known to have at least a specified |
525 | /// number of regions. |
526 | template <unsigned N> |
527 | class AtLeastNRegions { |
528 | public: |
529 | template <typename ConcreteType> |
530 | class Impl : public detail::MultiRegionTraitBase<ConcreteType, |
531 | AtLeastNRegions<N>::Impl> { |
532 | public: |
533 | static LogicalResult verifyTrait(Operation *op) { |
534 | return impl::verifyAtLeastNRegions(op, N); |
535 | } |
536 | }; |
537 | }; |
538 | |
539 | /// This class provides the API for ops which have an unknown number of |
540 | /// regions. |
541 | template <typename ConcreteType> |
542 | class VariadicRegions |
543 | : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {}; |
544 | |
545 | //===----------------------------------------------------------------------===// |
546 | // Result Traits |
547 | |
548 | /// This class provides return value APIs for ops that are known to have |
549 | /// zero results. |
550 | template <typename ConcreteType> |
551 | class ZeroResult : public TraitBase<ConcreteType, ZeroResult> { |
552 | public: |
553 | static LogicalResult verifyTrait(Operation *op) { |
554 | return impl::verifyZeroResult(op); |
555 | } |
556 | }; |
557 | |
558 | namespace detail { |
559 | /// Utility trait base that provides accessors for derived traits that have |
560 | /// multiple results. |
561 | template <typename ConcreteType, template <typename> class TraitType> |
562 | struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> { |
563 | using result_iterator = Operation::result_iterator; |
564 | using result_range = Operation::result_range; |
565 | using result_type_iterator = Operation::result_type_iterator; |
566 | using result_type_range = Operation::result_type_range; |
567 | |
568 | /// Return the number of results. |
569 | unsigned getNumResults() { return this->getOperation()->getNumResults(); } |
570 | |
571 | /// Return the result at index 'i'. |
572 | Value getResult(unsigned i) { return this->getOperation()->getResult(i); } |
573 | |
574 | /// Replace all uses of results of this operation with the provided 'values'. |
575 | /// 'values' may correspond to an existing operation, or a range of 'Value'. |
576 | template <typename ValuesT> |
577 | void replaceAllUsesWith(ValuesT &&values) { |
578 | this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); |
579 | } |
580 | |
581 | /// Return the type of the `i`-th result. |
582 | Type getType(unsigned i) { return getResult(i).getType(); } |
583 | |
584 | /// Result iterator access. |
585 | result_iterator result_begin() { |
586 | return this->getOperation()->result_begin(); |
587 | } |
588 | result_iterator result_end() { return this->getOperation()->result_end(); } |
589 | result_range getResults() { return this->getOperation()->getResults(); } |
590 | |
591 | /// Result type access. |
592 | result_type_iterator result_type_begin() { |
593 | return this->getOperation()->result_type_begin(); |
594 | } |
595 | result_type_iterator result_type_end() { |
596 | return this->getOperation()->result_type_end(); |
597 | } |
598 | result_type_range getResultTypes() { |
599 | return this->getOperation()->getResultTypes(); |
600 | } |
601 | }; |
602 | } // namespace detail |
603 | |
604 | /// This class provides return value APIs for ops that are known to have a |
605 | /// single result. ResultType is the concrete type returned by getType(). |
606 | template <typename ConcreteType> |
607 | class OneResult : public TraitBase<ConcreteType, OneResult> { |
608 | public: |
609 | Value getResult() { return this->getOperation()->getResult(0); } |
610 | |
611 | /// If the operation returns a single value, then the Op can be implicitly |
612 | /// converted to an Value. This yields the value of the only result. |
613 | operator Value() { return getResult(); } |
614 | |
615 | /// Replace all uses of 'this' value with the new value, updating anything |
616 | /// in the IR that uses 'this' to use the other value instead. When this |
617 | /// returns there are zero uses of 'this'. |
618 | void replaceAllUsesWith(Value newValue) { |
619 | getResult().replaceAllUsesWith(newValue); |
620 | } |
621 | |
622 | /// Replace all uses of 'this' value with the result of 'op'. |
623 | void replaceAllUsesWith(Operation *op) { |
624 | this->getOperation()->replaceAllUsesWith(op); |
625 | } |
626 | |
627 | static LogicalResult verifyTrait(Operation *op) { |
628 | return impl::verifyOneResult(op); |
629 | } |
630 | }; |
631 | |
632 | /// This trait is used for return value APIs for ops that are known to have a |
633 | /// specific type other than `Type`. This allows the "getType()" member to be |
634 | /// more specific for an op. This should be used in conjunction with OneResult, |
635 | /// and occur in the trait list before OneResult. |
636 | template <typename ResultType> |
637 | class OneTypedResult { |
638 | public: |
639 | /// This class provides return value APIs for ops that are known to have a |
640 | /// single result. ResultType is the concrete type returned by getType(). |
641 | template <typename ConcreteType> |
642 | class Impl |
643 | : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { |
644 | public: |
645 | ResultType getType() { |
646 | auto resultTy = this->getOperation()->getResult(0).getType(); |
647 | return resultTy.template cast<ResultType>(); |
648 | } |
649 | }; |
650 | }; |
651 | |
652 | /// This class provides the API for ops that are known to have a specified |
653 | /// number of results. This is used as a trait like this: |
654 | /// |
655 | /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { |
656 | /// |
657 | template <unsigned N> |
658 | class NResults { |
659 | public: |
660 | static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); |
661 | |
662 | template <typename ConcreteType> |
663 | class Impl |
664 | : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> { |
665 | public: |
666 | static LogicalResult verifyTrait(Operation *op) { |
667 | return impl::verifyNResults(op, N); |
668 | } |
669 | }; |
670 | }; |
671 | |
672 | /// This class provides the API for ops that are known to have at least a |
673 | /// specified number of results. This is used as a trait like this: |
674 | /// |
675 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { |
676 | /// |
677 | template <unsigned N> |
678 | class AtLeastNResults { |
679 | public: |
680 | template <typename ConcreteType> |
681 | class Impl : public detail::MultiResultTraitBase<ConcreteType, |
682 | AtLeastNResults<N>::Impl> { |
683 | public: |
684 | static LogicalResult verifyTrait(Operation *op) { |
685 | return impl::verifyAtLeastNResults(op, N); |
686 | } |
687 | }; |
688 | }; |
689 | |
690 | /// This class provides the API for ops which have an unknown number of |
691 | /// results. |
692 | template <typename ConcreteType> |
693 | class VariadicResults |
694 | : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {}; |
695 | |
696 | //===----------------------------------------------------------------------===// |
697 | // Terminator Traits |
698 | |
699 | /// This class indicates that the regions associated with this op don't have |
700 | /// terminators. |
701 | template <typename ConcreteType> |
702 | class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {}; |
703 | |
704 | /// This class provides the API for ops that are known to be terminators. |
705 | template <typename ConcreteType> |
706 | class IsTerminator : public TraitBase<ConcreteType, IsTerminator> { |
707 | public: |
708 | static LogicalResult verifyTrait(Operation *op) { |
709 | return impl::verifyIsTerminator(op); |
710 | } |
711 | }; |
712 | |
713 | /// This class provides verification for ops that are known to have zero |
714 | /// successors. |
715 | template <typename ConcreteType> |
716 | class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> { |
717 | public: |
718 | static LogicalResult verifyTrait(Operation *op) { |
719 | return impl::verifyZeroSuccessor(op); |
720 | } |
721 | }; |
722 | |
723 | namespace detail { |
724 | /// Utility trait base that provides accessors for derived traits that have |
725 | /// multiple successors. |
726 | template <typename ConcreteType, template <typename> class TraitType> |
727 | struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> { |
728 | using succ_iterator = Operation::succ_iterator; |
729 | using succ_range = SuccessorRange; |
730 | |
731 | /// Return the number of successors. |
732 | unsigned getNumSuccessors() { |
733 | return this->getOperation()->getNumSuccessors(); |
734 | } |
735 | |
736 | /// Return the successor at `index`. |
737 | Block *getSuccessor(unsigned i) { |
738 | return this->getOperation()->getSuccessor(i); |
739 | } |
740 | |
741 | /// Set the successor at `index`. |
742 | void setSuccessor(Block *block, unsigned i) { |
743 | return this->getOperation()->setSuccessor(block, i); |
744 | } |
745 | |
746 | /// Successor iterator access. |
747 | succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } |
748 | succ_iterator succ_end() { return this->getOperation()->succ_end(); } |
749 | succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } |
750 | }; |
751 | } // namespace detail |
752 | |
753 | /// This class provides APIs for ops that are known to have a single successor. |
754 | template <typename ConcreteType> |
755 | class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> { |
756 | public: |
757 | Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } |
758 | void setSuccessor(Block *succ) { |
759 | this->getOperation()->setSuccessor(succ, 0); |
760 | } |
761 | |
762 | static LogicalResult verifyTrait(Operation *op) { |
763 | return impl::verifyOneSuccessor(op); |
764 | } |
765 | }; |
766 | |
767 | /// This class provides the API for ops that are known to have a specified |
768 | /// number of successors. |
769 | template <unsigned N> |
770 | class NSuccessors { |
771 | public: |
772 | static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2"); |
773 | |
774 | template <typename ConcreteType> |
775 | class Impl : public detail::MultiSuccessorTraitBase<ConcreteType, |
776 | NSuccessors<N>::Impl> { |
777 | public: |
778 | static LogicalResult verifyTrait(Operation *op) { |
779 | return impl::verifyNSuccessors(op, N); |
780 | } |
781 | }; |
782 | }; |
783 | |
784 | /// This class provides APIs for ops that are known to have at least a specified |
785 | /// number of successors. |
786 | template <unsigned N> |
787 | class AtLeastNSuccessors { |
788 | public: |
789 | template <typename ConcreteType> |
790 | class Impl |
791 | : public detail::MultiSuccessorTraitBase<ConcreteType, |
792 | AtLeastNSuccessors<N>::Impl> { |
793 | public: |
794 | static LogicalResult verifyTrait(Operation *op) { |
795 | return impl::verifyAtLeastNSuccessors(op, N); |
796 | } |
797 | }; |
798 | }; |
799 | |
800 | /// This class provides the API for ops which have an unknown number of |
801 | /// successors. |
802 | template <typename ConcreteType> |
803 | class VariadicSuccessors |
804 | : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> { |
805 | }; |
806 | |
807 | //===----------------------------------------------------------------------===// |
808 | // SingleBlock |
809 | |
810 | /// This class provides APIs and verifiers for ops with regions having a single |
811 | /// block. |
812 | template <typename ConcreteType> |
813 | struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> { |
814 | public: |
815 | static LogicalResult verifyTrait(Operation *op) { |
816 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
817 | Region ®ion = op->getRegion(i); |
818 | |
819 | // Empty regions are fine. |
820 | if (region.empty()) |
821 | continue; |
822 | |
823 | // Non-empty regions must contain a single basic block. |
824 | if (!llvm::hasSingleElement(region)) |
825 | return op->emitOpError("expects region #") |
826 | << i << " to have 0 or 1 blocks"; |
827 | |
828 | if (!ConcreteType::template hasTrait<NoTerminator>()) { |
829 | Block &block = region.front(); |
830 | if (block.empty()) |
831 | return op->emitOpError() << "expects a non-empty block"; |
832 | } |
833 | } |
834 | return success(); |
835 | } |
836 | |
837 | Block *getBody(unsigned idx = 0) { |
838 | Region ®ion = this->getOperation()->getRegion(idx); |
839 | assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region" ) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\"" , "mlir/include/mlir/IR/OpDefinition.h", 839, __extension__ __PRETTY_FUNCTION__ )); |
840 | return ®ion.front(); |
841 | } |
842 | Region &getBodyRegion(unsigned idx = 0) { |
843 | return this->getOperation()->getRegion(idx); |
844 | } |
845 | |
846 | //===------------------------------------------------------------------===// |
847 | // Single Region Utilities |
848 | //===------------------------------------------------------------------===// |
849 | |
850 | /// The following are a set of methods only enabled when the parent |
851 | /// operation has a single region. Each of these methods take an additional |
852 | /// template parameter that represents the concrete operation so that we |
853 | /// can use SFINAE to disable the methods for non-single region operations. |
854 | template <typename OpT, typename T = void> |
855 | using enable_if_single_region = |
856 | typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; |
857 | |
858 | template <typename OpT = ConcreteType> |
859 | enable_if_single_region<OpT, Block::iterator> begin() { |
860 | return getBody()->begin(); |
861 | } |
862 | template <typename OpT = ConcreteType> |
863 | enable_if_single_region<OpT, Block::iterator> end() { |
864 | return getBody()->end(); |
865 | } |
866 | template <typename OpT = ConcreteType> |
867 | enable_if_single_region<OpT, Operation &> front() { |
868 | return *begin(); |
869 | } |
870 | |
871 | /// Insert the operation into the back of the body. |
872 | template <typename OpT = ConcreteType> |
873 | enable_if_single_region<OpT> push_back(Operation *op) { |
874 | insert(Block::iterator(getBody()->end()), op); |
875 | } |
876 | |
877 | /// Insert the operation at the given insertion point. |
878 | template <typename OpT = ConcreteType> |
879 | enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { |
880 | insert(Block::iterator(insertPt), op); |
881 | } |
882 | template <typename OpT = ConcreteType> |
883 | enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) { |
884 | getBody()->getOperations().insert(insertPt, op); |
885 | } |
886 | }; |
887 | |
888 | //===----------------------------------------------------------------------===// |
889 | // SingleBlockImplicitTerminator |
890 | |
891 | /// This class provides APIs and verifiers for ops with regions having a single |
892 | /// block that must terminate with `TerminatorOpType`. |
893 | template <typename TerminatorOpType> |
894 | struct SingleBlockImplicitTerminator { |
895 | template <typename ConcreteType> |
896 | class Impl : public SingleBlock<ConcreteType> { |
897 | private: |
898 | using Base = SingleBlock<ConcreteType>; |
899 | /// Builds a terminator operation without relying on OpBuilder APIs to avoid |
900 | /// cyclic header inclusion. |
901 | static Operation *buildTerminator(OpBuilder &builder, Location loc) { |
902 | OperationState state(loc, TerminatorOpType::getOperationName()); |
903 | TerminatorOpType::build(builder, state); |
904 | return Operation::create(state); |
905 | } |
906 | |
907 | public: |
908 | /// The type of the operation used as the implicit terminator type. |
909 | using ImplicitTerminatorOpT = TerminatorOpType; |
910 | |
911 | static LogicalResult verifyTrait(Operation *op) { |
912 | if (failed(Base::verifyTrait(op))) |
913 | return failure(); |
914 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { |
915 | Region ®ion = op->getRegion(i); |
916 | // Empty regions are fine. |
917 | if (region.empty()) |
918 | continue; |
919 | Operation &terminator = region.front().back(); |
920 | if (isa<TerminatorOpType>(terminator)) |
921 | continue; |
922 | |
923 | return op->emitOpError("expects regions to end with '" + |
924 | TerminatorOpType::getOperationName() + |
925 | "', found '" + |
926 | terminator.getName().getStringRef() + "'") |
927 | .attachNote() |
928 | << "in custom textual format, the absence of terminator implies " |
929 | "'" |
930 | << TerminatorOpType::getOperationName() << '\''; |
931 | } |
932 | |
933 | return success(); |
934 | } |
935 | |
936 | /// Ensure that the given region has the terminator required by this trait. |
937 | /// If OpBuilder is provided, use it to build the terminator and notify the |
938 | /// OpBuilder listeners accordingly. If only a Builder is provided, locally |
939 | /// construct an OpBuilder with no listeners; this should only be used if no |
940 | /// OpBuilder is available at the call site, e.g., in the parser. |
941 | static void ensureTerminator(Region ®ion, Builder &builder, |
942 | Location loc) { |
943 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
944 | buildTerminator); |
945 | } |
946 | static void ensureTerminator(Region ®ion, OpBuilder &builder, |
947 | Location loc) { |
948 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, |
949 | buildTerminator); |
950 | } |
951 | |
952 | //===------------------------------------------------------------------===// |
953 | // Single Region Utilities |
954 | //===------------------------------------------------------------------===// |
955 | using Base::getBody; |
956 | |
957 | template <typename OpT, typename T = void> |
958 | using enable_if_single_region = |
959 | typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; |
960 | |
961 | /// Insert the operation into the back of the body, before the terminator. |
962 | template <typename OpT = ConcreteType> |
963 | enable_if_single_region<OpT> push_back(Operation *op) { |
964 | insert(Block::iterator(getBody()->getTerminator()), op); |
965 | } |
966 | |
967 | /// Insert the operation at the given insertion point. Note: The operation |
968 | /// is never inserted after the terminator, even if the insertion point is |
969 | /// end(). |
970 | template <typename OpT = ConcreteType> |
971 | enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { |
972 | insert(Block::iterator(insertPt), op); |
973 | } |
974 | template <typename OpT = ConcreteType> |
975 | enable_if_single_region<OpT> insert(Block::iterator insertPt, |
976 | Operation *op) { |
977 | auto *body = getBody(); |
978 | if (insertPt == body->end()) |
979 | insertPt = Block::iterator(body->getTerminator()); |
980 | body->getOperations().insert(insertPt, op); |
981 | } |
982 | }; |
983 | }; |
984 | |
985 | /// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended |
986 | /// to be used with `llvm::is_detected`. |
987 | template <class T> |
988 | using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT; |
989 | |
990 | /// Support to check if an operation has the SingleBlockImplicitTerminator |
991 | /// trait. We can't just use `hasTrait` because this class is templated on a |
992 | /// specific terminator op. |
993 | template <class Op, bool hasTerminator = |
994 | llvm::is_detected<has_implicit_terminator_t, Op>::value> |
995 | struct hasSingleBlockImplicitTerminator { |
996 | static constexpr bool value = std::is_base_of< |
997 | typename OpTrait::SingleBlockImplicitTerminator< |
998 | typename Op::ImplicitTerminatorOpT>::template Impl<Op>, |
999 | Op>::value; |
1000 | }; |
1001 | template <class Op> |
1002 | struct hasSingleBlockImplicitTerminator<Op, false> { |
1003 | static constexpr bool value = false; |
1004 | }; |
1005 | |
1006 | //===----------------------------------------------------------------------===// |
1007 | // Misc Traits |
1008 | |
1009 | /// This class provides verification for ops that are known to have the same |
1010 | /// operand shape: all operands are scalars, vectors/tensors of the same |
1011 | /// shape. |
1012 | template <typename ConcreteType> |
1013 | class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> { |
1014 | public: |
1015 | static LogicalResult verifyTrait(Operation *op) { |
1016 | return impl::verifySameOperandsShape(op); |
1017 | } |
1018 | }; |
1019 | |
1020 | /// This class provides verification for ops that are known to have the same |
1021 | /// operand and result shape: both are scalars, vectors/tensors of the same |
1022 | /// shape. |
1023 | template <typename ConcreteType> |
1024 | class SameOperandsAndResultShape |
1025 | : public TraitBase<ConcreteType, SameOperandsAndResultShape> { |
1026 | public: |
1027 | static LogicalResult verifyTrait(Operation *op) { |
1028 | return impl::verifySameOperandsAndResultShape(op); |
1029 | } |
1030 | }; |
1031 | |
1032 | /// This class provides verification for ops that are known to have the same |
1033 | /// operand element type (or the type itself if it is scalar). |
1034 | /// |
1035 | template <typename ConcreteType> |
1036 | class SameOperandsElementType |
1037 | : public TraitBase<ConcreteType, SameOperandsElementType> { |
1038 | public: |
1039 | static LogicalResult verifyTrait(Operation *op) { |
1040 | return impl::verifySameOperandsElementType(op); |
1041 | } |
1042 | }; |
1043 | |
1044 | /// This class provides verification for ops that are known to have the same |
1045 | /// operand and result element type (or the type itself if it is scalar). |
1046 | /// |
1047 | template <typename ConcreteType> |
1048 | class SameOperandsAndResultElementType |
1049 | : public TraitBase<ConcreteType, SameOperandsAndResultElementType> { |
1050 | public: |
1051 | static LogicalResult verifyTrait(Operation *op) { |
1052 | return impl::verifySameOperandsAndResultElementType(op); |
1053 | } |
1054 | }; |
1055 | |
1056 | /// This class provides verification for ops that are known to have the same |
1057 | /// operand and result type. |
1058 | /// |
1059 | /// Note: this trait subsumes the SameOperandsAndResultShape and |
1060 | /// SameOperandsAndResultElementType traits. |
1061 | template <typename ConcreteType> |
1062 | class SameOperandsAndResultType |
1063 | : public TraitBase<ConcreteType, SameOperandsAndResultType> { |
1064 | public: |
1065 | static LogicalResult verifyTrait(Operation *op) { |
1066 | return impl::verifySameOperandsAndResultType(op); |
1067 | } |
1068 | }; |
1069 | |
1070 | /// This class verifies that any results of the specified op have a boolean |
1071 | /// type, a vector thereof, or a tensor thereof. |
1072 | template <typename ConcreteType> |
1073 | class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { |
1074 | public: |
1075 | static LogicalResult verifyTrait(Operation *op) { |
1076 | return impl::verifyResultsAreBoolLike(op); |
1077 | } |
1078 | }; |
1079 | |
1080 | /// This class verifies that any results of the specified op have a floating |
1081 | /// point type, a vector thereof, or a tensor thereof. |
1082 | template <typename ConcreteType> |
1083 | class ResultsAreFloatLike |
1084 | : public TraitBase<ConcreteType, ResultsAreFloatLike> { |
1085 | public: |
1086 | static LogicalResult verifyTrait(Operation *op) { |
1087 | return impl::verifyResultsAreFloatLike(op); |
1088 | } |
1089 | }; |
1090 | |
1091 | /// This class verifies that any results of the specified op have a signless |
1092 | /// integer or index type, a vector thereof, or a tensor thereof. |
1093 | template <typename ConcreteType> |
1094 | class ResultsAreSignlessIntegerLike |
1095 | : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> { |
1096 | public: |
1097 | static LogicalResult verifyTrait(Operation *op) { |
1098 | return impl::verifyResultsAreSignlessIntegerLike(op); |
1099 | } |
1100 | }; |
1101 | |
1102 | /// This class adds property that the operation is commutative. |
1103 | template <typename ConcreteType> |
1104 | class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {}; |
1105 | |
1106 | /// This class adds property that the operation is an involution. |
1107 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = x |
1108 | template <typename ConcreteType> |
1109 | class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { |
1110 | public: |
1111 | static LogicalResult verifyTrait(Operation *op) { |
1112 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1113 | "expected operation to produce one result"); |
1114 | static_assert(ConcreteType::template hasTrait<OneOperand>(), |
1115 | "expected operation to take one operand"); |
1116 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
1117 | "expected operation to preserve type"); |
1118 | // Involution requires the operation to be side effect free as well |
1119 | // but currently this check is under a FIXME and is not actually done. |
1120 | return impl::verifyIsInvolution(op); |
1121 | } |
1122 | |
1123 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
1124 | return impl::foldInvolution(op); |
1125 | } |
1126 | }; |
1127 | |
1128 | /// This class adds property that the operation is idempotent. |
1129 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), |
1130 | /// or a binary operation "g" that satisfies g(x, x) = x. |
1131 | template <typename ConcreteType> |
1132 | class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> { |
1133 | public: |
1134 | static LogicalResult verifyTrait(Operation *op) { |
1135 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1136 | "expected operation to produce one result"); |
1137 | static_assert(ConcreteType::template hasTrait<OneOperand>() || |
1138 | ConcreteType::template hasTrait<NOperands<2>::Impl>(), |
1139 | "expected operation to take one or two operands"); |
1140 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), |
1141 | "expected operation to preserve type"); |
1142 | // Idempotent requires the operation to be side effect free as well |
1143 | // but currently this check is under a FIXME and is not actually done. |
1144 | return impl::verifyIsIdempotent(op); |
1145 | } |
1146 | |
1147 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { |
1148 | return impl::foldIdempotent(op); |
1149 | } |
1150 | }; |
1151 | |
1152 | /// This class verifies that all operands of the specified op have a float type, |
1153 | /// a vector thereof, or a tensor thereof. |
1154 | template <typename ConcreteType> |
1155 | class OperandsAreFloatLike |
1156 | : public TraitBase<ConcreteType, OperandsAreFloatLike> { |
1157 | public: |
1158 | static LogicalResult verifyTrait(Operation *op) { |
1159 | return impl::verifyOperandsAreFloatLike(op); |
1160 | } |
1161 | }; |
1162 | |
1163 | /// This class verifies that all operands of the specified op have a signless |
1164 | /// integer or index type, a vector thereof, or a tensor thereof. |
1165 | template <typename ConcreteType> |
1166 | class OperandsAreSignlessIntegerLike |
1167 | : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> { |
1168 | public: |
1169 | static LogicalResult verifyTrait(Operation *op) { |
1170 | return impl::verifyOperandsAreSignlessIntegerLike(op); |
1171 | } |
1172 | }; |
1173 | |
1174 | /// This class verifies that all operands of the specified op have the same |
1175 | /// type. |
1176 | template <typename ConcreteType> |
1177 | class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { |
1178 | public: |
1179 | static LogicalResult verifyTrait(Operation *op) { |
1180 | return impl::verifySameTypeOperands(op); |
1181 | } |
1182 | }; |
1183 | |
1184 | /// This class provides the API for a sub-set of ops that are known to be |
1185 | /// constant-like. These are non-side effecting operations with one result and |
1186 | /// zero operands that can always be folded to a specific attribute value. |
1187 | template <typename ConcreteType> |
1188 | class ConstantLike : public TraitBase<ConcreteType, ConstantLike> { |
1189 | public: |
1190 | static LogicalResult verifyTrait(Operation *op) { |
1191 | static_assert(ConcreteType::template hasTrait<OneResult>(), |
1192 | "expected operation to produce one result"); |
1193 | static_assert(ConcreteType::template hasTrait<ZeroOperands>(), |
1194 | "expected operation to take zero operands"); |
1195 | // TODO: We should verify that the operation can always be folded, but this |
1196 | // requires that the attributes of the op already be verified. We should add |
1197 | // support for verifying traits "after" the operation to enable this use |
1198 | // case. |
1199 | return success(); |
1200 | } |
1201 | }; |
1202 | |
1203 | /// This class provides the API for ops that are known to be isolated from |
1204 | /// above. |
1205 | template <typename ConcreteType> |
1206 | class IsIsolatedFromAbove |
1207 | : public TraitBase<ConcreteType, IsIsolatedFromAbove> { |
1208 | public: |
1209 | static LogicalResult verifyTrait(Operation *op) { |
1210 | return impl::verifyIsIsolatedFromAbove(op); |
1211 | } |
1212 | }; |
1213 | |
1214 | /// A trait of region holding operations that defines a new scope for polyhedral |
1215 | /// optimization purposes. Any SSA values of 'index' type that either dominate |
1216 | /// such an operation or are used at the top-level of such an operation |
1217 | /// automatically become valid symbols for the polyhedral scope defined by that |
1218 | /// operation. For more details, see `Traits.md#AffineScope`. |
1219 | template <typename ConcreteType> |
1220 | class AffineScope : public TraitBase<ConcreteType, AffineScope> { |
1221 | public: |
1222 | static LogicalResult verifyTrait(Operation *op) { |
1223 | static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), |
1224 | "expected operation to have one or more regions"); |
1225 | return success(); |
1226 | } |
1227 | }; |
1228 | |
1229 | /// A trait of region holding operations that define a new scope for automatic |
1230 | /// allocations, i.e., allocations that are freed when control is transferred |
1231 | /// back from the operation's region. Any operations performing such allocations |
1232 | /// (for eg. memref.alloca) will have their allocations automatically freed at |
1233 | /// their closest enclosing operation with this trait. |
1234 | template <typename ConcreteType> |
1235 | class AutomaticAllocationScope |
1236 | : public TraitBase<ConcreteType, AutomaticAllocationScope> { |
1237 | public: |
1238 | static LogicalResult verifyTrait(Operation *op) { |
1239 | static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), |
1240 | "expected operation to have one or more regions"); |
1241 | return success(); |
1242 | } |
1243 | }; |
1244 | |
1245 | /// This class provides a verifier for ops that are expecting their parent |
1246 | /// to be one of the given parent ops |
1247 | template <typename... ParentOpTypes> |
1248 | struct HasParent { |
1249 | template <typename ConcreteType> |
1250 | class Impl : public TraitBase<ConcreteType, Impl> { |
1251 | public: |
1252 | static LogicalResult verifyTrait(Operation *op) { |
1253 | if (llvm::isa<ParentOpTypes...>(op->getParentOp())) |
1254 | return success(); |
1255 | |
1256 | return op->emitOpError() |
1257 | << "expects parent op " |
1258 | << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") |
1259 | << llvm::makeArrayRef({ParentOpTypes::getOperationName()...}) |
1260 | << "'"; |
1261 | } |
1262 | }; |
1263 | }; |
1264 | |
1265 | /// A trait for operations that have an attribute specifying operand segments. |
1266 | /// |
1267 | /// Certain operations can have multiple variadic operands and their size |
1268 | /// relationship is not always known statically. For such cases, we need |
1269 | /// a per-op-instance specification to divide the operands into logical groups |
1270 | /// or segments. This can be modeled by attributes. The attribute will be named |
1271 | /// as `operand_segment_sizes`. |
1272 | /// |
1273 | /// This trait verifies the attribute for specifying operand segments has |
1274 | /// the correct type (1D vector) and values (non-negative), etc. |
1275 | template <typename ConcreteType> |
1276 | class AttrSizedOperandSegments |
1277 | : public TraitBase<ConcreteType, AttrSizedOperandSegments> { |
1278 | public: |
1279 | static StringRef getOperandSegmentSizeAttr() { |
1280 | return "operand_segment_sizes"; |
1281 | } |
1282 | |
1283 | static LogicalResult verifyTrait(Operation *op) { |
1284 | return ::mlir::OpTrait::impl::verifyOperandSizeAttr( |
1285 | op, getOperandSegmentSizeAttr()); |
1286 | } |
1287 | }; |
1288 | |
1289 | /// Similar to AttrSizedOperandSegments but used for results. |
1290 | template <typename ConcreteType> |
1291 | class AttrSizedResultSegments |
1292 | : public TraitBase<ConcreteType, AttrSizedResultSegments> { |
1293 | public: |
1294 | static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } |
1295 | |
1296 | static LogicalResult verifyTrait(Operation *op) { |
1297 | return ::mlir::OpTrait::impl::verifyResultSizeAttr( |
1298 | op, getResultSegmentSizeAttr()); |
1299 | } |
1300 | }; |
1301 | |
1302 | /// This trait provides a verifier for ops that are expecting their regions to |
1303 | /// not have any arguments |
1304 | template <typename ConcrentType> |
1305 | struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> { |
1306 | static LogicalResult verifyTrait(Operation *op) { |
1307 | return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); |
1308 | } |
1309 | }; |
1310 | |
1311 | // This trait is used to flag operations that consume or produce |
1312 | // values of `MemRef` type where those references can be 'normalized'. |
1313 | // TODO: Right now, the operands of an operation are either all normalizable, |
1314 | // or not. In the future, we may want to allow some of the operands to be |
1315 | // normalizable. |
1316 | template <typename ConcrentType> |
1317 | struct MemRefsNormalizable |
1318 | : public TraitBase<ConcrentType, MemRefsNormalizable> {}; |
1319 | |
1320 | /// This trait tags element-wise ops on vectors or tensors. |
1321 | /// |
1322 | /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this |
1323 | /// trait. In particular, broadcasting behavior is not allowed. |
1324 | /// |
1325 | /// An `Elementwise` op must satisfy the following properties: |
1326 | /// |
1327 | /// 1. If any result is a vector/tensor then at least one operand must also be a |
1328 | /// vector/tensor. |
1329 | /// 2. If any operand is a vector/tensor then there must be at least one result |
1330 | /// and all results must be vectors/tensors. |
1331 | /// 3. All operand and result vector/tensor types must be of the same shape. The |
1332 | /// shape may be dynamic in which case the op's behaviour is undefined for |
1333 | /// non-matching shapes. |
1334 | /// 4. The operation must be elementwise on its vector/tensor operands and |
1335 | /// results. When applied to single-element vectors/tensors, the result must |
1336 | /// be the same per elememnt. |
1337 | /// |
1338 | /// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new |
1339 | /// interface `ElementwiseTypeInterface` that describes the container types for |
1340 | /// which the operation is elementwise. |
1341 | /// |
1342 | /// Rationale: |
1343 | /// - 1. and 2. guarantee a well-defined iteration space and exclude the cases |
1344 | /// of 0 non-scalar operands or 0 non-scalar results, which complicate a |
1345 | /// generic definition of the iteration space. |
1346 | /// - 3. guarantees that folding can be done across scalars/vectors/tensors with |
1347 | /// the same pattern, as otherwise lots of special handling for type |
1348 | /// mismatches would be needed. |
1349 | /// - 4. guarantees that no error handling is needed. Higher-level dialects |
1350 | /// should reify any needed guards or error handling code before lowering to |
1351 | /// an `Elementwise` op. |
1352 | template <typename ConcreteType> |
1353 | struct Elementwise : public TraitBase<ConcreteType, Elementwise> { |
1354 | static LogicalResult verifyTrait(Operation *op) { |
1355 | return ::mlir::OpTrait::impl::verifyElementwise(op); |
1356 | } |
1357 | }; |
1358 | |
1359 | /// This trait tags `Elementwise` operatons that can be systematically |
1360 | /// scalarized. All vector/tensor operands and results are then replaced by |
1361 | /// scalars of the respective element type. Semantically, this is the operation |
1362 | /// on a single element of the vector/tensor. |
1363 | /// |
1364 | /// Rationale: |
1365 | /// Allow to define the vector/tensor semantics of elementwise operations based |
1366 | /// on the same op's behavior on scalars. This provides a constructive procedure |
1367 | /// for IR transformations to, e.g., create scalar loop bodies from tensor ops. |
1368 | /// |
1369 | /// Example: |
1370 | /// ``` |
1371 | /// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val) |
1372 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
1373 | /// -> tensor<?xf32> |
1374 | /// ``` |
1375 | /// can be scalarized to |
1376 | /// |
1377 | /// ``` |
1378 | /// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar) |
1379 | /// : (i1, f32, f32) -> f32 |
1380 | /// ``` |
1381 | template <typename ConcreteType> |
1382 | struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> { |
1383 | static LogicalResult verifyTrait(Operation *op) { |
1384 | static_assert( |
1385 | ConcreteType::template hasTrait<Elementwise>(), |
1386 | "`Scalarizable` trait is only applicable to `Elementwise` ops."); |
1387 | return success(); |
1388 | } |
1389 | }; |
1390 | |
1391 | /// This trait tags `Elementwise` operatons that can be systematically |
1392 | /// vectorized. All scalar operands and results are then replaced by vectors |
1393 | /// with the respective element type. Semantically, this is the operation on |
1394 | /// multiple elements simultaneously. See also `Tensorizable`. |
1395 | /// |
1396 | /// Rationale: |
1397 | /// Provide the reverse to `Scalarizable` which, when chained together, allows |
1398 | /// reasoning about the relationship between the tensor and vector case. |
1399 | /// Additionally, it permits reasoning about promoting scalars to vectors via |
1400 | /// broadcasting in cases like `%select_scalar_pred` below. |
1401 | template <typename ConcreteType> |
1402 | struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> { |
1403 | static LogicalResult verifyTrait(Operation *op) { |
1404 | static_assert( |
1405 | ConcreteType::template hasTrait<Elementwise>(), |
1406 | "`Vectorizable` trait is only applicable to `Elementwise` ops."); |
1407 | return success(); |
1408 | } |
1409 | }; |
1410 | |
1411 | /// This trait tags `Elementwise` operatons that can be systematically |
1412 | /// tensorized. All scalar operands and results are then replaced by tensors |
1413 | /// with the respective element type. Semantically, this is the operation on |
1414 | /// multiple elements simultaneously. See also `Vectorizable`. |
1415 | /// |
1416 | /// Rationale: |
1417 | /// Provide the reverse to `Scalarizable` which, when chained together, allows |
1418 | /// reasoning about the relationship between the tensor and vector case. |
1419 | /// Additionally, it permits reasoning about promoting scalars to tensors via |
1420 | /// broadcasting in cases like `%select_scalar_pred` below. |
1421 | /// |
1422 | /// Examples: |
1423 | /// ``` |
1424 | /// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 |
1425 | /// ``` |
1426 | /// can be tensorized to |
1427 | /// ``` |
1428 | /// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>) |
1429 | /// -> tensor<?xf32> |
1430 | /// ``` |
1431 | /// |
1432 | /// ``` |
1433 | /// %scalar_pred = "arith.select"(%pred, %true_val, %false_val) |
1434 | /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> |
1435 | /// ``` |
1436 | /// can be tensorized to |
1437 | /// ``` |
1438 | /// %tensor_pred = "arith.select"(%pred, %true_val, %false_val) |
1439 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) |
1440 | /// -> tensor<?xf32> |
1441 | /// ``` |
1442 | template <typename ConcreteType> |
1443 | struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> { |
1444 | static LogicalResult verifyTrait(Operation *op) { |
1445 | static_assert( |
1446 | ConcreteType::template hasTrait<Elementwise>(), |
1447 | "`Tensorizable` trait is only applicable to `Elementwise` ops."); |
1448 | return success(); |
1449 | } |
1450 | }; |
1451 | |
1452 | /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` |
1453 | /// provide an easy way for scalar operations to conveniently generalize their |
1454 | /// behavior to vectors/tensors, and systematize conversion between these forms. |
1455 | bool hasElementwiseMappableTraits(Operation *op); |
1456 | |
1457 | } // namespace OpTrait |
1458 | |
1459 | //===----------------------------------------------------------------------===// |
1460 | // Internal Trait Utilities |
1461 | //===----------------------------------------------------------------------===// |
1462 | |
1463 | namespace op_definition_impl { |
1464 | //===----------------------------------------------------------------------===// |
1465 | // Trait Existence |
1466 | |
1467 | /// Returns true if this given Trait ID matches the IDs of any of the provided |
1468 | /// trait types `Traits`. |
1469 | template <template <typename T> class... Traits> |
1470 | static bool hasTrait(TypeID traitID) { |
1471 | TypeID traitIDs[] = {TypeID::get<Traits>()...}; |
1472 | for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) |
1473 | if (traitIDs[i] == traitID) |
1474 | return true; |
1475 | return false; |
1476 | } |
1477 | |
1478 | //===----------------------------------------------------------------------===// |
1479 | // Trait Folding |
1480 | |
1481 | /// Trait to check if T provides a 'foldTrait' method for single result |
1482 | /// operations. |
1483 | template <typename T, typename... Args> |
1484 | using has_single_result_fold_trait = decltype(T::foldTrait( |
1485 | std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); |
1486 | template <typename T> |
1487 | using detect_has_single_result_fold_trait = |
1488 | llvm::is_detected<has_single_result_fold_trait, T>; |
1489 | /// Trait to check if T provides a general 'foldTrait' method. |
1490 | template <typename T, typename... Args> |
1491 | using has_fold_trait = |
1492 | decltype(T::foldTrait(std::declval<Operation *>(), |
1493 | std::declval<ArrayRef<Attribute>>(), |
1494 | std::declval<SmallVectorImpl<OpFoldResult> &>())); |
1495 | template <typename T> |
1496 | using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; |
1497 | /// Trait to check if T provides any `foldTrait` method. |
1498 | /// NOTE: This should use std::disjunction when C++17 is available. |
1499 | template <typename T> |
1500 | using detect_has_any_fold_trait = |
1501 | std::conditional_t<bool(detect_has_fold_trait<T>::value), |
1502 | detect_has_fold_trait<T>, |
1503 | detect_has_single_result_fold_trait<T>>; |
1504 | |
1505 | /// Returns the result of folding a trait that implements a `foldTrait` function |
1506 | /// that is specialized for operations that have a single result. |
1507 | template <typename Trait> |
1508 | static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, |
1509 | LogicalResult> |
1510 | foldTrait(Operation *op, ArrayRef<Attribute> operands, |
1511 | SmallVectorImpl<OpFoldResult> &results) { |
1512 | assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__ )) |
1513 | "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__ )) |
1514 | "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1514, __extension__ __PRETTY_FUNCTION__ )); |
1515 | // If a previous trait has already been folded and replaced this operation, we |
1516 | // fail to fold this trait. |
1517 | if (!results.empty()) |
1518 | return failure(); |
1519 | |
1520 | if (OpFoldResult result = Trait::foldTrait(op, operands)) { |
1521 | if (result.template dyn_cast<Value>() != op->getResult(0)) |
1522 | results.push_back(result); |
1523 | return success(); |
1524 | } |
1525 | return failure(); |
1526 | } |
1527 | /// Returns the result of folding a trait that implements a generalized |
1528 | /// `foldTrait` function that is supports any operation type. |
1529 | template <typename Trait> |
1530 | static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> |
1531 | foldTrait(Operation *op, ArrayRef<Attribute> operands, |
1532 | SmallVectorImpl<OpFoldResult> &results) { |
1533 | // If a previous trait has already been folded and replaced this operation, we |
1534 | // fail to fold this trait. |
1535 | return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); |
1536 | } |
1537 | |
1538 | /// The internal implementation of `foldTraits` below that returns the result of |
1539 | /// folding a set of trait types `Ts` that implement a `foldTrait` method. |
1540 | template <typename... Ts> |
1541 | static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands, |
1542 | SmallVectorImpl<OpFoldResult> &results, |
1543 | std::tuple<Ts...> *) { |
1544 | bool anyFolded = false; |
1545 | (void)std::initializer_list<int>{ |
1546 | (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...}; |
1547 | return success(anyFolded); |
1548 | } |
1549 | |
1550 | /// Given a tuple type containing a set of traits that contain a `foldTrait` |
1551 | /// method, return the result of folding the given operation. |
1552 | template <typename TraitTupleT> |
1553 | static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult> |
1554 | foldTraits(Operation *op, ArrayRef<Attribute> operands, |
1555 | SmallVectorImpl<OpFoldResult> &results) { |
1556 | return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr); |
1557 | } |
1558 | /// A variant of the method above that is specialized when there are no traits |
1559 | /// that contain a `foldTrait` method. |
1560 | template <typename TraitTupleT> |
1561 | static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult> |
1562 | foldTraits(Operation *op, ArrayRef<Attribute> operands, |
1563 | SmallVectorImpl<OpFoldResult> &results) { |
1564 | return failure(); |
1565 | } |
1566 | |
1567 | //===----------------------------------------------------------------------===// |
1568 | // Trait Verification |
1569 | |
1570 | /// Trait to check if T provides a `verifyTrait` method. |
1571 | template <typename T, typename... Args> |
1572 | using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); |
1573 | template <typename T> |
1574 | using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; |
1575 | |
1576 | /// The internal implementation of `verifyTraits` below that returns the result |
1577 | /// of verifying the current operation with all of the provided trait types |
1578 | /// `Ts`. |
1579 | template <typename... Ts> |
1580 | static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) { |
1581 | LogicalResult result = success(); |
1582 | (void)std::initializer_list<int>{ |
1583 | (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...}; |
1584 | return result; |
1585 | } |
1586 | |
1587 | /// Given a tuple type containing a set of traits that contain a |
1588 | /// `verifyTrait` method, return the result of verifying the given operation. |
1589 | template <typename TraitTupleT> |
1590 | static LogicalResult verifyTraits(Operation *op) { |
1591 | return verifyTraitsImpl(op, (TraitTupleT *)nullptr); |
1592 | } |
1593 | } // namespace op_definition_impl |
1594 | |
1595 | //===----------------------------------------------------------------------===// |
1596 | // Operation Definition classes |
1597 | //===----------------------------------------------------------------------===// |
1598 | |
1599 | /// This provides public APIs that all operations should have. The template |
1600 | /// argument 'ConcreteType' should be the concrete type by CRTP and the others |
1601 | /// are base classes by the policy pattern. |
1602 | template <typename ConcreteType, template <typename T> class... Traits> |
1603 | class Op : public OpState, public Traits<ConcreteType>... { |
1604 | public: |
1605 | /// Inherit getOperation from `OpState`. |
1606 | using OpState::getOperation; |
1607 | using OpState::verifyInvariants; |
1608 | |
1609 | /// Return if this operation contains the provided trait. |
1610 | template <template <typename T> class Trait> |
1611 | static constexpr bool hasTrait() { |
1612 | return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; |
1613 | } |
1614 | |
1615 | /// Create a deep copy of this operation. |
1616 | ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } |
1617 | |
1618 | /// Create a partial copy of this operation without traversing into attached |
1619 | /// regions. The new operation will have the same number of regions as the |
1620 | /// original one, but they will be left empty. |
1621 | ConcreteType cloneWithoutRegions() { |
1622 | return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); |
1623 | } |
1624 | |
1625 | /// Return true if this "op class" can match against the specified operation. |
1626 | static bool classof(Operation *op) { |
1627 | if (auto info = op->getRegisteredInfo()) |
1628 | return TypeID::get<ConcreteType>() == info->getTypeID(); |
1629 | #ifndef NDEBUG |
1630 | if (op->getName().getStringRef() == ConcreteType::getOperationName()) |
1631 | llvm::report_fatal_error( |
1632 | "classof on '" + ConcreteType::getOperationName() + |
1633 | "' failed due to the operation not being registered"); |
1634 | #endif |
1635 | return false; |
1636 | } |
1637 | /// Provide `classof` support for other OpBase derived classes, such as |
1638 | /// Interfaces. |
1639 | template <typename T> |
1640 | static std::enable_if_t<std::is_base_of<OpState, T>::value, bool> |
1641 | classof(const T *op) { |
1642 | return classof(const_cast<T *>(op)->getOperation()); |
1643 | } |
1644 | |
1645 | /// Expose the type we are instantiated on to template machinery that may want |
1646 | /// to introspect traits on this operation. |
1647 | using ConcreteOpType = ConcreteType; |
1648 | |
1649 | /// This is a public constructor. Any op can be initialized to null. |
1650 | explicit Op() : OpState(nullptr) {} |
1651 | Op(std::nullptr_t) : OpState(nullptr) {} |
1652 | |
1653 | /// This is a public constructor to enable access via the llvm::cast family of |
1654 | /// methods. This should not be used directly. |
1655 | explicit Op(Operation *state) : OpState(state) {} |
1656 | |
1657 | /// Methods for supporting PointerLikeTypeTraits. |
1658 | const void *getAsOpaquePointer() const { |
1659 | return static_cast<const void *>((Operation *)*this); |
1660 | } |
1661 | static ConcreteOpType getFromOpaquePointer(const void *pointer) { |
1662 | return ConcreteOpType( |
1663 | reinterpret_cast<Operation *>(const_cast<void *>(pointer))); |
1664 | } |
1665 | |
1666 | /// Attach the given models as implementations of the corresponding interfaces |
1667 | /// for the concrete operation. |
1668 | template <typename... Models> |
1669 | static void attachInterface(MLIRContext &context) { |
1670 | Optional<RegisteredOperationName> info = RegisteredOperationName::lookup( |
1671 | ConcreteType::getOperationName(), &context); |
1672 | if (!info) |
1673 | llvm::report_fatal_error( |
1674 | "Attempting to attach an interface to an unregistered operation " + |
1675 | ConcreteType::getOperationName() + "."); |
1676 | info->attachInterface<Models...>(); |
1677 | } |
1678 | |
1679 | private: |
1680 | /// Trait to check if T provides a 'fold' method for a single result op. |
1681 | template <typename T, typename... Args> |
1682 | using has_single_result_fold = |
1683 | decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); |
1684 | template <typename T> |
1685 | using detect_has_single_result_fold = |
1686 | llvm::is_detected<has_single_result_fold, T>; |
1687 | /// Trait to check if T provides a general 'fold' method. |
1688 | template <typename T, typename... Args> |
1689 | using has_fold = decltype(std::declval<T>().fold( |
1690 | std::declval<ArrayRef<Attribute>>(), |
1691 | std::declval<SmallVectorImpl<OpFoldResult> &>())); |
1692 | template <typename T> |
1693 | using detect_has_fold = llvm::is_detected<has_fold, T>; |
1694 | /// Trait to check if T provides a 'print' method. |
1695 | template <typename T, typename... Args> |
1696 | using has_print = |
1697 | decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); |
1698 | template <typename T> |
1699 | using detect_has_print = llvm::is_detected<has_print, T>; |
1700 | /// A tuple type containing the traits that have a `foldTrait` function. |
1701 | using FoldableTraitsTupleT = typename detail::FilterTypes< |
1702 | op_definition_impl::detect_has_any_fold_trait, |
1703 | Traits<ConcreteType>...>::type; |
1704 | /// A tuple type containing the traits that have a verify function. |
1705 | using VerifiableTraitsTupleT = |
1706 | typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait, |
1707 | Traits<ConcreteType>...>::type; |
1708 | |
1709 | /// Returns an interface map containing the interfaces registered to this |
1710 | /// operation. |
1711 | static detail::InterfaceMap getInterfaceMap() { |
1712 | return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); |
1713 | } |
1714 | |
1715 | /// Return the internal implementations of each of the OperationName |
1716 | /// hooks. |
1717 | /// Implementation of `FoldHookFn` OperationName hook. |
1718 | static OperationName::FoldHookFn getFoldHookFn() { |
1719 | return getFoldHookFnImpl<ConcreteType>(); |
1720 | } |
1721 | /// The internal implementation of `getFoldHookFn` above that is invoked if |
1722 | /// the operation is single result and defines a `fold` method. |
1723 | template <typename ConcreteOpT> |
1724 | static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, |
1725 | Traits<ConcreteOpT>...>::value && |
1726 | detect_has_single_result_fold<ConcreteOpT>::value, |
1727 | OperationName::FoldHookFn> |
1728 | getFoldHookFnImpl() { |
1729 | return [](Operation *op, ArrayRef<Attribute> operands, |
1730 | SmallVectorImpl<OpFoldResult> &results) { |
1731 | return foldSingleResultHook<ConcreteOpT>(op, operands, results); |
1732 | }; |
1733 | } |
1734 | /// The internal implementation of `getFoldHookFn` above that is invoked if |
1735 | /// the operation is not single result and defines a `fold` method. |
1736 | template <typename ConcreteOpT> |
1737 | static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, |
1738 | Traits<ConcreteOpT>...>::value && |
1739 | detect_has_fold<ConcreteOpT>::value, |
1740 | OperationName::FoldHookFn> |
1741 | getFoldHookFnImpl() { |
1742 | return [](Operation *op, ArrayRef<Attribute> operands, |
1743 | SmallVectorImpl<OpFoldResult> &results) { |
1744 | return foldHook<ConcreteOpT>(op, operands, results); |
1745 | }; |
1746 | } |
1747 | /// The internal implementation of `getFoldHookFn` above that is invoked if |
1748 | /// the operation does not define a `fold` method. |
1749 | template <typename ConcreteOpT> |
1750 | static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value && |
1751 | !detect_has_fold<ConcreteOpT>::value, |
1752 | OperationName::FoldHookFn> |
1753 | getFoldHookFnImpl() { |
1754 | return [](Operation *op, ArrayRef<Attribute> operands, |
1755 | SmallVectorImpl<OpFoldResult> &results) { |
1756 | // In this case, we only need to fold the traits of the operation. |
1757 | return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands, |
1758 | results); |
1759 | }; |
1760 | } |
1761 | /// Return the result of folding a single result operation that defines a |
1762 | /// `fold` method. |
1763 | template <typename ConcreteOpT> |
1764 | static LogicalResult |
1765 | foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, |
1766 | SmallVectorImpl<OpFoldResult> &results) { |
1767 | OpFoldResult result = cast<ConcreteOpT>(op).fold(operands); |
1768 | |
1769 | // If the fold failed or was in-place, try to fold the traits of the |
1770 | // operation. |
1771 | if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { |
1772 | if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( |
1773 | op, operands, results))) |
1774 | return success(); |
1775 | return success(static_cast<bool>(result)); |
1776 | } |
1777 | results.push_back(result); |
1778 | return success(); |
1779 | } |
1780 | /// Return the result of folding an operation that defines a `fold` method. |
1781 | template <typename ConcreteOpT> |
1782 | static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, |
1783 | SmallVectorImpl<OpFoldResult> &results) { |
1784 | LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results); |
1785 | |
1786 | // If the fold failed or was in-place, try to fold the traits of the |
1787 | // operation. |
1788 | if (failed(result) || results.empty()) { |
1789 | if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>( |
1790 | op, operands, results))) |
1791 | return success(); |
1792 | } |
1793 | return result; |
1794 | } |
1795 | |
1796 | /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook. |
1797 | static OperationName::GetCanonicalizationPatternsFn |
1798 | getGetCanonicalizationPatternsFn() { |
1799 | return &ConcreteType::getCanonicalizationPatterns; |
1800 | } |
1801 | /// Implementation of `GetHasTraitFn` |
1802 | static OperationName::HasTraitFn getHasTraitFn() { |
1803 | return |
1804 | [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); }; |
1805 | } |
1806 | /// Implementation of `ParseAssemblyFn` OperationName hook. |
1807 | static OperationName::ParseAssemblyFn getParseAssemblyFn() { |
1808 | return &ConcreteType::parse; |
1809 | } |
1810 | /// Implementation of `PrintAssemblyFn` OperationName hook. |
1811 | static OperationName::PrintAssemblyFn getPrintAssemblyFn() { |
1812 | return getPrintAssemblyFnImpl<ConcreteType>(); |
1813 | } |
1814 | /// The internal implementation of `getPrintAssemblyFn` that is invoked when |
1815 | /// the concrete operation does not define a `print` method. |
1816 | template <typename ConcreteOpT> |
1817 | static std::enable_if_t<!detect_has_print<ConcreteOpT>::value, |
1818 | OperationName::PrintAssemblyFn> |
1819 | getPrintAssemblyFnImpl() { |
1820 | return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { |
1821 | return OpState::print(op, printer, defaultDialect); |
1822 | }; |
1823 | } |
1824 | /// The internal implementation of `getPrintAssemblyFn` that is invoked when |
1825 | /// the concrete operation defines a `print` method. |
1826 | template <typename ConcreteOpT> |
1827 | static std::enable_if_t<detect_has_print<ConcreteOpT>::value, |
1828 | OperationName::PrintAssemblyFn> |
1829 | getPrintAssemblyFnImpl() { |
1830 | return &printAssembly; |
1831 | } |
1832 | static void printAssembly(Operation *op, OpAsmPrinter &p, |
1833 | StringRef defaultDialect) { |
1834 | OpState::printOpName(op, p, defaultDialect); |
1835 | return cast<ConcreteType>(op).print(p); |
1836 | } |
1837 | /// Implementation of `VerifyInvariantsFn` OperationName hook. |
1838 | static LogicalResult verifyInvariants(Operation *op) { |
1839 | static_assert(hasNoDataMembers(), |
1840 | "Op class shouldn't define new data members"); |
1841 | return failure( |
1842 | failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) || |
1843 | failed(cast<ConcreteType>(op).verifyInvariants())); |
1844 | } |
1845 | static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { |
1846 | return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants); |
1847 | } |
1848 | |
1849 | static constexpr bool hasNoDataMembers() { |
1850 | // Checking that the derived class does not define any member by comparing |
1851 | // its size to an ad-hoc EmptyOp. |
1852 | class EmptyOp : public Op<EmptyOp, Traits...> {}; |
1853 | return sizeof(ConcreteType) == sizeof(EmptyOp); |
1854 | } |
1855 | |
1856 | /// Allow access to internal implementation methods. |
1857 | friend RegisteredOperationName; |
1858 | }; |
1859 | |
1860 | /// This class represents the base of an operation interface. See the definition |
1861 | /// of `detail::Interface` for requirements on the `Traits` type. |
1862 | template <typename ConcreteType, typename Traits> |
1863 | class OpInterface |
1864 | : public detail::Interface<ConcreteType, Operation *, Traits, |
1865 | Op<ConcreteType>, OpTrait::TraitBase> { |
1866 | public: |
1867 | using Base = OpInterface<ConcreteType, Traits>; |
1868 | using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits, |
1869 | Op<ConcreteType>, OpTrait::TraitBase>; |
1870 | |
1871 | /// Inherit the base class constructor. |
1872 | using InterfaceBase::InterfaceBase; |
1873 | |
1874 | protected: |
1875 | /// Returns the impl interface instance for the given operation. |
1876 | static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { |
1877 | OperationName name = op->getName(); |
1878 | |
1879 | // Access the raw interface from the operation info. |
1880 | if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) { |
1881 | if (auto *opIface = rInfo->getInterface<ConcreteType>()) |
1882 | return opIface; |
1883 | // Fallback to the dialect to provide it with a chance to implement this |
1884 | // interface for this operation. |
1885 | return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>( |
1886 | op->getName()); |
1887 | } |
1888 | // Fallback to the dialect to provide it with a chance to implement this |
1889 | // interface for this operation. |
1890 | if (Dialect *dialect = name.getDialect()) |
1891 | return dialect->getRegisteredInterfaceForOp<ConcreteType>(name); |
1892 | return nullptr; |
1893 | } |
1894 | |
1895 | /// Allow access to `getInterfaceFor`. |
1896 | friend InterfaceBase; |
1897 | }; |
1898 | |
1899 | //===----------------------------------------------------------------------===// |
1900 | // CastOpInterface utilities |
1901 | //===----------------------------------------------------------------------===// |
1902 | |
1903 | // These functions are out-of-line implementations of the methods in |
1904 | // CastOpInterface, which avoids them being template instantiated/duplicated. |
1905 | namespace impl { |
1906 | /// Attempt to fold the given cast operation. |
1907 | LogicalResult foldCastInterfaceOp(Operation *op, |
1908 | ArrayRef<Attribute> attrOperands, |
1909 | SmallVectorImpl<OpFoldResult> &foldResults); |
1910 | /// Attempt to verify the given cast operation. |
1911 | LogicalResult verifyCastInterfaceOp( |
1912 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible); |
1913 | } // namespace impl |
1914 | } // namespace mlir |
1915 | |
1916 | namespace llvm { |
1917 | |
1918 | template <typename T> |
1919 | struct DenseMapInfo< |
1920 | T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> { |
1921 | static inline T getEmptyKey() { |
1922 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
1923 | return T::getFromOpaquePointer(pointer); |
1924 | } |
1925 | static inline T getTombstoneKey() { |
1926 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
1927 | return T::getFromOpaquePointer(pointer); |
1928 | } |
1929 | static unsigned getHashValue(T val) { |
1930 | return hash_value(val.getAsOpaquePointer()); |
1931 | } |
1932 | static bool isEqual(T lhs, T rhs) { return lhs == rhs; } |
1933 | }; |
1934 | |
1935 | } // namespace llvm |
1936 | |
1937 | #endif |
1 | //===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_SUPPORT_LOGICALRESULT_H |
10 | #define MLIR_SUPPORT_LOGICALRESULT_H |
11 | |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "llvm/ADT/Optional.h" |
14 | |
15 | namespace mlir { |
16 | |
17 | /// This class represents an efficient way to signal success or failure. It |
18 | /// should be preferred over the use of `bool` when appropriate, as it avoids |
19 | /// all of the ambiguity that arises in interpreting a boolean result. This |
20 | /// class is marked as NODISCARD to ensure that the result is processed. Users |
21 | /// may explicitly discard a result by using `(void)`, e.g. |
22 | /// `(void)functionThatReturnsALogicalResult();`. Given the intended nature of |
23 | /// this class, it generally shouldn't be used as the result of functions that |
24 | /// very frequently have the result ignored. This class is intended to be used |
25 | /// in conjunction with the utility functions below. |
26 | struct LLVM_NODISCARD[[clang::warn_unused_result]] LogicalResult { |
27 | public: |
28 | /// If isSuccess is true a `success` result is generated, otherwise a |
29 | /// 'failure' result is generated. |
30 | static LogicalResult success(bool isSuccess = true) { |
31 | return LogicalResult(isSuccess); |
32 | } |
33 | |
34 | /// If isFailure is true a `failure` result is generated, otherwise a |
35 | /// 'success' result is generated. |
36 | static LogicalResult failure(bool isFailure = true) { |
37 | return success(!isFailure); |
38 | } |
39 | |
40 | /// Returns true if the provided LogicalResult corresponds to a success value. |
41 | bool succeeded() const { return isSuccess; } |
42 | |
43 | /// Returns true if the provided LogicalResult corresponds to a failure value. |
44 | bool failed() const { return !succeeded(); } |
45 | |
46 | private: |
47 | LogicalResult(bool isSuccess) : isSuccess(isSuccess) {} |
48 | |
49 | /// Boolean indicating if this is a success result, if false this is a |
50 | /// failure result. |
51 | bool isSuccess; |
52 | }; |
53 | |
54 | /// Utility function to generate a LogicalResult. If isSuccess is true a |
55 | /// `success` result is generated, otherwise a 'failure' result is generated. |
56 | inline LogicalResult success(bool isSuccess = true) { |
57 | return LogicalResult::success(isSuccess); |
58 | } |
59 | |
60 | /// Utility function to generate a LogicalResult. If isFailure is true a |
61 | /// `failure` result is generated, otherwise a 'success' result is generated. |
62 | inline LogicalResult failure(bool isFailure = true) { |
63 | return LogicalResult::failure(isFailure); |
64 | } |
65 | |
66 | /// Utility function that returns true if the provided LogicalResult corresponds |
67 | /// to a success value. |
68 | inline bool succeeded(LogicalResult result) { return result.succeeded(); } |
69 | |
70 | /// Utility function that returns true if the provided LogicalResult corresponds |
71 | /// to a failure value. |
72 | inline bool failed(LogicalResult result) { return result.failed(); } |
73 | |
74 | /// This class provides support for representing a failure result, or a valid |
75 | /// value of type `T`. This allows for integrating with LogicalResult, while |
76 | /// also providing a value on the success path. |
77 | template <typename T> class LLVM_NODISCARD[[clang::warn_unused_result]] FailureOr : public Optional<T> { |
78 | public: |
79 | /// Allow constructing from a LogicalResult. The result *must* be a failure. |
80 | /// Success results should use a proper instance of type `T`. |
81 | FailureOr(LogicalResult result) { |
82 | assert(failed(result) &&(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'" ) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\"" , "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__ __PRETTY_FUNCTION__)) |
83 | "success should be constructed with an instance of 'T'")(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'" ) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\"" , "mlir/include/mlir/Support/LogicalResult.h", 83, __extension__ __PRETTY_FUNCTION__)); |
84 | } |
85 | FailureOr() : FailureOr(failure()) {} |
86 | FailureOr(T &&y) : Optional<T>(std::forward<T>(y)) {} |
87 | FailureOr(const T &y) : Optional<T>(y) {} |
88 | template <typename U, |
89 | std::enable_if_t<std::is_constructible<T, U>::value> * = nullptr> |
90 | FailureOr(const FailureOr<U> &other) |
91 | : Optional<T>(failed(other) ? Optional<T>() : Optional<T>(*other)) {} |
92 | |
93 | operator LogicalResult() const { return success(this->hasValue()); } |
94 | |
95 | private: |
96 | /// Hide the bool conversion as it easily creates confusion. |
97 | using Optional<T>::operator bool; |
98 | using Optional<T>::hasValue; |
99 | }; |
100 | |
101 | } // namespace mlir |
102 | |
103 | #endif // MLIR_SUPPORT_LOGICALRESULT_H |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | |
24 | class Builder; |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // AsmPrinter |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | /// This base class exposes generic asm printer hooks, usable across the various |
31 | /// derived printers. |
32 | class AsmPrinter { |
33 | public: |
34 | /// This class contains the internal default implementation of the base |
35 | /// printer methods. |
36 | class Impl; |
37 | |
38 | /// Initialize the printer with the given internal implementation. |
39 | AsmPrinter(Impl &impl) : impl(&impl) {} |
40 | virtual ~AsmPrinter(); |
41 | |
42 | /// Return the raw output stream used by this printer. |
43 | virtual raw_ostream &getStream() const; |
44 | |
45 | /// Print the given floating point value in a stabilized form that can be |
46 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
47 | /// hook on the AsmParser. |
48 | virtual void printFloat(const APFloat &value); |
49 | |
50 | virtual void printType(Type type); |
51 | virtual void printAttribute(Attribute attr); |
52 | |
53 | /// Trait to check if `AttrType` provides a `print` method. |
54 | template <typename AttrOrType> |
55 | using has_print_method = |
56 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
57 | template <typename AttrOrType> |
58 | using detect_has_print_method = |
59 | llvm::is_detected<has_print_method, AttrOrType>; |
60 | |
61 | /// Print the provided attribute in the context of an operation custom |
62 | /// printer/parser: this will invoke directly the print method on the |
63 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
64 | template <typename AttrOrType, |
65 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
66 | *sfinae = nullptr> |
67 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
68 | if (succeeded(printAlias(attrOrType))) |
69 | return; |
70 | attrOrType.print(*this); |
71 | } |
72 | |
73 | /// SFINAE for printing the provided attribute in the context of an operation |
74 | /// custom printer in the case where the attribute does not define a print |
75 | /// method. |
76 | template <typename AttrOrType, |
77 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
78 | *sfinae = nullptr> |
79 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
80 | *this << attrOrType; |
81 | } |
82 | |
83 | /// Print the given attribute without its type. The corresponding parser must |
84 | /// provide a valid type for the attribute. |
85 | virtual void printAttributeWithoutType(Attribute attr); |
86 | |
87 | /// Print the given string as a keyword, or a quoted and escaped string if it |
88 | /// has any special or non-printable characters in it. |
89 | virtual void printKeywordOrString(StringRef keyword); |
90 | |
91 | /// Print the given string as a symbol reference, i.e. a form representable by |
92 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
93 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
94 | /// special or non-printable characters in it. |
95 | virtual void printSymbolName(StringRef symbolRef); |
96 | |
97 | /// Print an optional arrow followed by a type list. |
98 | template <typename TypeRange> |
99 | void printOptionalArrowTypeList(TypeRange &&types) { |
100 | if (types.begin() != types.end()) |
101 | printArrowTypeList(types); |
102 | } |
103 | template <typename TypeRange> |
104 | void printArrowTypeList(TypeRange &&types) { |
105 | auto &os = getStream() << " -> "; |
106 | |
107 | bool wrapped = !llvm::hasSingleElement(types) || |
108 | (*types.begin()).template isa<FunctionType>(); |
109 | if (wrapped) |
110 | os << '('; |
111 | llvm::interleaveComma(types, *this); |
112 | if (wrapped) |
113 | os << ')'; |
114 | } |
115 | |
116 | /// Print the two given type ranges in a functional form. |
117 | template <typename InputRangeT, typename ResultRangeT> |
118 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
119 | auto &os = getStream(); |
120 | os << '('; |
121 | llvm::interleaveComma(inputs, *this); |
122 | os << ')'; |
123 | printArrowTypeList(results); |
124 | } |
125 | |
126 | protected: |
127 | /// Initialize the printer with no internal implementation. In this case, all |
128 | /// virtual methods of this class must be overriden. |
129 | AsmPrinter() {} |
130 | |
131 | private: |
132 | AsmPrinter(const AsmPrinter &) = delete; |
133 | void operator=(const AsmPrinter &) = delete; |
134 | |
135 | /// Print the alias for the given attribute, return failure if no alias could |
136 | /// be printed. |
137 | virtual LogicalResult printAlias(Attribute attr); |
138 | |
139 | /// Print the alias for the given type, return failure if no alias could |
140 | /// be printed. |
141 | virtual LogicalResult printAlias(Type type); |
142 | |
143 | /// The internal implementation of the printer. |
144 | Impl *impl{nullptr}; |
145 | }; |
146 | |
147 | template <typename AsmPrinterT> |
148 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
149 | AsmPrinterT &> |
150 | operator<<(AsmPrinterT &p, Type type) { |
151 | p.printType(type); |
152 | return p; |
153 | } |
154 | |
155 | template <typename AsmPrinterT> |
156 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
157 | AsmPrinterT &> |
158 | operator<<(AsmPrinterT &p, Attribute attr) { |
159 | p.printAttribute(attr); |
160 | return p; |
161 | } |
162 | |
163 | template <typename AsmPrinterT> |
164 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
165 | AsmPrinterT &> |
166 | operator<<(AsmPrinterT &p, const APFloat &value) { |
167 | p.printFloat(value); |
168 | return p; |
169 | } |
170 | template <typename AsmPrinterT> |
171 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
172 | AsmPrinterT &> |
173 | operator<<(AsmPrinterT &p, float value) { |
174 | return p << APFloat(value); |
175 | } |
176 | template <typename AsmPrinterT> |
177 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
178 | AsmPrinterT &> |
179 | operator<<(AsmPrinterT &p, double value) { |
180 | return p << APFloat(value); |
181 | } |
182 | |
183 | // Support printing anything that isn't convertible to one of the other |
184 | // streamable types, even if it isn't exactly one of them. For example, we want |
185 | // to print FunctionType with the Type version above, not have it match this. |
186 | template < |
187 | typename AsmPrinterT, typename T, |
188 | typename std::enable_if<!std::is_convertible<T &, Value &>::value && |
189 | !std::is_convertible<T &, Type &>::value && |
190 | !std::is_convertible<T &, Attribute &>::value && |
191 | !std::is_convertible<T &, ValueRange>::value && |
192 | !std::is_convertible<T &, APFloat &>::value && |
193 | !llvm::is_one_of<T, bool, float, double>::value, |
194 | T>::type * = nullptr> |
195 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
196 | AsmPrinterT &> |
197 | operator<<(AsmPrinterT &p, const T &other) { |
198 | p.getStream() << other; |
199 | return p; |
200 | } |
201 | |
202 | template <typename AsmPrinterT> |
203 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
204 | AsmPrinterT &> |
205 | operator<<(AsmPrinterT &p, bool value) { |
206 | return p << (value ? StringRef("true") : "false"); |
207 | } |
208 | |
209 | template <typename AsmPrinterT, typename ValueRangeT> |
210 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
211 | AsmPrinterT &> |
212 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
213 | llvm::interleaveComma(types, p); |
214 | return p; |
215 | } |
216 | template <typename AsmPrinterT> |
217 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
218 | AsmPrinterT &> |
219 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
220 | llvm::interleaveComma(types, p); |
221 | return p; |
222 | } |
223 | template <typename AsmPrinterT, typename ElementT> |
224 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
225 | AsmPrinterT &> |
226 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
227 | llvm::interleaveComma(types, p); |
228 | return p; |
229 | } |
230 | |
231 | //===----------------------------------------------------------------------===// |
232 | // OpAsmPrinter |
233 | //===----------------------------------------------------------------------===// |
234 | |
235 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
236 | /// necessary to implement a custom print() method. |
237 | class OpAsmPrinter : public AsmPrinter { |
238 | public: |
239 | using AsmPrinter::AsmPrinter; |
240 | ~OpAsmPrinter() override; |
241 | |
242 | /// Print a newline and indent the printer to the start of the current |
243 | /// operation. |
244 | virtual void printNewline() = 0; |
245 | |
246 | /// Print a block argument in the usual format of: |
247 | /// %ssaName : type {attr1=42} loc("here") |
248 | /// where location printing is controlled by the standard internal option. |
249 | /// You may pass omitType=true to not print a type, and pass an empty |
250 | /// attribute list if you don't care for attributes. |
251 | virtual void printRegionArgument(BlockArgument arg, |
252 | ArrayRef<NamedAttribute> argAttrs = {}, |
253 | bool omitType = false) = 0; |
254 | |
255 | /// Print implementations for various things an operation contains. |
256 | virtual void printOperand(Value value) = 0; |
257 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
258 | |
259 | /// Print a comma separated list of operands. |
260 | template <typename ContainerType> |
261 | void printOperands(const ContainerType &container) { |
262 | printOperands(container.begin(), container.end()); |
263 | } |
264 | |
265 | /// Print a comma separated list of operands. |
266 | template <typename IteratorType> |
267 | void printOperands(IteratorType it, IteratorType end) { |
268 | if (it == end) |
269 | return; |
270 | printOperand(*it); |
271 | for (++it; it != end; ++it) { |
272 | getStream() << ", "; |
273 | printOperand(*it); |
274 | } |
275 | } |
276 | |
277 | /// Print the given successor. |
278 | virtual void printSuccessor(Block *successor) = 0; |
279 | |
280 | /// Print the successor and its operands. |
281 | virtual void printSuccessorAndUseList(Block *successor, |
282 | ValueRange succOperands) = 0; |
283 | |
284 | /// If the specified operation has attributes, print out an attribute |
285 | /// dictionary with their values. elidedAttrs allows the client to ignore |
286 | /// specific well known attributes, commonly used if the attribute value is |
287 | /// printed some other way (like as a fixed operand). |
288 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
289 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
290 | |
291 | /// If the specified operation has attributes, print out an attribute |
292 | /// dictionary prefixed with 'attributes'. |
293 | virtual void |
294 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
295 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
296 | |
297 | /// Print the entire operation with the default generic assembly form. |
298 | /// If `printOpName` is true, then the operation name is printed (the default) |
299 | /// otherwise it is omitted and the print will start with the operand list. |
300 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
301 | |
302 | /// Prints a region. |
303 | /// If 'printEntryBlockArgs' is false, the arguments of the |
304 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
305 | /// operation of the block is not printed. If printEmptyBlock is true, then |
306 | /// the block header is printed even if the block is empty. |
307 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
308 | bool printBlockTerminators = true, |
309 | bool printEmptyBlock = false) = 0; |
310 | |
311 | /// Renumber the arguments for the specified region to the same names as the |
312 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
313 | /// operations. If any entry in namesToUse is null, the corresponding |
314 | /// argument name is left alone. |
315 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
316 | |
317 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
318 | /// of dims/symbols. |
319 | /// Operand values must come from single-result sources, and be valid |
320 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
321 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
322 | ValueRange operands) = 0; |
323 | |
324 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
325 | /// dims and symbols. |
326 | /// Operand values must come from single-result sources, and be valid |
327 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
328 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
329 | ValueRange symOperands) = 0; |
330 | |
331 | /// Print the complete type of an operation in functional form. |
332 | void printFunctionalType(Operation *op); |
333 | using AsmPrinter::printFunctionalType; |
334 | }; |
335 | |
336 | // Make the implementations convenient to use. |
337 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
338 | p.printOperand(value); |
339 | return p; |
340 | } |
341 | |
342 | template <typename T, |
343 | typename std::enable_if<std::is_convertible<T &, ValueRange>::value && |
344 | !std::is_convertible<T &, Value &>::value, |
345 | T>::type * = nullptr> |
346 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
347 | p.printOperands(values); |
348 | return p; |
349 | } |
350 | |
351 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
352 | p.printSuccessor(value); |
353 | return p; |
354 | } |
355 | |
356 | //===----------------------------------------------------------------------===// |
357 | // AsmParser |
358 | //===----------------------------------------------------------------------===// |
359 | |
360 | /// This base class exposes generic asm parser hooks, usable across the various |
361 | /// derived parsers. |
362 | class AsmParser { |
363 | public: |
364 | AsmParser() = default; |
365 | virtual ~AsmParser(); |
366 | |
367 | MLIRContext *getContext() const; |
368 | |
369 | /// Return the location of the original name token. |
370 | virtual SMLoc getNameLoc() const = 0; |
371 | |
372 | //===--------------------------------------------------------------------===// |
373 | // Utilities |
374 | //===--------------------------------------------------------------------===// |
375 | |
376 | /// Emit a diagnostic at the specified location and return failure. |
377 | virtual InFlightDiagnostic emitError(SMLoc loc, |
378 | const Twine &message = {}) = 0; |
379 | |
380 | /// Return a builder which provides useful access to MLIRContext, global |
381 | /// objects like types and attributes. |
382 | virtual Builder &getBuilder() const = 0; |
383 | |
384 | /// Get the location of the next token and store it into the argument. This |
385 | /// always succeeds. |
386 | virtual SMLoc getCurrentLocation() = 0; |
387 | ParseResult getCurrentLocation(SMLoc *loc) { |
388 | *loc = getCurrentLocation(); |
389 | return success(); |
390 | } |
391 | |
392 | /// Re-encode the given source location as an MLIR location and return it. |
393 | /// Note: This method should only be used when a `Location` is necessary, as |
394 | /// the encoding process is not efficient. |
395 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
396 | |
397 | //===--------------------------------------------------------------------===// |
398 | // Token Parsing |
399 | //===--------------------------------------------------------------------===// |
400 | |
401 | /// Parse a '->' token. |
402 | virtual ParseResult parseArrow() = 0; |
403 | |
404 | /// Parse a '->' token if present |
405 | virtual ParseResult parseOptionalArrow() = 0; |
406 | |
407 | /// Parse a `{` token. |
408 | virtual ParseResult parseLBrace() = 0; |
409 | |
410 | /// Parse a `{` token if present. |
411 | virtual ParseResult parseOptionalLBrace() = 0; |
412 | |
413 | /// Parse a `}` token. |
414 | virtual ParseResult parseRBrace() = 0; |
415 | |
416 | /// Parse a `}` token if present. |
417 | virtual ParseResult parseOptionalRBrace() = 0; |
418 | |
419 | /// Parse a `:` token. |
420 | virtual ParseResult parseColon() = 0; |
421 | |
422 | /// Parse a `:` token if present. |
423 | virtual ParseResult parseOptionalColon() = 0; |
424 | |
425 | /// Parse a `,` token. |
426 | virtual ParseResult parseComma() = 0; |
427 | |
428 | /// Parse a `,` token if present. |
429 | virtual ParseResult parseOptionalComma() = 0; |
430 | |
431 | /// Parse a `=` token. |
432 | virtual ParseResult parseEqual() = 0; |
433 | |
434 | /// Parse a `=` token if present. |
435 | virtual ParseResult parseOptionalEqual() = 0; |
436 | |
437 | /// Parse a '<' token. |
438 | virtual ParseResult parseLess() = 0; |
439 | |
440 | /// Parse a '<' token if present. |
441 | virtual ParseResult parseOptionalLess() = 0; |
442 | |
443 | /// Parse a '>' token. |
444 | virtual ParseResult parseGreater() = 0; |
445 | |
446 | /// Parse a '>' token if present. |
447 | virtual ParseResult parseOptionalGreater() = 0; |
448 | |
449 | /// Parse a '?' token. |
450 | virtual ParseResult parseQuestion() = 0; |
451 | |
452 | /// Parse a '?' token if present. |
453 | virtual ParseResult parseOptionalQuestion() = 0; |
454 | |
455 | /// Parse a '+' token. |
456 | virtual ParseResult parsePlus() = 0; |
457 | |
458 | /// Parse a '+' token if present. |
459 | virtual ParseResult parseOptionalPlus() = 0; |
460 | |
461 | /// Parse a '*' token. |
462 | virtual ParseResult parseStar() = 0; |
463 | |
464 | /// Parse a '*' token if present. |
465 | virtual ParseResult parseOptionalStar() = 0; |
466 | |
467 | /// Parse a quoted string token. |
468 | ParseResult parseString(std::string *string) { |
469 | auto loc = getCurrentLocation(); |
470 | if (parseOptionalString(string)) |
471 | return emitError(loc, "expected string"); |
472 | return success(); |
473 | } |
474 | |
475 | /// Parse a quoted string token if present. |
476 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
477 | |
478 | /// Parse a given keyword. |
479 | ParseResult parseKeyword(StringRef keyword, const Twine &msg = "") { |
480 | auto loc = getCurrentLocation(); |
481 | if (parseOptionalKeyword(keyword)) |
482 | return emitError(loc, "expected '") << keyword << "'" << msg; |
483 | return success(); |
484 | } |
485 | |
486 | /// Parse a keyword into 'keyword'. |
487 | ParseResult parseKeyword(StringRef *keyword) { |
488 | auto loc = getCurrentLocation(); |
489 | if (parseOptionalKeyword(keyword)) |
490 | return emitError(loc, "expected valid keyword"); |
491 | return success(); |
492 | } |
493 | |
494 | /// Parse the given keyword if present. |
495 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
496 | |
497 | /// Parse a keyword, if present, into 'keyword'. |
498 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
499 | |
500 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
501 | /// into 'keyword' |
502 | virtual ParseResult |
503 | parseOptionalKeyword(StringRef *keyword, |
504 | ArrayRef<StringRef> allowedValues) = 0; |
505 | |
506 | /// Parse a keyword or a quoted string. |
507 | ParseResult parseKeywordOrString(std::string *result) { |
508 | if (failed(parseOptionalKeywordOrString(result))) |
509 | return emitError(getCurrentLocation()) |
510 | << "expected valid keyword or string"; |
511 | return success(); |
512 | } |
513 | |
514 | /// Parse an optional keyword or string. |
515 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
516 | |
517 | /// Parse a `(` token. |
518 | virtual ParseResult parseLParen() = 0; |
519 | |
520 | /// Parse a `(` token if present. |
521 | virtual ParseResult parseOptionalLParen() = 0; |
522 | |
523 | /// Parse a `)` token. |
524 | virtual ParseResult parseRParen() = 0; |
525 | |
526 | /// Parse a `)` token if present. |
527 | virtual ParseResult parseOptionalRParen() = 0; |
528 | |
529 | /// Parse a `[` token. |
530 | virtual ParseResult parseLSquare() = 0; |
531 | |
532 | /// Parse a `[` token if present. |
533 | virtual ParseResult parseOptionalLSquare() = 0; |
534 | |
535 | /// Parse a `]` token. |
536 | virtual ParseResult parseRSquare() = 0; |
537 | |
538 | /// Parse a `]` token if present. |
539 | virtual ParseResult parseOptionalRSquare() = 0; |
540 | |
541 | /// Parse a `...` token if present; |
542 | virtual ParseResult parseOptionalEllipsis() = 0; |
543 | |
544 | /// Parse a floating point value from the stream. |
545 | virtual ParseResult parseFloat(double &result) = 0; |
546 | |
547 | /// Parse an integer value from the stream. |
548 | template <typename IntT> |
549 | ParseResult parseInteger(IntT &result) { |
550 | auto loc = getCurrentLocation(); |
551 | OptionalParseResult parseResult = parseOptionalInteger(result); |
552 | if (!parseResult.hasValue()) |
553 | return emitError(loc, "expected integer value"); |
554 | return *parseResult; |
555 | } |
556 | |
557 | /// Parse an optional integer value from the stream. |
558 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
559 | |
560 | template <typename IntT> |
561 | OptionalParseResult parseOptionalInteger(IntT &result) { |
562 | auto loc = getCurrentLocation(); |
563 | |
564 | // Parse the unsigned variant. |
565 | APInt uintResult; |
566 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
567 | if (!parseResult.hasValue() || failed(*parseResult)) |
568 | return parseResult; |
569 | |
570 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
571 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
572 | // zero for non-negated integers. |
573 | result = |
574 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
575 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
576 | return emitError(loc, "integer value too large"); |
577 | return success(); |
578 | } |
579 | |
580 | /// These are the supported delimiters around operand lists and region |
581 | /// argument lists, used by parseOperandList and parseRegionArgumentList. |
582 | enum class Delimiter { |
583 | /// Zero or more operands with no delimiters. |
584 | None, |
585 | /// Parens surrounding zero or more operands. |
586 | Paren, |
587 | /// Square brackets surrounding zero or more operands. |
588 | Square, |
589 | /// <> brackets surrounding zero or more operands. |
590 | LessGreater, |
591 | /// {} brackets surrounding zero or more operands. |
592 | Braces, |
593 | /// Parens supporting zero or more operands, or nothing. |
594 | OptionalParen, |
595 | /// Square brackets supporting zero or more ops, or nothing. |
596 | OptionalSquare, |
597 | /// <> brackets supporting zero or more ops, or nothing. |
598 | OptionalLessGreater, |
599 | /// {} brackets surrounding zero or more operands, or nothing. |
600 | OptionalBraces, |
601 | }; |
602 | |
603 | /// Parse a list of comma-separated items with an optional delimiter. If a |
604 | /// delimiter is provided, then an empty list is allowed. If not, then at |
605 | /// least one element will be parsed. |
606 | /// |
607 | /// contextMessage is an optional message appended to "expected '('" sorts of |
608 | /// diagnostics when parsing the delimeters. |
609 | virtual ParseResult |
610 | parseCommaSeparatedList(Delimiter delimiter, |
611 | function_ref<ParseResult()> parseElementFn, |
612 | StringRef contextMessage = StringRef()) = 0; |
613 | |
614 | /// Parse a comma separated list of elements that must have at least one entry |
615 | /// in it. |
616 | ParseResult |
617 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
618 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
619 | } |
620 | |
621 | //===--------------------------------------------------------------------===// |
622 | // Attribute/Type Parsing |
623 | //===--------------------------------------------------------------------===// |
624 | |
625 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
626 | /// the provided location to emit errors in the case of failure. Note that |
627 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
628 | /// context parameter. |
629 | template <typename T, typename... ParamsT> |
630 | T getChecked(SMLoc loc, ParamsT &&... params) { |
631 | return T::getChecked([&] { return emitError(loc); }, |
632 | std::forward<ParamsT>(params)...); |
633 | } |
634 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
635 | /// errors. |
636 | template <typename T, typename... ParamsT> |
637 | T getChecked(ParamsT &&... params) { |
638 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
639 | std::forward<ParamsT>(params)...); |
640 | } |
641 | |
642 | //===--------------------------------------------------------------------===// |
643 | // Attribute Parsing |
644 | //===--------------------------------------------------------------------===// |
645 | |
646 | /// Parse an arbitrary attribute of a given type and return it in result. |
647 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
648 | |
649 | /// Parse a custom attribute with the provided callback, unless the next |
650 | /// token is `#`, in which case the generic parser is invoked. |
651 | virtual ParseResult parseCustomAttributeWithFallback( |
652 | Attribute &result, Type type, |
653 | function_ref<ParseResult(Attribute &result, Type type)> |
654 | parseAttribute) = 0; |
655 | |
656 | /// Parse an attribute of a specific kind and type. |
657 | template <typename AttrType> |
658 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
659 | SMLoc loc = getCurrentLocation(); |
660 | |
661 | // Parse any kind of attribute. |
662 | Attribute attr; |
663 | if (parseAttribute(attr, type)) |
664 | return failure(); |
665 | |
666 | // Check for the right kind of attribute. |
667 | if (!(result = attr.dyn_cast<AttrType>())) |
668 | return emitError(loc, "invalid kind of attribute specified"); |
669 | |
670 | return success(); |
671 | } |
672 | |
673 | /// Parse an arbitrary attribute and return it in result. This also adds the |
674 | /// attribute to the specified attribute list with the specified name. |
675 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
676 | NamedAttrList &attrs) { |
677 | return parseAttribute(result, Type(), attrName, attrs); |
678 | } |
679 | |
680 | /// Parse an attribute of a specific kind and type. |
681 | template <typename AttrType> |
682 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
683 | NamedAttrList &attrs) { |
684 | return parseAttribute(result, Type(), attrName, attrs); |
685 | } |
686 | |
687 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
688 | /// This also adds the attribute to the specified attribute list with the |
689 | /// specified name. |
690 | template <typename AttrType> |
691 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
692 | NamedAttrList &attrs) { |
693 | SMLoc loc = getCurrentLocation(); |
694 | |
695 | // Parse any kind of attribute. |
696 | Attribute attr; |
697 | if (parseAttribute(attr, type)) |
698 | return failure(); |
699 | |
700 | // Check for the right kind of attribute. |
701 | result = attr.dyn_cast<AttrType>(); |
702 | if (!result) |
703 | return emitError(loc, "invalid kind of attribute specified"); |
704 | |
705 | attrs.append(attrName, result); |
706 | return success(); |
707 | } |
708 | |
709 | /// Trait to check if `AttrType` provides a `parse` method. |
710 | template <typename AttrType> |
711 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
712 | std::declval<Type>())); |
713 | template <typename AttrType> |
714 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
715 | |
716 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
717 | /// which case the generic parser is invoked. The parsed attribute is |
718 | /// populated in `result` and also added to the specified attribute list with |
719 | /// the specified name. |
720 | template <typename AttrType> |
721 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
722 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
723 | StringRef attrName, NamedAttrList &attrs) { |
724 | SMLoc loc = getCurrentLocation(); |
725 | |
726 | // Parse any kind of attribute. |
727 | Attribute attr; |
728 | if (parseCustomAttributeWithFallback( |
729 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
730 | result = AttrType::parse(*this, type); |
731 | if (!result) |
732 | return failure(); |
733 | return success(); |
734 | })) |
735 | return failure(); |
736 | |
737 | // Check for the right kind of attribute. |
738 | result = attr.dyn_cast<AttrType>(); |
739 | if (!result) |
740 | return emitError(loc, "invalid kind of attribute specified"); |
741 | |
742 | attrs.append(attrName, result); |
743 | return success(); |
744 | } |
745 | |
746 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
747 | template <typename AttrType> |
748 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
749 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
750 | StringRef attrName, NamedAttrList &attrs) { |
751 | return parseAttribute(result, type, attrName, attrs); |
752 | } |
753 | |
754 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
755 | /// which case the generic parser is invoked. The parsed attribute is |
756 | /// populated in `result`. |
757 | template <typename AttrType> |
758 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
759 | parseCustomAttributeWithFallback(AttrType &result) { |
760 | SMLoc loc = getCurrentLocation(); |
761 | |
762 | // Parse any kind of attribute. |
763 | Attribute attr; |
764 | if (parseCustomAttributeWithFallback( |
765 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
766 | result = AttrType::parse(*this, type); |
767 | return success(!!result); |
768 | })) |
769 | return failure(); |
770 | |
771 | // Check for the right kind of attribute. |
772 | result = attr.dyn_cast<AttrType>(); |
773 | if (!result) |
774 | return emitError(loc, "invalid kind of attribute specified"); |
775 | return success(); |
776 | } |
777 | |
778 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
779 | template <typename AttrType> |
780 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
781 | parseCustomAttributeWithFallback(AttrType &result) { |
782 | return parseAttribute(result); |
783 | } |
784 | |
785 | /// Parse an arbitrary optional attribute of a given type and return it in |
786 | /// result. |
787 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
788 | Type type = {}) = 0; |
789 | |
790 | /// Parse an optional array attribute and return it in result. |
791 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
792 | Type type = {}) = 0; |
793 | |
794 | /// Parse an optional string attribute and return it in result. |
795 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
796 | Type type = {}) = 0; |
797 | |
798 | /// Parse an optional attribute of a specific type and add it to the list with |
799 | /// the specified name. |
800 | template <typename AttrType> |
801 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
802 | StringRef attrName, |
803 | NamedAttrList &attrs) { |
804 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
805 | } |
806 | |
807 | /// Parse an optional attribute of a specific type and add it to the list with |
808 | /// the specified name. |
809 | template <typename AttrType> |
810 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
811 | StringRef attrName, |
812 | NamedAttrList &attrs) { |
813 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
814 | if (parseResult.hasValue() && succeeded(*parseResult)) |
815 | attrs.append(attrName, result); |
816 | return parseResult; |
817 | } |
818 | |
819 | /// Parse a named dictionary into 'result' if it is present. |
820 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
821 | |
822 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
823 | /// present. |
824 | virtual ParseResult |
825 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
826 | |
827 | /// Parse an affine map instance into 'map'. |
828 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
829 | |
830 | /// Parse an integer set instance into 'set'. |
831 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
832 | |
833 | //===--------------------------------------------------------------------===// |
834 | // Identifier Parsing |
835 | //===--------------------------------------------------------------------===// |
836 | |
837 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
838 | /// attribute named 'attrName'. |
839 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
840 | NamedAttrList &attrs) { |
841 | if (failed(parseOptionalSymbolName(result, attrName, attrs))) |
842 | return emitError(getCurrentLocation()) |
843 | << "expected valid '@'-identifier for symbol name"; |
844 | return success(); |
845 | } |
846 | |
847 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
848 | /// string attribute named 'attrName'. |
849 | virtual ParseResult parseOptionalSymbolName(StringAttr &result, |
850 | StringRef attrName, |
851 | NamedAttrList &attrs) = 0; |
852 | |
853 | //===--------------------------------------------------------------------===// |
854 | // Type Parsing |
855 | //===--------------------------------------------------------------------===// |
856 | |
857 | /// Parse a type. |
858 | virtual ParseResult parseType(Type &result) = 0; |
859 | |
860 | /// Parse a custom type with the provided callback, unless the next |
861 | /// token is `#`, in which case the generic parser is invoked. |
862 | virtual ParseResult parseCustomTypeWithFallback( |
863 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
864 | |
865 | /// Parse an optional type. |
866 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
867 | |
868 | /// Parse a type of a specific type. |
869 | template <typename TypeT> |
870 | ParseResult parseType(TypeT &result) { |
871 | SMLoc loc = getCurrentLocation(); |
872 | |
873 | // Parse any kind of type. |
874 | Type type; |
875 | if (parseType(type)) |
876 | return failure(); |
877 | |
878 | // Check for the right kind of type. |
879 | result = type.dyn_cast<TypeT>(); |
880 | if (!result) |
881 | return emitError(loc, "invalid kind of type specified"); |
882 | |
883 | return success(); |
884 | } |
885 | |
886 | /// Trait to check if `TypeT` provides a `parse` method. |
887 | template <typename TypeT> |
888 | using type_has_parse_method = |
889 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
890 | template <typename TypeT> |
891 | using detect_type_has_parse_method = |
892 | llvm::is_detected<type_has_parse_method, TypeT>; |
893 | |
894 | /// Parse a custom Type of a given type unless the next token is `#`, in |
895 | /// which case the generic parser is invoked. The parsed Type is |
896 | /// populated in `result`. |
897 | template <typename TypeT> |
898 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
899 | parseCustomTypeWithFallback(TypeT &result) { |
900 | SMLoc loc = getCurrentLocation(); |
901 | |
902 | // Parse any kind of Type. |
903 | Type type; |
904 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
905 | result = TypeT::parse(*this); |
906 | return success(!!result); |
907 | })) |
908 | return failure(); |
909 | |
910 | // Check for the right kind of Type. |
911 | result = type.dyn_cast<TypeT>(); |
912 | if (!result) |
913 | return emitError(loc, "invalid kind of Type specified"); |
914 | return success(); |
915 | } |
916 | |
917 | /// SFINAE parsing method for Type that don't implement a parse method. |
918 | template <typename TypeT> |
919 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
920 | parseCustomTypeWithFallback(TypeT &result) { |
921 | return parseType(result); |
922 | } |
923 | |
924 | /// Parse a type list. |
925 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
926 | do { |
927 | Type type; |
928 | if (parseType(type)) |
929 | return failure(); |
930 | result.push_back(type); |
931 | } while (succeeded(parseOptionalComma())); |
932 | return success(); |
933 | } |
934 | |
935 | /// Parse an arrow followed by a type list. |
936 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
937 | |
938 | /// Parse an optional arrow followed by a type list. |
939 | virtual ParseResult |
940 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
941 | |
942 | /// Parse a colon followed by a type. |
943 | virtual ParseResult parseColonType(Type &result) = 0; |
944 | |
945 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
946 | template <typename TypeType> |
947 | ParseResult parseColonType(TypeType &result) { |
948 | SMLoc loc = getCurrentLocation(); |
949 | |
950 | // Parse any kind of type. |
951 | Type type; |
952 | if (parseColonType(type)) |
953 | return failure(); |
954 | |
955 | // Check for the right kind of type. |
956 | result = type.dyn_cast<TypeType>(); |
957 | if (!result) |
958 | return emitError(loc, "invalid kind of type specified"); |
959 | |
960 | return success(); |
961 | } |
962 | |
963 | /// Parse a colon followed by a type list, which must have at least one type. |
964 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
965 | |
966 | /// Parse an optional colon followed by a type list, which if present must |
967 | /// have at least one type. |
968 | virtual ParseResult |
969 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
970 | |
971 | /// Parse a keyword followed by a type. |
972 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
973 | return failure(parseKeyword(keyword) || parseType(result)); |
974 | } |
975 | |
976 | /// Add the specified type to the end of the specified type list and return |
977 | /// success. This is a helper designed to allow parse methods to be simple |
978 | /// and chain through || operators. |
979 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
980 | result.push_back(type); |
981 | return success(); |
982 | } |
983 | |
984 | /// Add the specified types to the end of the specified type list and return |
985 | /// success. This is a helper designed to allow parse methods to be simple |
986 | /// and chain through || operators. |
987 | ParseResult addTypesToList(ArrayRef<Type> types, |
988 | SmallVectorImpl<Type> &result) { |
989 | result.append(types.begin(), types.end()); |
990 | return success(); |
991 | } |
992 | |
993 | /// Parse a 'x' separated dimension list. This populates the dimension list, |
994 | /// using -1 for the `?` dimensions if `allowDynamic` is set and errors out on |
995 | /// `?` otherwise. |
996 | /// |
997 | /// dimension-list ::= (dimension `x`)* |
998 | /// dimension ::= `?` | integer |
999 | /// |
1000 | /// When `allowDynamic` is not set, this is used to parse: |
1001 | /// |
1002 | /// static-dimension-list ::= (integer `x`)* |
1003 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1004 | bool allowDynamic = true) = 0; |
1005 | |
1006 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1007 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1008 | /// next token. |
1009 | virtual ParseResult parseXInDimensionList() = 0; |
1010 | |
1011 | private: |
1012 | AsmParser(const AsmParser &) = delete; |
1013 | void operator=(const AsmParser &) = delete; |
1014 | }; |
1015 | |
1016 | //===----------------------------------------------------------------------===// |
1017 | // OpAsmParser |
1018 | //===----------------------------------------------------------------------===// |
1019 | |
1020 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1021 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1022 | /// that is designed to reduce/constrain syntax innovation in individual |
1023 | /// operations. |
1024 | /// |
1025 | /// For example, consider an op like this: |
1026 | /// |
1027 | /// %x = load %p[%1, %2] : memref<...> |
1028 | /// |
1029 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1030 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1031 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1032 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1033 | /// type. |
1034 | /// |
1035 | class OpAsmParser : public AsmParser { |
1036 | public: |
1037 | using AsmParser::AsmParser; |
1038 | ~OpAsmParser() override; |
1039 | |
1040 | /// Parse a loc(...) specifier if present, filling in result if so. |
1041 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1042 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1043 | /// completes. |
1044 | virtual ParseResult |
1045 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1046 | |
1047 | /// Return the name of the specified result in the specified syntax, as well |
1048 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1049 | /// invalid result numbers. For example, in this operation: |
1050 | /// |
1051 | /// %x, %y:2, %z = foo.op |
1052 | /// |
1053 | /// getResultName(0) == {"x", 0 } |
1054 | /// getResultName(1) == {"y", 0 } |
1055 | /// getResultName(2) == {"y", 1 } |
1056 | /// getResultName(3) == {"z", 0 } |
1057 | /// getResultName(4) == {"", ~0U } |
1058 | virtual std::pair<StringRef, unsigned> |
1059 | getResultName(unsigned resultNo) const = 0; |
1060 | |
1061 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1062 | /// example in the comment for `getResultName`. |
1063 | virtual size_t getNumResults() const = 0; |
1064 | |
1065 | // These methods emit an error and return failure or success. This allows |
1066 | // these to be chained together into a linear sequence of || expressions in |
1067 | // many cases. |
1068 | |
1069 | /// Parse an operation in its generic form. |
1070 | /// The parsed operation is parsed in the current context and inserted in the |
1071 | /// provided block and insertion point. The results produced by this operation |
1072 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1073 | /// failure. |
1074 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1075 | Block::iterator insertPt) = 0; |
1076 | |
1077 | /// Parse the name of an operation, in the custom form. On success, return a |
1078 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1079 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1080 | |
1081 | //===--------------------------------------------------------------------===// |
1082 | // Operand Parsing |
1083 | //===--------------------------------------------------------------------===// |
1084 | |
1085 | /// This is the representation of an operand reference. |
1086 | struct OperandType { |
1087 | SMLoc location; // Location of the token. |
1088 | StringRef name; // Value name, e.g. %42 or %abc |
1089 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1090 | }; |
1091 | |
1092 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1093 | /// region(s), attribute(s) and function-type, of the generic form of an |
1094 | /// operation instance and populate the input operation-state 'result' with |
1095 | /// those components. If any of the components is explicitly provided, then |
1096 | /// skip parsing that component. |
1097 | virtual ParseResult parseGenericOperationAfterOpName( |
1098 | OperationState &result, |
1099 | Optional<ArrayRef<OperandType>> parsedOperandType = llvm::None, |
1100 | Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None, |
1101 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1102 | llvm::None, |
1103 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None, |
1104 | Optional<FunctionType> parsedFnType = llvm::None) = 0; |
1105 | |
1106 | /// Parse a single operand. |
1107 | virtual ParseResult parseOperand(OperandType &result) = 0; |
1108 | |
1109 | /// Parse a single operand if present. |
1110 | virtual OptionalParseResult parseOptionalOperand(OperandType &result) = 0; |
1111 | |
1112 | /// Parse zero or more SSA comma-separated operand references with a specified |
1113 | /// surrounding delimiter, and an optional required operand count. |
1114 | virtual ParseResult |
1115 | parseOperandList(SmallVectorImpl<OperandType> &result, |
1116 | int requiredOperandCount = -1, |
1117 | Delimiter delimiter = Delimiter::None) = 0; |
1118 | ParseResult parseOperandList(SmallVectorImpl<OperandType> &result, |
1119 | Delimiter delimiter) { |
1120 | return parseOperandList(result, /*requiredOperandCount=*/-1, delimiter); |
1121 | } |
1122 | |
1123 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1124 | /// references with a specified surrounding delimiter, and an optional |
1125 | /// required operand count. A leading comma is expected before the operands. |
1126 | virtual ParseResult |
1127 | parseTrailingOperandList(SmallVectorImpl<OperandType> &result, |
1128 | int requiredOperandCount = -1, |
1129 | Delimiter delimiter = Delimiter::None) = 0; |
1130 | ParseResult parseTrailingOperandList(SmallVectorImpl<OperandType> &result, |
1131 | Delimiter delimiter) { |
1132 | return parseTrailingOperandList(result, /*requiredOperandCount=*/-1, |
1133 | delimiter); |
1134 | } |
1135 | |
1136 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1137 | virtual ParseResult resolveOperand(const OperandType &operand, Type type, |
1138 | SmallVectorImpl<Value> &result) = 0; |
1139 | |
1140 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1141 | /// appending the results to the list on success. This method should be used |
1142 | /// when all operands have the same type. |
1143 | ParseResult resolveOperands(ArrayRef<OperandType> operands, Type type, |
1144 | SmallVectorImpl<Value> &result) { |
1145 | for (auto elt : operands) |
1146 | if (resolveOperand(elt, type, result)) |
1147 | return failure(); |
1148 | return success(); |
1149 | } |
1150 | |
1151 | /// Resolve a list of operands and a list of operand types to SSA values, |
1152 | /// emitting an error and returning failure, or appending the results |
1153 | /// to the list on success. |
1154 | ParseResult resolveOperands(ArrayRef<OperandType> operands, |
1155 | ArrayRef<Type> types, SMLoc loc, |
1156 | SmallVectorImpl<Value> &result) { |
1157 | if (operands.size() != types.size()) |
1158 | return emitError(loc) |
1159 | << operands.size() << " operands present, but expected " |
1160 | << types.size(); |
1161 | |
1162 | for (unsigned i = 0, e = operands.size(); i != e; ++i) |
1163 | if (resolveOperand(operands[i], types[i], result)) |
1164 | return failure(); |
1165 | return success(); |
1166 | } |
1167 | template <typename Operands> |
1168 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1169 | SmallVectorImpl<Value> &result) { |
1170 | return resolveOperands(std::forward<Operands>(operands), |
1171 | ArrayRef<Type>(type), loc, result); |
1172 | } |
1173 | template <typename Operands, typename Types> |
1174 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1175 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1176 | SmallVectorImpl<Value> &result) { |
1177 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1178 | size_t typeSize = std::distance(types.begin(), types.end()); |
1179 | if (operandSize != typeSize) |
1180 | return emitError(loc) |
1181 | << operandSize << " operands present, but expected " << typeSize; |
1182 | |
1183 | for (auto it : llvm::zip(operands, types)) |
1184 | if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) |
1185 | return failure(); |
1186 | return success(); |
1187 | } |
1188 | |
1189 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1190 | /// Operand values must come from single-result sources, and be valid |
1191 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1192 | virtual ParseResult |
1193 | parseAffineMapOfSSAIds(SmallVectorImpl<OperandType> &operands, Attribute &map, |
1194 | StringRef attrName, NamedAttrList &attrs, |
1195 | Delimiter delimiter = Delimiter::Square) = 0; |
1196 | |
1197 | /// Parses an affine expression where dims and symbols are SSA operands. |
1198 | /// Operand values must come from single-result sources, and be valid |
1199 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1200 | virtual ParseResult |
1201 | parseAffineExprOfSSAIds(SmallVectorImpl<OperandType> &dimOperands, |
1202 | SmallVectorImpl<OperandType> &symbOperands, |
1203 | AffineExpr &expr) = 0; |
1204 | |
1205 | //===--------------------------------------------------------------------===// |
1206 | // Region Parsing |
1207 | //===--------------------------------------------------------------------===// |
1208 | |
1209 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1210 | /// moved to the op regions after the op is created. The first block of the |
1211 | /// region takes 'arguments' of types 'argTypes'. If `argLocations` is |
1212 | /// non-empty it contains a location to be attached to each argument. If |
1213 | /// 'enableNameShadowing' is set to true, the argument names are allowed to |
1214 | /// shadow the names of other existing SSA values defined above the region |
1215 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1216 | /// to operations that are 'IsolatedFromAbove'. |
1217 | virtual ParseResult parseRegion(Region ®ion, |
1218 | ArrayRef<OperandType> arguments = {}, |
1219 | ArrayRef<Type> argTypes = {}, |
1220 | ArrayRef<Location> argLocations = {}, |
1221 | bool enableNameShadowing = false) = 0; |
1222 | |
1223 | /// Parses a region if present. |
1224 | virtual OptionalParseResult |
1225 | parseOptionalRegion(Region ®ion, ArrayRef<OperandType> arguments = {}, |
1226 | ArrayRef<Type> argTypes = {}, |
1227 | ArrayRef<Location> argLocations = {}, |
1228 | bool enableNameShadowing = false) = 0; |
1229 | |
1230 | /// Parses a region if present. If the region is present, a new region is |
1231 | /// allocated and placed in `region`. If no region is present or on failure, |
1232 | /// `region` remains untouched. |
1233 | virtual OptionalParseResult parseOptionalRegion( |
1234 | std::unique_ptr<Region> ®ion, ArrayRef<OperandType> arguments = {}, |
1235 | ArrayRef<Type> argTypes = {}, bool enableNameShadowing = false) = 0; |
1236 | |
1237 | /// Parse a region argument, this argument is resolved when calling |
1238 | /// 'parseRegion'. |
1239 | virtual ParseResult parseRegionArgument(OperandType &argument) = 0; |
1240 | |
1241 | /// Parse zero or more region arguments with a specified surrounding |
1242 | /// delimiter, and an optional required argument count. Region arguments |
1243 | /// define new values; so this also checks if values with the same names have |
1244 | /// not been defined yet. |
1245 | virtual ParseResult |
1246 | parseRegionArgumentList(SmallVectorImpl<OperandType> &result, |
1247 | int requiredOperandCount = -1, |
1248 | Delimiter delimiter = Delimiter::None) = 0; |
1249 | virtual ParseResult |
1250 | parseRegionArgumentList(SmallVectorImpl<OperandType> &result, |
1251 | Delimiter delimiter) { |
1252 | return parseRegionArgumentList(result, /*requiredOperandCount=*/-1, |
1253 | delimiter); |
1254 | } |
1255 | |
1256 | /// Parse a region argument if present. |
1257 | virtual ParseResult parseOptionalRegionArgument(OperandType &argument) = 0; |
1258 | |
1259 | //===--------------------------------------------------------------------===// |
1260 | // Successor Parsing |
1261 | //===--------------------------------------------------------------------===// |
1262 | |
1263 | /// Parse a single operation successor. |
1264 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1265 | |
1266 | /// Parse an optional operation successor. |
1267 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1268 | |
1269 | /// Parse a single operation successor and its operand list. |
1270 | virtual ParseResult |
1271 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1272 | |
1273 | //===--------------------------------------------------------------------===// |
1274 | // Type Parsing |
1275 | //===--------------------------------------------------------------------===// |
1276 | |
1277 | /// Parse a list of assignments of the form |
1278 | /// (%x1 = %y1, %x2 = %y2, ...) |
1279 | ParseResult parseAssignmentList(SmallVectorImpl<OperandType> &lhs, |
1280 | SmallVectorImpl<OperandType> &rhs) { |
1281 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1282 | if (!result.hasValue()) |
1283 | return emitError(getCurrentLocation(), "expected '('"); |
1284 | return result.getValue(); |
1285 | } |
1286 | |
1287 | virtual OptionalParseResult |
1288 | parseOptionalAssignmentList(SmallVectorImpl<OperandType> &lhs, |
1289 | SmallVectorImpl<OperandType> &rhs) = 0; |
1290 | |
1291 | /// Parse a list of assignments of the form |
1292 | /// (%x1 = %y1 : type1, %x2 = %y2 : type2, ...) |
1293 | ParseResult parseAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs, |
1294 | SmallVectorImpl<OperandType> &rhs, |
1295 | SmallVectorImpl<Type> &types) { |
1296 | OptionalParseResult result = |
1297 | parseOptionalAssignmentListWithTypes(lhs, rhs, types); |
1298 | if (!result.hasValue()) |
1299 | return emitError(getCurrentLocation(), "expected '('"); |
1300 | return result.getValue(); |
1301 | } |
1302 | |
1303 | virtual OptionalParseResult |
1304 | parseOptionalAssignmentListWithTypes(SmallVectorImpl<OperandType> &lhs, |
1305 | SmallVectorImpl<OperandType> &rhs, |
1306 | SmallVectorImpl<Type> &types) = 0; |
1307 | |
1308 | private: |
1309 | /// Parse either an operand list or a region argument list depending on |
1310 | /// whether isOperandList is true. |
1311 | ParseResult parseOperandOrRegionArgList(SmallVectorImpl<OperandType> &result, |
1312 | bool isOperandList, |
1313 | int requiredOperandCount, |
1314 | Delimiter delimiter); |
1315 | }; |
1316 | |
1317 | //===--------------------------------------------------------------------===// |
1318 | // Dialect OpAsm interface. |
1319 | //===--------------------------------------------------------------------===// |
1320 | |
1321 | /// A functor used to set the name of the start of a result group of an |
1322 | /// operation. See 'getAsmResultNames' below for more details. |
1323 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1324 | |
1325 | class OpAsmDialectInterface |
1326 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1327 | public: |
1328 | /// Holds the result of `getAlias` hook call. |
1329 | enum class AliasResult { |
1330 | /// The object (type or attribute) is not supported by the hook |
1331 | /// and an alias was not provided. |
1332 | NoAlias, |
1333 | /// An alias was provided, but it might be overriden by other hook. |
1334 | OverridableAlias, |
1335 | /// An alias was provided and it should be used |
1336 | /// (no other hooks will be checked). |
1337 | FinalAlias |
1338 | }; |
1339 | |
1340 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1341 | |
1342 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1343 | /// not necessarily a part of this dialect. The identifier is used in place of |
1344 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1345 | /// end with a numeric digit([0-9]+). |
1346 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1347 | return AliasResult::NoAlias; |
1348 | } |
1349 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1350 | return AliasResult::NoAlias; |
1351 | } |
1352 | |
1353 | }; |
1354 | } // namespace mlir |
1355 | |
1356 | //===--------------------------------------------------------------------===// |
1357 | // Operation OpAsm interface. |
1358 | //===--------------------------------------------------------------------===// |
1359 | |
1360 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1361 | #include "mlir/IR/OpAsmInterface.h.inc" |
1362 | |
1363 | #endif |