Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/test/lib/Dialect/Test/TestDialect.cpp
Warning:line 844, column 26
The right operand of '<' is a garbage value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name TestDialect.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_INCLUDE_TESTS -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/test/lib/Dialect/Test -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/test/lib/Dialect/Test -I include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-09-04-125545-48738-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/test/lib/Dialect/Test/TestDialect.cpp

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/test/lib/Dialect/Test/TestDialect.cpp

1//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "TestDialect.h"
10#include "TestAttributes.h"
11#include "TestInterfaces.h"
12#include "TestTypes.h"
13#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
14#include "mlir/Dialect/DLTI/DLTI.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/Dialect/Tensor/IR/Tensor.h"
17#include "mlir/IR/AsmState.h"
18#include "mlir/IR/BuiltinAttributes.h"
19#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/DialectImplementation.h"
22#include "mlir/IR/ExtensibleDialect.h"
23#include "mlir/IR/MLIRContext.h"
24#include "mlir/IR/OperationSupport.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/IR/TypeUtilities.h"
27#include "mlir/IR/Verifier.h"
28#include "mlir/Interfaces/InferIntRangeInterface.h"
29#include "mlir/Reducer/ReductionPatternInterface.h"
30#include "mlir/Transforms/FoldUtils.h"
31#include "mlir/Transforms/InliningUtils.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringSwitch.h"
35
36// Include this before the using namespace lines below to
37// test that we don't have namespace dependencies.
38#include "TestOpsDialect.cpp.inc"
39
40using namespace mlir;
41using namespace test;
42
43void test::registerTestDialect(DialectRegistry &registry) {
44 registry.insert<TestDialect>();
45}
46
47//===----------------------------------------------------------------------===//
48// TestDialect Interfaces
49//===----------------------------------------------------------------------===//
50
51namespace {
52
53/// Testing the correctness of some traits.
54static_assert(
55 llvm::is_detected<OpTrait::has_implicit_terminator_t,
56 SingleBlockImplicitTerminatorOp>::value,
57 "has_implicit_terminator_t does not match SingleBlockImplicitTerminatorOp");
58static_assert(OpTrait::hasSingleBlockImplicitTerminator<
59 SingleBlockImplicitTerminatorOp>::value,
60 "hasSingleBlockImplicitTerminator does not match "
61 "SingleBlockImplicitTerminatorOp");
62
63struct TestResourceBlobManagerInterface
64 : public ResourceBlobManagerDialectInterfaceBase<
65 TestDialectResourceBlobHandle> {
66 using ResourceBlobManagerDialectInterfaceBase<
67 TestDialectResourceBlobHandle>::ResourceBlobManagerDialectInterfaceBase;
68};
69
70// Test support for interacting with the AsmPrinter.
71struct TestOpAsmInterface : public OpAsmDialectInterface {
72 using OpAsmDialectInterface::OpAsmDialectInterface;
73 TestOpAsmInterface(Dialect *dialect, TestResourceBlobManagerInterface &mgr)
74 : OpAsmDialectInterface(dialect), blobManager(mgr) {}
75
76 //===------------------------------------------------------------------===//
77 // Aliases
78 //===------------------------------------------------------------------===//
79
80 AliasResult getAlias(Attribute attr, raw_ostream &os) const final {
81 StringAttr strAttr = attr.dyn_cast<StringAttr>();
82 if (!strAttr)
83 return AliasResult::NoAlias;
84
85 // Check the contents of the string attribute to see what the test alias
86 // should be named.
87 Optional<StringRef> aliasName =
88 StringSwitch<Optional<StringRef>>(strAttr.getValue())
89 .Case("alias_test:dot_in_name", StringRef("test.alias"))
90 .Case("alias_test:trailing_digit", StringRef("test_alias0"))
91 .Case("alias_test:prefixed_digit", StringRef("0_test_alias"))
92 .Case("alias_test:sanitize_conflict_a",
93 StringRef("test_alias_conflict0"))
94 .Case("alias_test:sanitize_conflict_b",
95 StringRef("test_alias_conflict0_"))
96 .Case("alias_test:tensor_encoding", StringRef("test_encoding"))
97 .Default(llvm::None);
98 if (!aliasName)
99 return AliasResult::NoAlias;
100
101 os << *aliasName;
102 return AliasResult::FinalAlias;
103 }
104
105 AliasResult getAlias(Type type, raw_ostream &os) const final {
106 if (auto tupleType = type.dyn_cast<TupleType>()) {
107 if (tupleType.size() > 0 &&
108 llvm::all_of(tupleType.getTypes(), [](Type elemType) {
109 return elemType.isa<SimpleAType>();
110 })) {
111 os << "test_tuple";
112 return AliasResult::FinalAlias;
113 }
114 }
115 if (auto intType = type.dyn_cast<TestIntegerType>()) {
116 if (intType.getSignedness() ==
117 TestIntegerType::SignednessSemantics::Unsigned &&
118 intType.getWidth() == 8) {
119 os << "test_ui8";
120 return AliasResult::FinalAlias;
121 }
122 }
123 if (auto recType = type.dyn_cast<TestRecursiveType>()) {
124 if (recType.getName() == "type_to_alias") {
125 // We only make alias for a specific recursive type.
126 os << "testrec";
127 return AliasResult::FinalAlias;
128 }
129 }
130 return AliasResult::NoAlias;
131 }
132
133 //===------------------------------------------------------------------===//
134 // Resources
135 //===------------------------------------------------------------------===//
136
137 std::string
138 getResourceKey(const AsmDialectResourceHandle &handle) const override {
139 return cast<TestDialectResourceBlobHandle>(handle).getKey().str();
140 }
141
142 FailureOr<AsmDialectResourceHandle>
143 declareResource(StringRef key) const final {
144 return blobManager.insert(key);
145 }
146
147 LogicalResult parseResource(AsmParsedResourceEntry &entry) const final {
148 FailureOr<AsmResourceBlob> blob = entry.parseAsBlob();
149 if (failed(blob))
150 return failure();
151
152 // Update the blob for this entry.
153 blobManager.update(entry.getKey(), std::move(*blob));
154 return success();
155 }
156
157 void
158 buildResources(Operation *op,
159 const SetVector<AsmDialectResourceHandle> &referencedResources,
160 AsmResourceBuilder &provider) const final {
161 blobManager.buildResources(provider, referencedResources.getArrayRef());
162 }
163
164private:
165 /// The blob manager for the dialect.
166 TestResourceBlobManagerInterface &blobManager;
167};
168
169struct TestDialectFoldInterface : public DialectFoldInterface {
170 using DialectFoldInterface::DialectFoldInterface;
171
172 /// Registered hook to check if the given region, which is attached to an
173 /// operation that is *not* isolated from above, should be used when
174 /// materializing constants.
175 bool shouldMaterializeInto(Region *region) const final {
176 // If this is a one region operation, then insert into it.
177 return isa<OneRegionOp>(region->getParentOp());
178 }
179};
180
181/// This class defines the interface for handling inlining with standard
182/// operations.
183struct TestInlinerInterface : public DialectInlinerInterface {
184 using DialectInlinerInterface::DialectInlinerInterface;
185
186 //===--------------------------------------------------------------------===//
187 // Analysis Hooks
188 //===--------------------------------------------------------------------===//
189
190 bool isLegalToInline(Operation *call, Operation *callable,
191 bool wouldBeCloned) const final {
192 // Don't allow inlining calls that are marked `noinline`.
193 return !call->hasAttr("noinline");
194 }
195 bool isLegalToInline(Region *, Region *, bool,
196 BlockAndValueMapping &) const final {
197 // Inlining into test dialect regions is legal.
198 return true;
199 }
200 bool isLegalToInline(Operation *, Region *, bool,
201 BlockAndValueMapping &) const final {
202 return true;
203 }
204
205 bool shouldAnalyzeRecursively(Operation *op) const final {
206 // Analyze recursively if this is not a functional region operation, it
207 // froms a separate functional scope.
208 return !isa<FunctionalRegionOp>(op);
209 }
210
211 //===--------------------------------------------------------------------===//
212 // Transformation Hooks
213 //===--------------------------------------------------------------------===//
214
215 /// Handle the given inlined terminator by replacing it with a new operation
216 /// as necessary.
217 void handleTerminator(Operation *op,
218 ArrayRef<Value> valuesToRepl) const final {
219 // Only handle "test.return" here.
220 auto returnOp = dyn_cast<TestReturnOp>(op);
221 if (!returnOp)
222 return;
223
224 // Replace the values directly with the return operands.
225 assert(returnOp.getNumOperands() == valuesToRepl.size())(static_cast <bool> (returnOp.getNumOperands() == valuesToRepl
.size()) ? void (0) : __assert_fail ("returnOp.getNumOperands() == valuesToRepl.size()"
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 225, __extension__
__PRETTY_FUNCTION__))
;
226 for (const auto &it : llvm::enumerate(returnOp.getOperands()))
227 valuesToRepl[it.index()].replaceAllUsesWith(it.value());
228 }
229
230 /// Attempt to materialize a conversion for a type mismatch between a call
231 /// from this dialect, and a callable region. This method should generate an
232 /// operation that takes 'input' as the only operand, and produces a single
233 /// result of 'resultType'. If a conversion can not be generated, nullptr
234 /// should be returned.
235 Operation *materializeCallConversion(OpBuilder &builder, Value input,
236 Type resultType,
237 Location conversionLoc) const final {
238 // Only allow conversion for i16/i32 types.
239 if (!(resultType.isSignlessInteger(16) ||
240 resultType.isSignlessInteger(32)) ||
241 !(input.getType().isSignlessInteger(16) ||
242 input.getType().isSignlessInteger(32)))
243 return nullptr;
244 return builder.create<TestCastOp>(conversionLoc, resultType, input);
245 }
246
247 void processInlinedCallBlocks(
248 Operation *call,
249 iterator_range<Region::iterator> inlinedBlocks) const final {
250 if (!isa<ConversionCallOp>(call))
251 return;
252
253 // Set attributed on all ops in the inlined blocks.
254 for (Block &block : inlinedBlocks) {
255 block.walk([&](Operation *op) {
256 op->setAttr("inlined_conversion", UnitAttr::get(call->getContext()));
257 });
258 }
259 }
260};
261
262struct TestReductionPatternInterface : public DialectReductionPatternInterface {
263public:
264 TestReductionPatternInterface(Dialect *dialect)
265 : DialectReductionPatternInterface(dialect) {}
266
267 void populateReductionPatterns(RewritePatternSet &patterns) const final {
268 populateTestReductionPatterns(patterns);
269 }
270};
271
272} // namespace
273
274//===----------------------------------------------------------------------===//
275// Dynamic operations
276//===----------------------------------------------------------------------===//
277
278std::unique_ptr<DynamicOpDefinition> getDynamicGenericOp(TestDialect *dialect) {
279 return DynamicOpDefinition::get(
280 "dynamic_generic", dialect, [](Operation *op) { return success(); },
281 [](Operation *op) { return success(); });
282}
283
284std::unique_ptr<DynamicOpDefinition>
285getDynamicOneOperandTwoResultsOp(TestDialect *dialect) {
286 return DynamicOpDefinition::get(
287 "dynamic_one_operand_two_results", dialect,
288 [](Operation *op) {
289 if (op->getNumOperands() != 1) {
290 op->emitOpError()
291 << "expected 1 operand, but had " << op->getNumOperands();
292 return failure();
293 }
294 if (op->getNumResults() != 2) {
295 op->emitOpError()
296 << "expected 2 results, but had " << op->getNumResults();
297 return failure();
298 }
299 return success();
300 },
301 [](Operation *op) { return success(); });
302}
303
304std::unique_ptr<DynamicOpDefinition>
305getDynamicCustomParserPrinterOp(TestDialect *dialect) {
306 auto verifier = [](Operation *op) {
307 if (op->getNumOperands() == 0 && op->getNumResults() == 0)
308 return success();
309 op->emitError() << "operation should have no operands and no results";
310 return failure();
311 };
312 auto regionVerifier = [](Operation *op) { return success(); };
313
314 auto parser = [](OpAsmParser &parser, OperationState &state) {
315 return parser.parseKeyword("custom_keyword");
316 };
317
318 auto printer = [](Operation *op, OpAsmPrinter &printer, llvm::StringRef) {
319 printer << op->getName() << " custom_keyword";
320 };
321
322 return DynamicOpDefinition::get("dynamic_custom_parser_printer", dialect,
323 verifier, regionVerifier, parser, printer);
324}
325
326//===----------------------------------------------------------------------===//
327// TestDialect
328//===----------------------------------------------------------------------===//
329
330static void testSideEffectOpGetEffect(
331 Operation *op,
332 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>> &effects);
333
334// This is the implementation of a dialect fallback for `TestEffectOpInterface`.
335struct TestOpEffectInterfaceFallback
336 : public TestEffectOpInterface::FallbackModel<
337 TestOpEffectInterfaceFallback> {
338 static bool classof(Operation *op) {
339 bool isSupportedOp =
340 op->getName().getStringRef() == "test.unregistered_side_effect_op";
341 assert(isSupportedOp && "Unexpected dispatch")(static_cast <bool> (isSupportedOp && "Unexpected dispatch"
) ? void (0) : __assert_fail ("isSupportedOp && \"Unexpected dispatch\""
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 341, __extension__
__PRETTY_FUNCTION__))
;
342 return isSupportedOp;
343 }
344
345 void
346 getEffects(Operation *op,
347 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
348 &effects) const {
349 testSideEffectOpGetEffect(op, effects);
350 }
351};
352
353void TestDialect::initialize() {
354 registerAttributes();
355 registerTypes();
356 addOperations<
357#define GET_OP_LIST
358#include "TestOps.cpp.inc"
359 >();
360 registerDynamicOp(getDynamicGenericOp(this));
361 registerDynamicOp(getDynamicOneOperandTwoResultsOp(this));
362 registerDynamicOp(getDynamicCustomParserPrinterOp(this));
363
364 auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
365 addInterface<TestOpAsmInterface>(blobInterface);
366
367 addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
368 TestReductionPatternInterface>();
369 allowUnknownOperations();
370
371 // Instantiate our fallback op interface that we'll use on specific
372 // unregistered op.
373 fallbackEffectOpInterfaces = new TestOpEffectInterfaceFallback;
374}
375TestDialect::~TestDialect() {
376 delete static_cast<TestOpEffectInterfaceFallback *>(
377 fallbackEffectOpInterfaces);
378}
379
380Operation *TestDialect::materializeConstant(OpBuilder &builder, Attribute value,
381 Type type, Location loc) {
382 return builder.create<TestOpConstant>(loc, type, value);
383}
384
385::mlir::LogicalResult FormatInferType2Op::inferReturnTypes(
386 ::mlir::MLIRContext *context, ::llvm::Optional<::mlir::Location> location,
387 ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes,
388 ::mlir::RegionRange regions,
389 ::llvm::SmallVectorImpl<::mlir::Type> &inferredReturnTypes) {
390 inferredReturnTypes.assign({::mlir::IntegerType::get(context, 16)});
391 return ::mlir::success();
392}
393
394void *TestDialect::getRegisteredInterfaceForOp(TypeID typeID,
395 OperationName opName) {
396 if (opName.getIdentifier() == "test.unregistered_side_effect_op" &&
397 typeID == TypeID::get<TestEffectOpInterface>())
398 return fallbackEffectOpInterfaces;
399 return nullptr;
400}
401
402LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
403 NamedAttribute namedAttr) {
404 if (namedAttr.getName() == "test.invalid_attr")
405 return op->emitError() << "invalid to use 'test.invalid_attr'";
406 return success();
407}
408
409LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
410 unsigned regionIndex,
411 unsigned argIndex,
412 NamedAttribute namedAttr) {
413 if (namedAttr.getName() == "test.invalid_attr")
414 return op->emitError() << "invalid to use 'test.invalid_attr'";
415 return success();
416}
417
418LogicalResult
419TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
420 unsigned resultIndex,
421 NamedAttribute namedAttr) {
422 if (namedAttr.getName() == "test.invalid_attr")
423 return op->emitError() << "invalid to use 'test.invalid_attr'";
424 return success();
425}
426
427Optional<Dialect::ParseOpHook>
428TestDialect::getParseOperationHook(StringRef opName) const {
429 if (opName == "test.dialect_custom_printer") {
430 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
431 return parser.parseKeyword("custom_format");
432 }};
433 }
434 if (opName == "test.dialect_custom_format_fallback") {
435 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
436 return parser.parseKeyword("custom_format_fallback");
437 }};
438 }
439 if (opName == "test.dialect_custom_printer.with.dot") {
440 return ParseOpHook{[](OpAsmParser &parser, OperationState &state) {
441 return ParseResult::success();
442 }};
443 }
444 return None;
445}
446
447llvm::unique_function<void(Operation *, OpAsmPrinter &)>
448TestDialect::getOperationPrinter(Operation *op) const {
449 StringRef opName = op->getName().getStringRef();
450 if (opName == "test.dialect_custom_printer") {
451 return [](Operation *op, OpAsmPrinter &printer) {
452 printer.getStream() << " custom_format";
453 };
454 }
455 if (opName == "test.dialect_custom_format_fallback") {
456 return [](Operation *op, OpAsmPrinter &printer) {
457 printer.getStream() << " custom_format_fallback";
458 };
459 }
460 return {};
461}
462
463//===----------------------------------------------------------------------===//
464// TypedAttrOp
465//===----------------------------------------------------------------------===//
466
467/// Parse an attribute with a given type.
468static ParseResult parseAttrElideType(AsmParser &parser, TypeAttr type,
469 Attribute &attr) {
470 return parser.parseAttribute(attr, type.getValue());
471}
472
473/// Print an attribute without its type.
474static void printAttrElideType(AsmPrinter &printer, Operation *op,
475 TypeAttr type, Attribute attr) {
476 printer.printAttributeWithoutType(attr);
477}
478
479//===----------------------------------------------------------------------===//
480// TestBranchOp
481//===----------------------------------------------------------------------===//
482
483SuccessorOperands TestBranchOp::getSuccessorOperands(unsigned index) {
484 assert(index == 0 && "invalid successor index")(static_cast <bool> (index == 0 && "invalid successor index"
) ? void (0) : __assert_fail ("index == 0 && \"invalid successor index\""
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 484, __extension__
__PRETTY_FUNCTION__))
;
485 return SuccessorOperands(getTargetOperandsMutable());
486}
487
488//===----------------------------------------------------------------------===//
489// TestProducingBranchOp
490//===----------------------------------------------------------------------===//
491
492SuccessorOperands TestProducingBranchOp::getSuccessorOperands(unsigned index) {
493 assert(index <= 1 && "invalid successor index")(static_cast <bool> (index <= 1 && "invalid successor index"
) ? void (0) : __assert_fail ("index <= 1 && \"invalid successor index\""
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 493, __extension__
__PRETTY_FUNCTION__))
;
494 if (index == 1)
495 return SuccessorOperands(getFirstOperandsMutable());
496 return SuccessorOperands(getSecondOperandsMutable());
497}
498
499//===----------------------------------------------------------------------===//
500// TestProducingBranchOp
501//===----------------------------------------------------------------------===//
502
503SuccessorOperands TestInternalBranchOp::getSuccessorOperands(unsigned index) {
504 assert(index <= 1 && "invalid successor index")(static_cast <bool> (index <= 1 && "invalid successor index"
) ? void (0) : __assert_fail ("index <= 1 && \"invalid successor index\""
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 504, __extension__
__PRETTY_FUNCTION__))
;
505 if (index == 0)
506 return SuccessorOperands(0, getSuccessOperandsMutable());
507 return SuccessorOperands(1, getErrorOperandsMutable());
508}
509
510//===----------------------------------------------------------------------===//
511// TestDialectCanonicalizerOp
512//===----------------------------------------------------------------------===//
513
514static LogicalResult
515dialectCanonicalizationPattern(TestDialectCanonicalizerOp op,
516 PatternRewriter &rewriter) {
517 rewriter.replaceOpWithNewOp<arith::ConstantOp>(
518 op, rewriter.getI32IntegerAttr(42));
519 return success();
520}
521
522void TestDialect::getCanonicalizationPatterns(
523 RewritePatternSet &results) const {
524 results.add(&dialectCanonicalizationPattern);
525}
526
527//===----------------------------------------------------------------------===//
528// TestCallOp
529//===----------------------------------------------------------------------===//
530
531LogicalResult TestCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
532 // Check that the callee attribute was specified.
533 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
534 if (!fnAttr)
535 return emitOpError("requires a 'callee' symbol reference attribute");
536 if (!symbolTable.lookupNearestSymbolFrom<FunctionOpInterface>(*this, fnAttr))
537 return emitOpError() << "'" << fnAttr.getValue()
538 << "' does not reference a valid function";
539 return success();
540}
541
542//===----------------------------------------------------------------------===//
543// TestFoldToCallOp
544//===----------------------------------------------------------------------===//
545
546namespace {
547struct FoldToCallOpPattern : public OpRewritePattern<FoldToCallOp> {
548 using OpRewritePattern<FoldToCallOp>::OpRewritePattern;
549
550 LogicalResult matchAndRewrite(FoldToCallOp op,
551 PatternRewriter &rewriter) const override {
552 rewriter.replaceOpWithNewOp<func::CallOp>(op, TypeRange(),
553 op.getCalleeAttr(), ValueRange());
554 return success();
555 }
556};
557} // namespace
558
559void FoldToCallOp::getCanonicalizationPatterns(RewritePatternSet &results,
560 MLIRContext *context) {
561 results.add<FoldToCallOpPattern>(context);
562}
563
564//===----------------------------------------------------------------------===//
565// Test Format* operations
566//===----------------------------------------------------------------------===//
567
568//===----------------------------------------------------------------------===//
569// Parsing
570
571static ParseResult parseCustomOptionalOperand(
572 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
573 if (succeeded(parser.parseOptionalLParen())) {
574 optOperand.emplace();
575 if (parser.parseOperand(*optOperand) || parser.parseRParen())
576 return failure();
577 }
578 return success();
579}
580
581static ParseResult parseCustomDirectiveOperands(
582 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
583 Optional<OpAsmParser::UnresolvedOperand> &optOperand,
584 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands) {
585 if (parser.parseOperand(operand))
586 return failure();
587 if (succeeded(parser.parseOptionalComma())) {
588 optOperand.emplace();
589 if (parser.parseOperand(*optOperand))
590 return failure();
591 }
592 if (parser.parseArrow() || parser.parseLParen() ||
593 parser.parseOperandList(varOperands) || parser.parseRParen())
594 return failure();
595 return success();
596}
597static ParseResult
598parseCustomDirectiveResults(OpAsmParser &parser, Type &operandType,
599 Type &optOperandType,
600 SmallVectorImpl<Type> &varOperandTypes) {
601 if (parser.parseColon())
602 return failure();
603
604 if (parser.parseType(operandType))
605 return failure();
606 if (succeeded(parser.parseOptionalComma())) {
607 if (parser.parseType(optOperandType))
608 return failure();
609 }
610 if (parser.parseArrow() || parser.parseLParen() ||
611 parser.parseTypeList(varOperandTypes) || parser.parseRParen())
612 return failure();
613 return success();
614}
615static ParseResult
616parseCustomDirectiveWithTypeRefs(OpAsmParser &parser, Type operandType,
617 Type optOperandType,
618 const SmallVectorImpl<Type> &varOperandTypes) {
619 if (parser.parseKeyword("type_refs_capture"))
620 return failure();
621
622 Type operandType2, optOperandType2;
623 SmallVector<Type, 1> varOperandTypes2;
624 if (parseCustomDirectiveResults(parser, operandType2, optOperandType2,
625 varOperandTypes2))
626 return failure();
627
628 if (operandType != operandType2 || optOperandType != optOperandType2 ||
629 varOperandTypes != varOperandTypes2)
630 return failure();
631
632 return success();
633}
634static ParseResult parseCustomDirectiveOperandsAndTypes(
635 OpAsmParser &parser, OpAsmParser::UnresolvedOperand &operand,
636 Optional<OpAsmParser::UnresolvedOperand> &optOperand,
637 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &varOperands,
638 Type &operandType, Type &optOperandType,
639 SmallVectorImpl<Type> &varOperandTypes) {
640 if (parseCustomDirectiveOperands(parser, operand, optOperand, varOperands) ||
641 parseCustomDirectiveResults(parser, operandType, optOperandType,
642 varOperandTypes))
643 return failure();
644 return success();
645}
646static ParseResult parseCustomDirectiveRegions(
647 OpAsmParser &parser, Region &region,
648 SmallVectorImpl<std::unique_ptr<Region>> &varRegions) {
649 if (parser.parseRegion(region))
650 return failure();
651 if (failed(parser.parseOptionalComma()))
652 return success();
653 std::unique_ptr<Region> varRegion = std::make_unique<Region>();
654 if (parser.parseRegion(*varRegion))
655 return failure();
656 varRegions.emplace_back(std::move(varRegion));
657 return success();
658}
659static ParseResult
660parseCustomDirectiveSuccessors(OpAsmParser &parser, Block *&successor,
661 SmallVectorImpl<Block *> &varSuccessors) {
662 if (parser.parseSuccessor(successor))
663 return failure();
664 if (failed(parser.parseOptionalComma()))
665 return success();
666 Block *varSuccessor;
667 if (parser.parseSuccessor(varSuccessor))
668 return failure();
669 varSuccessors.append(2, varSuccessor);
670 return success();
671}
672static ParseResult parseCustomDirectiveAttributes(OpAsmParser &parser,
673 IntegerAttr &attr,
674 IntegerAttr &optAttr) {
675 if (parser.parseAttribute(attr))
676 return failure();
677 if (succeeded(parser.parseOptionalComma())) {
678 if (parser.parseAttribute(optAttr))
679 return failure();
680 }
681 return success();
682}
683
684static ParseResult parseCustomDirectiveAttrDict(OpAsmParser &parser,
685 NamedAttrList &attrs) {
686 return parser.parseOptionalAttrDict(attrs);
687}
688static ParseResult parseCustomDirectiveOptionalOperandRef(
689 OpAsmParser &parser, Optional<OpAsmParser::UnresolvedOperand> &optOperand) {
690 int64_t operandCount = 0;
691 if (parser.parseInteger(operandCount))
692 return failure();
693 bool expectedOptionalOperand = operandCount == 0;
694 return success(expectedOptionalOperand != optOperand.has_value());
695}
696
697//===----------------------------------------------------------------------===//
698// Printing
699
700static void printCustomOptionalOperand(OpAsmPrinter &printer, Operation *,
701 Value optOperand) {
702 if (optOperand)
703 printer << "(" << optOperand << ") ";
704}
705
706static void printCustomDirectiveOperands(OpAsmPrinter &printer, Operation *,
707 Value operand, Value optOperand,
708 OperandRange varOperands) {
709 printer << operand;
710 if (optOperand)
711 printer << ", " << optOperand;
712 printer << " -> (" << varOperands << ")";
713}
714static void printCustomDirectiveResults(OpAsmPrinter &printer, Operation *,
715 Type operandType, Type optOperandType,
716 TypeRange varOperandTypes) {
717 printer << " : " << operandType;
718 if (optOperandType)
719 printer << ", " << optOperandType;
720 printer << " -> (" << varOperandTypes << ")";
721}
722static void printCustomDirectiveWithTypeRefs(OpAsmPrinter &printer,
723 Operation *op, Type operandType,
724 Type optOperandType,
725 TypeRange varOperandTypes) {
726 printer << " type_refs_capture ";
727 printCustomDirectiveResults(printer, op, operandType, optOperandType,
728 varOperandTypes);
729}
730static void printCustomDirectiveOperandsAndTypes(
731 OpAsmPrinter &printer, Operation *op, Value operand, Value optOperand,
732 OperandRange varOperands, Type operandType, Type optOperandType,
733 TypeRange varOperandTypes) {
734 printCustomDirectiveOperands(printer, op, operand, optOperand, varOperands);
735 printCustomDirectiveResults(printer, op, operandType, optOperandType,
736 varOperandTypes);
737}
738static void printCustomDirectiveRegions(OpAsmPrinter &printer, Operation *,
739 Region &region,
740 MutableArrayRef<Region> varRegions) {
741 printer.printRegion(region);
742 if (!varRegions.empty()) {
743 printer << ", ";
744 for (Region &region : varRegions)
745 printer.printRegion(region);
746 }
747}
748static void printCustomDirectiveSuccessors(OpAsmPrinter &printer, Operation *,
749 Block *successor,
750 SuccessorRange varSuccessors) {
751 printer << successor;
752 if (!varSuccessors.empty())
753 printer << ", " << varSuccessors.front();
754}
755static void printCustomDirectiveAttributes(OpAsmPrinter &printer, Operation *,
756 Attribute attribute,
757 Attribute optAttribute) {
758 printer << attribute;
759 if (optAttribute)
760 printer << ", " << optAttribute;
761}
762
763static void printCustomDirectiveAttrDict(OpAsmPrinter &printer, Operation *op,
764 DictionaryAttr attrs) {
765 printer.printOptionalAttrDict(attrs.getValue());
766}
767
768static void printCustomDirectiveOptionalOperandRef(OpAsmPrinter &printer,
769 Operation *op,
770 Value optOperand) {
771 printer << (optOperand ? "1" : "0");
772}
773
774//===----------------------------------------------------------------------===//
775// Test IsolatedRegionOp - parse passthrough region arguments.
776//===----------------------------------------------------------------------===//
777
778ParseResult IsolatedRegionOp::parse(OpAsmParser &parser,
779 OperationState &result) {
780 // Parse the input operand.
781 OpAsmParser::Argument argInfo;
782 argInfo.type = parser.getBuilder().getIndexType();
783 if (parser.parseOperand(argInfo.ssaName) ||
784 parser.resolveOperand(argInfo.ssaName, argInfo.type, result.operands))
785 return failure();
786
787 // Parse the body region, and reuse the operand info as the argument info.
788 Region *body = result.addRegion();
789 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/true);
790}
791
792void IsolatedRegionOp::print(OpAsmPrinter &p) {
793 p << "test.isolated_region ";
794 p.printOperand(getOperand());
795 p.shadowRegionArgs(getRegion(), getOperand());
796 p << ' ';
797 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
798}
799
800//===----------------------------------------------------------------------===//
801// Test SSACFGRegionOp
802//===----------------------------------------------------------------------===//
803
804RegionKind SSACFGRegionOp::getRegionKind(unsigned index) {
805 return RegionKind::SSACFG;
806}
807
808//===----------------------------------------------------------------------===//
809// Test GraphRegionOp
810//===----------------------------------------------------------------------===//
811
812RegionKind GraphRegionOp::getRegionKind(unsigned index) {
813 return RegionKind::Graph;
814}
815
816//===----------------------------------------------------------------------===//
817// Test AffineScopeOp
818//===----------------------------------------------------------------------===//
819
820ParseResult AffineScopeOp::parse(OpAsmParser &parser, OperationState &result) {
821 // Parse the body region, and reuse the operand info as the argument info.
822 Region *body = result.addRegion();
823 return parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{});
824}
825
826void AffineScopeOp::print(OpAsmPrinter &p) {
827 p << "test.affine_scope ";
828 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
829}
830
831//===----------------------------------------------------------------------===//
832// Test parser.
833//===----------------------------------------------------------------------===//
834
835ParseResult ParseIntegerLiteralOp::parse(OpAsmParser &parser,
836 OperationState &result) {
837 if (parser.parseOptionalColon())
1
Taking false branch
838 return success();
839 uint64_t numResults;
2
'numResults' declared without an initial value
840 if (parser.parseInteger(numResults))
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Taking false branch
841 return failure();
842
843 IndexType type = parser.getBuilder().getIndexType();
844 for (unsigned i = 0; i < numResults; ++i)
13
The right operand of '<' is a garbage value
845 result.addTypes(type);
846 return success();
847}
848
849void ParseIntegerLiteralOp::print(OpAsmPrinter &p) {
850 if (unsigned numResults = getNumResults())
851 p << " : " << numResults;
852}
853
854ParseResult ParseWrappedKeywordOp::parse(OpAsmParser &parser,
855 OperationState &result) {
856 StringRef keyword;
857 if (parser.parseKeyword(&keyword))
858 return failure();
859 result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
860 return success();
861}
862
863void ParseWrappedKeywordOp::print(OpAsmPrinter &p) { p << " " << getKeyword(); }
864
865//===----------------------------------------------------------------------===//
866// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
867
868ParseResult WrappingRegionOp::parse(OpAsmParser &parser,
869 OperationState &result) {
870 if (parser.parseKeyword("wraps"))
871 return failure();
872
873 // Parse the wrapped op in a region
874 Region &body = *result.addRegion();
875 body.push_back(new Block);
876 Block &block = body.back();
877 Operation *wrappedOp = parser.parseGenericOperation(&block, block.begin());
878 if (!wrappedOp)
879 return failure();
880
881 // Create a return terminator in the inner region, pass as operand to the
882 // terminator the returned values from the wrapped operation.
883 SmallVector<Value, 8> returnOperands(wrappedOp->getResults());
884 OpBuilder builder(parser.getContext());
885 builder.setInsertionPointToEnd(&block);
886 builder.create<TestReturnOp>(wrappedOp->getLoc(), returnOperands);
887
888 // Get the results type for the wrapping op from the terminator operands.
889 Operation &returnOp = body.back().back();
890 result.types.append(returnOp.operand_type_begin(),
891 returnOp.operand_type_end());
892
893 // Use the location of the wrapped op for the "test.wrapping_region" op.
894 result.location = wrappedOp->getLoc();
895
896 return success();
897}
898
899void WrappingRegionOp::print(OpAsmPrinter &p) {
900 p << " wraps ";
901 p.printGenericOp(&getRegion().front().front());
902}
903
904//===----------------------------------------------------------------------===//
905// Test PrettyPrintedRegionOp - exercising the following parser APIs
906// parseGenericOperationAfterOpName
907// parseCustomOperationName
908//===----------------------------------------------------------------------===//
909
910ParseResult PrettyPrintedRegionOp::parse(OpAsmParser &parser,
911 OperationState &result) {
912
913 SMLoc loc = parser.getCurrentLocation();
914 Location currLocation = parser.getEncodedSourceLoc(loc);
915
916 // Parse the operands.
917 SmallVector<OpAsmParser::UnresolvedOperand, 2> operands;
918 if (parser.parseOperandList(operands))
919 return failure();
920
921 // Check if we are parsing the pretty-printed version
922 // test.pretty_printed_region start <inner-op> end : <functional-type>
923 // Else fallback to parsing the "non pretty-printed" version.
924 if (!succeeded(parser.parseOptionalKeyword("start")))
925 return parser.parseGenericOperationAfterOpName(
926 result, llvm::makeArrayRef(operands));
927
928 FailureOr<OperationName> parseOpNameInfo = parser.parseCustomOperationName();
929 if (failed(parseOpNameInfo))
930 return failure();
931
932 StringAttr innerOpName = parseOpNameInfo->getIdentifier();
933
934 FunctionType opFntype;
935 Optional<Location> explicitLoc;
936 if (parser.parseKeyword("end") || parser.parseColon() ||
937 parser.parseType(opFntype) ||
938 parser.parseOptionalLocationSpecifier(explicitLoc))
939 return failure();
940
941 // If location of the op is explicitly provided, then use it; Else use
942 // the parser's current location.
943 Location opLoc = explicitLoc.value_or(currLocation);
944
945 // Derive the SSA-values for op's operands.
946 if (parser.resolveOperands(operands, opFntype.getInputs(), loc,
947 result.operands))
948 return failure();
949
950 // Add a region for op.
951 Region &region = *result.addRegion();
952
953 // Create a basic-block inside op's region.
954 Block &block = region.emplaceBlock();
955
956 // Create and insert an "inner-op" operation in the block.
957 // Just for testing purposes, we can assume that inner op is a binary op with
958 // result and operand types all same as the test-op's first operand.
959 Type innerOpType = opFntype.getInput(0);
960 Value lhs = block.addArgument(innerOpType, opLoc);
961 Value rhs = block.addArgument(innerOpType, opLoc);
962
963 OpBuilder builder(parser.getBuilder().getContext());
964 builder.setInsertionPointToStart(&block);
965
966 Operation *innerOp =
967 builder.create(opLoc, innerOpName, /*operands=*/{lhs, rhs}, innerOpType);
968
969 // Insert a return statement in the block returning the inner-op's result.
970 builder.create<TestReturnOp>(innerOp->getLoc(), innerOp->getResults());
971
972 // Populate the op operation-state with result-type and location.
973 result.addTypes(opFntype.getResults());
974 result.location = innerOp->getLoc();
975
976 return success();
977}
978
979void PrettyPrintedRegionOp::print(OpAsmPrinter &p) {
980 p << ' ';
981 p.printOperands(getOperands());
982
983 Operation &innerOp = getRegion().front().front();
984 // Assuming that region has a single non-terminator inner-op, if the inner-op
985 // meets some criteria (which in this case is a simple one based on the name
986 // of inner-op), then we can print the entire region in a succinct way.
987 // Here we assume that the prototype of "special.op" can be trivially derived
988 // while parsing it back.
989 if (innerOp.getName().getStringRef().equals("special.op")) {
990 p << " start special.op end";
991 } else {
992 p << " (";
993 p.printRegion(getRegion());
994 p << ")";
995 }
996
997 p << " : ";
998 p.printFunctionalType(*this);
999}
1000
1001//===----------------------------------------------------------------------===//
1002// Test PolyForOp - parse list of region arguments.
1003//===----------------------------------------------------------------------===//
1004
1005ParseResult PolyForOp::parse(OpAsmParser &parser, OperationState &result) {
1006 SmallVector<OpAsmParser::Argument, 4> ivsInfo;
1007 // Parse list of region arguments without a delimiter.
1008 if (parser.parseArgumentList(ivsInfo, OpAsmParser::Delimiter::None))
1009 return failure();
1010
1011 // Parse the body region.
1012 Region *body = result.addRegion();
1013 for (auto &iv : ivsInfo)
1014 iv.type = parser.getBuilder().getIndexType();
1015 return parser.parseRegion(*body, ivsInfo);
1016}
1017
1018void PolyForOp::print(OpAsmPrinter &p) { p.printGenericOp(*this); }
1019
1020void PolyForOp::getAsmBlockArgumentNames(Region &region,
1021 OpAsmSetValueNameFn setNameFn) {
1022 auto arrayAttr = getOperation()->getAttrOfType<ArrayAttr>("arg_names");
1023 if (!arrayAttr)
1024 return;
1025 auto args = getRegion().front().getArguments();
1026 auto e = std::min(arrayAttr.size(), args.size());
1027 for (unsigned i = 0; i < e; ++i) {
1028 if (auto strAttr = arrayAttr[i].dyn_cast<StringAttr>())
1029 setNameFn(args[i], strAttr.getValue());
1030 }
1031}
1032
1033//===----------------------------------------------------------------------===//
1034// Test removing op with inner ops.
1035//===----------------------------------------------------------------------===//
1036
1037namespace {
1038struct TestRemoveOpWithInnerOps
1039 : public OpRewritePattern<TestOpWithRegionPattern> {
1040 using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
1041
1042 void initialize() { setDebugName("TestRemoveOpWithInnerOps"); }
1043
1044 LogicalResult matchAndRewrite(TestOpWithRegionPattern op,
1045 PatternRewriter &rewriter) const override {
1046 rewriter.eraseOp(op);
1047 return success();
1048 }
1049};
1050} // namespace
1051
1052void TestOpWithRegionPattern::getCanonicalizationPatterns(
1053 RewritePatternSet &results, MLIRContext *context) {
1054 results.add<TestRemoveOpWithInnerOps>(context);
1055}
1056
1057OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
1058 return getOperand();
1059}
1060
1061OpFoldResult TestOpConstant::fold(ArrayRef<Attribute> operands) {
1062 return getValue();
1063}
1064
1065LogicalResult TestOpWithVariadicResultsAndFolder::fold(
1066 ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
1067 for (Value input : this->getOperands()) {
1068 results.push_back(input);
1069 }
1070 return success();
1071}
1072
1073OpFoldResult TestOpInPlaceFold::fold(ArrayRef<Attribute> operands) {
1074 assert(operands.size() == 1)(static_cast <bool> (operands.size() == 1) ? void (0) :
__assert_fail ("operands.size() == 1", "mlir/test/lib/Dialect/Test/TestDialect.cpp"
, 1074, __extension__ __PRETTY_FUNCTION__))
;
1075 if (operands.front()) {
1076 (*this)->setAttr("attr", operands.front());
1077 return getResult();
1078 }
1079 return {};
1080}
1081
1082OpFoldResult TestPassthroughFold::fold(ArrayRef<Attribute> operands) {
1083 return getOperand();
1084}
1085
1086LogicalResult OpWithInferTypeInterfaceOp::inferReturnTypes(
1087 MLIRContext *, Optional<Location> location, ValueRange operands,
1088 DictionaryAttr attributes, RegionRange regions,
1089 SmallVectorImpl<Type> &inferredReturnTypes) {
1090 if (operands[0].getType() != operands[1].getType()) {
1091 return emitOptionalError(location, "operand type mismatch ",
1092 operands[0].getType(), " vs ",
1093 operands[1].getType());
1094 }
1095 inferredReturnTypes.assign({operands[0].getType()});
1096 return success();
1097}
1098
1099// TODO: We should be able to only define either inferReturnType or
1100// refineReturnType, currently only refineReturnType can be omitted.
1101LogicalResult OpWithRefineTypeInterfaceOp::inferReturnTypes(
1102 MLIRContext *context, Optional<Location> location, ValueRange operands,
1103 DictionaryAttr attributes, RegionRange regions,
1104 SmallVectorImpl<Type> &returnTypes) {
1105 returnTypes.clear();
1106 return OpWithRefineTypeInterfaceOp::refineReturnTypes(
1107 context, location, operands, attributes, regions, returnTypes);
1108}
1109
1110LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
1111 MLIRContext *, Optional<Location> location, ValueRange operands,
1112 DictionaryAttr attributes, RegionRange regions,
1113 SmallVectorImpl<Type> &returnTypes) {
1114 if (operands[0].getType() != operands[1].getType()) {
1115 return emitOptionalError(location, "operand type mismatch ",
1116 operands[0].getType(), " vs ",
1117 operands[1].getType());
1118 }
1119 // TODO: Add helper to make this more concise to write.
1120 if (returnTypes.empty())
1121 returnTypes.resize(1, nullptr);
1122 if (returnTypes[0] && returnTypes[0] != operands[0].getType())
1123 return emitOptionalError(location,
1124 "required first operand and result to match");
1125 returnTypes[0] = operands[0].getType();
1126 return success();
1127}
1128
1129LogicalResult OpWithShapedTypeInferTypeInterfaceOp::inferReturnTypeComponents(
1130 MLIRContext *context, Optional<Location> location, ValueShapeRange operands,
1131 DictionaryAttr attributes, RegionRange regions,
1132 SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
1133 // Create return type consisting of the last element of the first operand.
1134 auto operandType = operands.front().getType();
1135 auto sval = operandType.dyn_cast<ShapedType>();
1136 if (!sval) {
1137 return emitOptionalError(location, "only shaped type operands allowed");
1138 }
1139 int64_t dim =
1140 sval.hasRank() ? sval.getShape().front() : ShapedType::kDynamicSize;
1141 auto type = IntegerType::get(context, 17);
1142 inferredReturnShapes.push_back(ShapedTypeComponents({dim}, type));
1143 return success();
1144}
1145
1146LogicalResult OpWithShapedTypeInferTypeInterfaceOp::reifyReturnTypeShapes(
1147 OpBuilder &builder, ValueRange operands,
1148 llvm::SmallVectorImpl<Value> &shapes) {
1149 shapes = SmallVector<Value, 1>{
1150 builder.createOrFold<tensor::DimOp>(getLoc(), operands.front(), 0)};
1151 return success();
1152}
1153
1154LogicalResult OpWithResultShapeInterfaceOp::reifyReturnTypeShapes(
1155 OpBuilder &builder, ValueRange operands,
1156 llvm::SmallVectorImpl<Value> &shapes) {
1157 Location loc = getLoc();
1158 shapes.reserve(operands.size());
1159 for (Value operand : llvm::reverse(operands)) {
1160 auto rank = operand.getType().cast<RankedTensorType>().getRank();
1161 auto currShape = llvm::to_vector<4>(
1162 llvm::map_range(llvm::seq<int64_t>(0, rank), [&](int64_t dim) -> Value {
1163 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1164 }));
1165 shapes.push_back(builder.create<tensor::FromElementsOp>(
1166 getLoc(), RankedTensorType::get({rank}, builder.getIndexType()),
1167 currShape));
1168 }
1169 return success();
1170}
1171
1172LogicalResult OpWithResultShapePerDimInterfaceOp::reifyResultShapes(
1173 OpBuilder &builder, ReifiedRankedShapedTypeDims &shapes) {
1174 Location loc = getLoc();
1175 shapes.reserve(getNumOperands());
1176 for (Value operand : llvm::reverse(getOperands())) {
1177 auto currShape = llvm::to_vector<4>(llvm::map_range(
1178 llvm::seq<int64_t>(
1179 0, operand.getType().cast<RankedTensorType>().getRank()),
1180 [&](int64_t dim) -> Value {
1181 return builder.createOrFold<tensor::DimOp>(loc, operand, dim);
1182 }));
1183 shapes.emplace_back(std::move(currShape));
1184 }
1185 return success();
1186}
1187
1188//===----------------------------------------------------------------------===//
1189// Test SideEffect interfaces
1190//===----------------------------------------------------------------------===//
1191
1192namespace {
1193/// A test resource for side effects.
1194struct TestResource : public SideEffects::Resource::Base<TestResource> {
1195 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResource)static ::mlir::TypeID resolveTypeID() { static ::mlir::SelfOwningTypeID
id; return id; } static_assert( ::mlir::detail::InlineTypeIDResolver
::has_resolve_typeid< TestResource>::value, "`MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID` must be placed in a "
"public section of `" "TestResource" "`");
1196
1197 StringRef getName() final { return "<Test>"; }
1198};
1199} // namespace
1200
1201static void testSideEffectOpGetEffect(
1202 Operation *op,
1203 SmallVectorImpl<SideEffects::EffectInstance<TestEffects::Effect>>
1204 &effects) {
1205 auto effectsAttr = op->getAttrOfType<AffineMapAttr>("effect_parameter");
1206 if (!effectsAttr)
1207 return;
1208
1209 effects.emplace_back(TestEffects::Concrete::get(), effectsAttr);
1210}
1211
1212void SideEffectOp::getEffects(
1213 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1214 // Check for an effects attribute on the op instance.
1215 ArrayAttr effectsAttr = (*this)->getAttrOfType<ArrayAttr>("effects");
1216 if (!effectsAttr)
1217 return;
1218
1219 // If there is one, it is an array of dictionary attributes that hold
1220 // information on the effects of this operation.
1221 for (Attribute element : effectsAttr) {
1222 DictionaryAttr effectElement = element.cast<DictionaryAttr>();
1223
1224 // Get the specific memory effect.
1225 MemoryEffects::Effect *effect =
1226 StringSwitch<MemoryEffects::Effect *>(
1227 effectElement.get("effect").cast<StringAttr>().getValue())
1228 .Case("allocate", MemoryEffects::Allocate::get())
1229 .Case("free", MemoryEffects::Free::get())
1230 .Case("read", MemoryEffects::Read::get())
1231 .Case("write", MemoryEffects::Write::get());
1232
1233 // Check for a non-default resource to use.
1234 SideEffects::Resource *resource = SideEffects::DefaultResource::get();
1235 if (effectElement.get("test_resource"))
1236 resource = TestResource::get();
1237
1238 // Check for a result to affect.
1239 if (effectElement.get("on_result"))
1240 effects.emplace_back(effect, getResult(), resource);
1241 else if (Attribute ref = effectElement.get("on_reference"))
1242 effects.emplace_back(effect, ref.cast<SymbolRefAttr>(), resource);
1243 else
1244 effects.emplace_back(effect, resource);
1245 }
1246}
1247
1248void SideEffectOp::getEffects(
1249 SmallVectorImpl<TestEffects::EffectInstance> &effects) {
1250 testSideEffectOpGetEffect(getOperation(), effects);
1251}
1252
1253//===----------------------------------------------------------------------===//
1254// StringAttrPrettyNameOp
1255//===----------------------------------------------------------------------===//
1256
1257// This op has fancy handling of its SSA result name.
1258ParseResult StringAttrPrettyNameOp::parse(OpAsmParser &parser,
1259 OperationState &result) {
1260 // Add the result types.
1261 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i)
1262 result.addTypes(parser.getBuilder().getIntegerType(32));
1263
1264 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1265 return failure();
1266
1267 // If the attribute dictionary contains no 'names' attribute, infer it from
1268 // the SSA name (if specified).
1269 bool hadNames = llvm::any_of(result.attributes, [](NamedAttribute attr) {
1270 return attr.getName() == "names";
1271 });
1272
1273 // If there was no name specified, check to see if there was a useful name
1274 // specified in the asm file.
1275 if (hadNames || parser.getNumResults() == 0)
1276 return success();
1277
1278 SmallVector<StringRef, 4> names;
1279 auto *context = result.getContext();
1280
1281 for (size_t i = 0, e = parser.getNumResults(); i != e; ++i) {
1282 auto resultName = parser.getResultName(i);
1283 StringRef nameStr;
1284 if (!resultName.first.empty() && !isdigit(resultName.first[0]))
1285 nameStr = resultName.first;
1286
1287 names.push_back(nameStr);
1288 }
1289
1290 auto namesAttr = parser.getBuilder().getStrArrayAttr(names);
1291 result.attributes.push_back({StringAttr::get(context, "names"), namesAttr});
1292 return success();
1293}
1294
1295void StringAttrPrettyNameOp::print(OpAsmPrinter &p) {
1296 // Note that we only need to print the "name" attribute if the asmprinter
1297 // result name disagrees with it. This can happen in strange cases, e.g.
1298 // when there are conflicts.
1299 bool namesDisagree = getNames().size() != getNumResults();
1300
1301 SmallString<32> resultNameStr;
1302 for (size_t i = 0, e = getNumResults(); i != e && !namesDisagree; ++i) {
1303 resultNameStr.clear();
1304 llvm::raw_svector_ostream tmpStream(resultNameStr);
1305 p.printOperand(getResult(i), tmpStream);
1306
1307 auto expectedName = getNames()[i].dyn_cast<StringAttr>();
1308 if (!expectedName ||
1309 tmpStream.str().drop_front() != expectedName.getValue()) {
1310 namesDisagree = true;
1311 }
1312 }
1313
1314 if (namesDisagree)
1315 p.printOptionalAttrDictWithKeyword((*this)->getAttrs());
1316 else
1317 p.printOptionalAttrDictWithKeyword((*this)->getAttrs(), {"names"});
1318}
1319
1320// We set the SSA name in the asm syntax to the contents of the name
1321// attribute.
1322void StringAttrPrettyNameOp::getAsmResultNames(
1323 function_ref<void(Value, StringRef)> setNameFn) {
1324
1325 auto value = getNames();
1326 for (size_t i = 0, e = value.size(); i != e; ++i)
1327 if (auto str = value[i].dyn_cast<StringAttr>())
1328 if (!str.getValue().empty())
1329 setNameFn(getResult(i), str.getValue());
1330}
1331
1332void CustomResultsNameOp::getAsmResultNames(
1333 function_ref<void(Value, StringRef)> setNameFn) {
1334 ArrayAttr value = getNames();
1335 for (size_t i = 0, e = value.size(); i != e; ++i)
1336 if (auto str = value[i].dyn_cast<StringAttr>())
1337 if (!str.getValue().empty())
1338 setNameFn(getResult(i), str.getValue());
1339}
1340
1341//===----------------------------------------------------------------------===//
1342// ResultTypeWithTraitOp
1343//===----------------------------------------------------------------------===//
1344
1345LogicalResult ResultTypeWithTraitOp::verify() {
1346 if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
1347 return success();
1348 return emitError("result type should have trait 'TestTypeTrait'");
1349}
1350
1351//===----------------------------------------------------------------------===//
1352// AttrWithTraitOp
1353//===----------------------------------------------------------------------===//
1354
1355LogicalResult AttrWithTraitOp::verify() {
1356 if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
1357 return success();
1358 return emitError("'attr' attribute should have trait 'TestAttrTrait'");
1359}
1360
1361//===----------------------------------------------------------------------===//
1362// RegionIfOp
1363//===----------------------------------------------------------------------===//
1364
1365void RegionIfOp::print(OpAsmPrinter &p) {
1366 p << " ";
1367 p.printOperands(getOperands());
1368 p << ": " << getOperandTypes();
1369 p.printArrowTypeList(getResultTypes());
1370 p << " then ";
1371 p.printRegion(getThenRegion(),
1372 /*printEntryBlockArgs=*/true,
1373 /*printBlockTerminators=*/true);
1374 p << " else ";
1375 p.printRegion(getElseRegion(),
1376 /*printEntryBlockArgs=*/true,
1377 /*printBlockTerminators=*/true);
1378 p << " join ";
1379 p.printRegion(getJoinRegion(),
1380 /*printEntryBlockArgs=*/true,
1381 /*printBlockTerminators=*/true);
1382}
1383
1384ParseResult RegionIfOp::parse(OpAsmParser &parser, OperationState &result) {
1385 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandInfos;
1386 SmallVector<Type, 2> operandTypes;
1387
1388 result.regions.reserve(3);
1389 Region *thenRegion = result.addRegion();
1390 Region *elseRegion = result.addRegion();
1391 Region *joinRegion = result.addRegion();
1392
1393 // Parse operand, type and arrow type lists.
1394 if (parser.parseOperandList(operandInfos) ||
1395 parser.parseColonTypeList(operandTypes) ||
1396 parser.parseArrowTypeList(result.types))
1397 return failure();
1398
1399 // Parse all attached regions.
1400 if (parser.parseKeyword("then") || parser.parseRegion(*thenRegion, {}, {}) ||
1401 parser.parseKeyword("else") || parser.parseRegion(*elseRegion, {}, {}) ||
1402 parser.parseKeyword("join") || parser.parseRegion(*joinRegion, {}, {}))
1403 return failure();
1404
1405 return parser.resolveOperands(operandInfos, operandTypes,
1406 parser.getCurrentLocation(), result.operands);
1407}
1408
1409OperandRange RegionIfOp::getSuccessorEntryOperands(Optional<unsigned> index) {
1410 assert(index && *index < 2 && "invalid region index")(static_cast <bool> (index && *index < 2 &&
"invalid region index") ? void (0) : __assert_fail ("index && *index < 2 && \"invalid region index\""
, "mlir/test/lib/Dialect/Test/TestDialect.cpp", 1410, __extension__
__PRETTY_FUNCTION__))
;
1411 return getOperands();
1412}
1413
1414void RegionIfOp::getSuccessorRegions(
1415 Optional<unsigned> index, ArrayRef<Attribute> operands,
1416 SmallVectorImpl<RegionSuccessor> &regions) {
1417 // We always branch to the join region.
1418 if (index.has_value()) {
1419 if (index.value() < 2)
1420 regions.push_back(RegionSuccessor(&getJoinRegion(), getJoinArgs()));
1421 else
1422 regions.push_back(RegionSuccessor(getResults()));
1423 return;
1424 }
1425
1426 // The then and else regions are the entry regions of this op.
1427 regions.push_back(RegionSuccessor(&getThenRegion(), getThenArgs()));
1428 regions.push_back(RegionSuccessor(&getElseRegion(), getElseArgs()));
1429}
1430
1431void RegionIfOp::getRegionInvocationBounds(
1432 ArrayRef<Attribute> operands,
1433 SmallVectorImpl<InvocationBounds> &invocationBounds) {
1434 // Each region is invoked at most once.
1435 invocationBounds.assign(/*NumElts=*/3, /*Elt=*/{0, 1});
1436}
1437
1438//===----------------------------------------------------------------------===//
1439// AnyCondOp
1440//===----------------------------------------------------------------------===//
1441
1442void AnyCondOp::getSuccessorRegions(Optional<unsigned> index,
1443 ArrayRef<Attribute> operands,
1444 SmallVectorImpl<RegionSuccessor> &regions) {
1445 // The parent op branches into the only region, and the region branches back
1446 // to the parent op.
1447 if (!index)
1448 regions.emplace_back(&getRegion());
1449 else
1450 regions.emplace_back(getResults());
1451}
1452
1453void AnyCondOp::getRegionInvocationBounds(
1454 ArrayRef<Attribute> operands,
1455 SmallVectorImpl<InvocationBounds> &invocationBounds) {
1456 invocationBounds.emplace_back(1, 1);
1457}
1458
1459//===----------------------------------------------------------------------===//
1460// SingleNoTerminatorCustomAsmOp
1461//===----------------------------------------------------------------------===//
1462
1463ParseResult SingleNoTerminatorCustomAsmOp::parse(OpAsmParser &parser,
1464 OperationState &state) {
1465 Region *body = state.addRegion();
1466 if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
1467 return failure();
1468 return success();
1469}
1470
1471void SingleNoTerminatorCustomAsmOp::print(OpAsmPrinter &printer) {
1472 printer.printRegion(
1473 getRegion(), /*printEntryBlockArgs=*/false,
1474 // This op has a single block without terminators. But explicitly mark
1475 // as not printing block terminators for testing.
1476 /*printBlockTerminators=*/false);
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// TestVerifiersOp
1481//===----------------------------------------------------------------------===//
1482
1483LogicalResult TestVerifiersOp::verify() {
1484 if (!getRegion().hasOneBlock())
1485 return emitOpError("`hasOneBlock` trait hasn't been verified");
1486
1487 Operation *definingOp = getInput().getDefiningOp();
1488 if (definingOp && failed(mlir::verify(definingOp)))
1489 return emitOpError("operand hasn't been verified");
1490
1491 emitRemark("success run of verifier");
1492
1493 return success();
1494}
1495
1496LogicalResult TestVerifiersOp::verifyRegions() {
1497 if (!getRegion().hasOneBlock())
1498 return emitOpError("`hasOneBlock` trait hasn't been verified");
1499
1500 for (Block &block : getRegion())
1501 for (Operation &op : block)
1502 if (failed(mlir::verify(&op)))
1503 return emitOpError("nested op hasn't been verified");
1504
1505 emitRemark("success run of region verifier");
1506
1507 return success();
1508}
1509
1510//===----------------------------------------------------------------------===//
1511// Test InferIntRangeInterface
1512//===----------------------------------------------------------------------===//
1513
1514void TestWithBoundsOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1515 SetIntRangeFn setResultRanges) {
1516 setResultRanges(getResult(), {getUmin(), getUmax(), getSmin(), getSmax()});
1517}
1518
1519ParseResult TestWithBoundsRegionOp::parse(OpAsmParser &parser,
1520 OperationState &result) {
1521 if (parser.parseOptionalAttrDict(result.attributes))
1522 return failure();
1523
1524 // Parse the input argument
1525 OpAsmParser::Argument argInfo;
1526 argInfo.type = parser.getBuilder().getIndexType();
1527 if (failed(parser.parseArgument(argInfo)))
1528 return failure();
1529
1530 // Parse the body region, and reuse the operand info as the argument info.
1531 Region *body = result.addRegion();
1532 return parser.parseRegion(*body, argInfo, /*enableNameShadowing=*/false);
1533}
1534
1535void TestWithBoundsRegionOp::print(OpAsmPrinter &p) {
1536 p.printOptionalAttrDict((*this)->getAttrs());
1537 p << ' ';
1538 p.printRegionArgument(getRegion().getArgument(0), /*argAttrs=*/{},
1539 /*omitType=*/true);
1540 p << ' ';
1541 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
1542}
1543
1544void TestWithBoundsRegionOp::inferResultRanges(
1545 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1546 Value arg = getRegion().getArgument(0);
1547 setResultRanges(arg, {getUmin(), getUmax(), getSmin(), getSmax()});
1548}
1549
1550void TestIncrementOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
1551 SetIntRangeFn setResultRanges) {
1552 const ConstantIntRanges &range = argRanges[0];
1553 APInt one(range.umin().getBitWidth(), 1);
1554 setResultRanges(getResult(),
1555 {range.umin().uadd_sat(one), range.umax().uadd_sat(one),
1556 range.smin().sadd_sat(one), range.smax().sadd_sat(one)});
1557}
1558
1559void TestReflectBoundsOp::inferResultRanges(
1560 ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
1561 const ConstantIntRanges &range = argRanges[0];
1562 MLIRContext *ctx = getContext();
1563 Builder b(ctx);
1564 setUminAttr(b.getIndexAttr(range.umin().getZExtValue()));
1565 setUmaxAttr(b.getIndexAttr(range.umax().getZExtValue()));
1566 setSminAttr(b.getIndexAttr(range.smin().getSExtValue()));
1567 setSmaxAttr(b.getIndexAttr(range.smax().getSExtValue()));
1568 setResultRanges(getResult(), range);
1569}
1570
1571#include "TestOpEnums.cpp.inc"
1572#include "TestOpInterfaces.cpp.inc"
1573#include "TestTypeInterfaces.cpp.inc"
1574
1575#define GET_OP_CLASSES
1576#include "TestOps.cpp.inc"

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() {}
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a newline and indent the printer to the start of the current
331 /// operation.
332 virtual void printNewline() = 0;
333
334 /// Print a block argument in the usual format of:
335 /// %ssaName : type {attr1=42} loc("here")
336 /// where location printing is controlled by the standard internal option.
337 /// You may pass omitType=true to not print a type, and pass an empty
338 /// attribute list if you don't care for attributes.
339 virtual void printRegionArgument(BlockArgument arg,
340 ArrayRef<NamedAttribute> argAttrs = {},
341 bool omitType = false) = 0;
342
343 /// Print implementations for various things an operation contains.
344 virtual void printOperand(Value value) = 0;
345 virtual void printOperand(Value value, raw_ostream &os) = 0;
346
347 /// Print a comma separated list of operands.
348 template <typename ContainerType>
349 void printOperands(const ContainerType &container) {
350 printOperands(container.begin(), container.end());
351 }
352
353 /// Print a comma separated list of operands.
354 template <typename IteratorType>
355 void printOperands(IteratorType it, IteratorType end) {
356 if (it == end)
357 return;
358 printOperand(*it);
359 for (++it; it != end; ++it) {
360 getStream() << ", ";
361 printOperand(*it);
362 }
363 }
364
365 /// Print the given successor.
366 virtual void printSuccessor(Block *successor) = 0;
367
368 /// Print the successor and its operands.
369 virtual void printSuccessorAndUseList(Block *successor,
370 ValueRange succOperands) = 0;
371
372 /// If the specified operation has attributes, print out an attribute
373 /// dictionary with their values. elidedAttrs allows the client to ignore
374 /// specific well known attributes, commonly used if the attribute value is
375 /// printed some other way (like as a fixed operand).
376 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
377 ArrayRef<StringRef> elidedAttrs = {}) = 0;
378
379 /// If the specified operation has attributes, print out an attribute
380 /// dictionary prefixed with 'attributes'.
381 virtual void
382 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
383 ArrayRef<StringRef> elidedAttrs = {}) = 0;
384
385 /// Print the entire operation with the default generic assembly form.
386 /// If `printOpName` is true, then the operation name is printed (the default)
387 /// otherwise it is omitted and the print will start with the operand list.
388 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
389
390 /// Prints a region.
391 /// If 'printEntryBlockArgs' is false, the arguments of the
392 /// block are not printed. If 'printBlockTerminator' is false, the terminator
393 /// operation of the block is not printed. If printEmptyBlock is true, then
394 /// the block header is printed even if the block is empty.
395 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
396 bool printBlockTerminators = true,
397 bool printEmptyBlock = false) = 0;
398
399 /// Renumber the arguments for the specified region to the same names as the
400 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
401 /// operations. If any entry in namesToUse is null, the corresponding
402 /// argument name is left alone.
403 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
404
405 /// Prints an affine map of SSA ids, where SSA id names are used in place
406 /// of dims/symbols.
407 /// Operand values must come from single-result sources, and be valid
408 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
409 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
410 ValueRange operands) = 0;
411
412 /// Prints an affine expression of SSA ids with SSA id names used instead of
413 /// dims and symbols.
414 /// Operand values must come from single-result sources, and be valid
415 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
416 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
417 ValueRange symOperands) = 0;
418
419 /// Print the complete type of an operation in functional form.
420 void printFunctionalType(Operation *op);
421 using AsmPrinter::printFunctionalType;
422};
423
424// Make the implementations convenient to use.
425inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
426 p.printOperand(value);
427 return p;
428}
429
430template <typename T,
431 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
432 !std::is_convertible<T &, Value &>::value,
433 T> * = nullptr>
434inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
435 p.printOperands(values);
436 return p;
437}
438
439inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
440 p.printSuccessor(value);
441 return p;
442}
443
444//===----------------------------------------------------------------------===//
445// AsmParser
446//===----------------------------------------------------------------------===//
447
448/// This base class exposes generic asm parser hooks, usable across the various
449/// derived parsers.
450class AsmParser {
451public:
452 AsmParser() = default;
453 virtual ~AsmParser();
454
455 MLIRContext *getContext() const;
456
457 /// Return the location of the original name token.
458 virtual SMLoc getNameLoc() const = 0;
459
460 //===--------------------------------------------------------------------===//
461 // Utilities
462 //===--------------------------------------------------------------------===//
463
464 /// Emit a diagnostic at the specified location and return failure.
465 virtual InFlightDiagnostic emitError(SMLoc loc,
466 const Twine &message = {}) = 0;
467
468 /// Return a builder which provides useful access to MLIRContext, global
469 /// objects like types and attributes.
470 virtual Builder &getBuilder() const = 0;
471
472 /// Get the location of the next token and store it into the argument. This
473 /// always succeeds.
474 virtual SMLoc getCurrentLocation() = 0;
475 ParseResult getCurrentLocation(SMLoc *loc) {
476 *loc = getCurrentLocation();
477 return success();
478 }
479
480 /// Re-encode the given source location as an MLIR location and return it.
481 /// Note: This method should only be used when a `Location` is necessary, as
482 /// the encoding process is not efficient.
483 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
484
485 //===--------------------------------------------------------------------===//
486 // Token Parsing
487 //===--------------------------------------------------------------------===//
488
489 /// Parse a '->' token.
490 virtual ParseResult parseArrow() = 0;
491
492 /// Parse a '->' token if present
493 virtual ParseResult parseOptionalArrow() = 0;
494
495 /// Parse a `{` token.
496 virtual ParseResult parseLBrace() = 0;
497
498 /// Parse a `{` token if present.
499 virtual ParseResult parseOptionalLBrace() = 0;
500
501 /// Parse a `}` token.
502 virtual ParseResult parseRBrace() = 0;
503
504 /// Parse a `}` token if present.
505 virtual ParseResult parseOptionalRBrace() = 0;
506
507 /// Parse a `:` token.
508 virtual ParseResult parseColon() = 0;
509
510 /// Parse a `:` token if present.
511 virtual ParseResult parseOptionalColon() = 0;
512
513 /// Parse a `,` token.
514 virtual ParseResult parseComma() = 0;
515
516 /// Parse a `,` token if present.
517 virtual ParseResult parseOptionalComma() = 0;
518
519 /// Parse a `=` token.
520 virtual ParseResult parseEqual() = 0;
521
522 /// Parse a `=` token if present.
523 virtual ParseResult parseOptionalEqual() = 0;
524
525 /// Parse a '<' token.
526 virtual ParseResult parseLess() = 0;
527
528 /// Parse a '<' token if present.
529 virtual ParseResult parseOptionalLess() = 0;
530
531 /// Parse a '>' token.
532 virtual ParseResult parseGreater() = 0;
533
534 /// Parse a '>' token if present.
535 virtual ParseResult parseOptionalGreater() = 0;
536
537 /// Parse a '?' token.
538 virtual ParseResult parseQuestion() = 0;
539
540 /// Parse a '?' token if present.
541 virtual ParseResult parseOptionalQuestion() = 0;
542
543 /// Parse a '+' token.
544 virtual ParseResult parsePlus() = 0;
545
546 /// Parse a '+' token if present.
547 virtual ParseResult parseOptionalPlus() = 0;
548
549 /// Parse a '*' token.
550 virtual ParseResult parseStar() = 0;
551
552 /// Parse a '*' token if present.
553 virtual ParseResult parseOptionalStar() = 0;
554
555 /// Parse a '|' token.
556 virtual ParseResult parseVerticalBar() = 0;
557
558 /// Parse a '|' token if present.
559 virtual ParseResult parseOptionalVerticalBar() = 0;
560
561 /// Parse a quoted string token.
562 ParseResult parseString(std::string *string) {
563 auto loc = getCurrentLocation();
564 if (parseOptionalString(string))
565 return emitError(loc, "expected string");
566 return success();
567 }
568
569 /// Parse a quoted string token if present.
570 virtual ParseResult parseOptionalString(std::string *string) = 0;
571
572 /// Parse a `(` token.
573 virtual ParseResult parseLParen() = 0;
574
575 /// Parse a `(` token if present.
576 virtual ParseResult parseOptionalLParen() = 0;
577
578 /// Parse a `)` token.
579 virtual ParseResult parseRParen() = 0;
580
581 /// Parse a `)` token if present.
582 virtual ParseResult parseOptionalRParen() = 0;
583
584 /// Parse a `[` token.
585 virtual ParseResult parseLSquare() = 0;
586
587 /// Parse a `[` token if present.
588 virtual ParseResult parseOptionalLSquare() = 0;
589
590 /// Parse a `]` token.
591 virtual ParseResult parseRSquare() = 0;
592
593 /// Parse a `]` token if present.
594 virtual ParseResult parseOptionalRSquare() = 0;
595
596 /// Parse a `...` token if present;
597 virtual ParseResult parseOptionalEllipsis() = 0;
598
599 /// Parse a floating point value from the stream.
600 virtual ParseResult parseFloat(double &result) = 0;
601
602 /// Parse an integer value from the stream.
603 template <typename IntT>
604 ParseResult parseInteger(IntT &result) {
605 auto loc = getCurrentLocation();
606 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
607 if (!parseResult.has_value())
9
Taking true branch
608 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
609 return *parseResult;
610 }
611
612 /// Parse an optional integer value from the stream.
613 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
614
615 template <typename IntT>
616 OptionalParseResult parseOptionalInteger(IntT &result) {
617 auto loc = getCurrentLocation();
618
619 // Parse the unsigned variant.
620 APInt uintResult;
621 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
622 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
623 return parseResult;
7
Returning without writing to 'result'
624
625 // Try to convert to the provided integer type. sextOrTrunc is correct even
626 // for unsigned types because parseOptionalInteger ensures the sign bit is
627 // zero for non-negated integers.
628 result =
629 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
630 if (APInt(uintResult.getBitWidth(), result) != uintResult)
631 return emitError(loc, "integer value too large");
632 return success();
633 }
634
635 /// These are the supported delimiters around operand lists and region
636 /// argument lists, used by parseOperandList.
637 enum class Delimiter {
638 /// Zero or more operands with no delimiters.
639 None,
640 /// Parens surrounding zero or more operands.
641 Paren,
642 /// Square brackets surrounding zero or more operands.
643 Square,
644 /// <> brackets surrounding zero or more operands.
645 LessGreater,
646 /// {} brackets surrounding zero or more operands.
647 Braces,
648 /// Parens supporting zero or more operands, or nothing.
649 OptionalParen,
650 /// Square brackets supporting zero or more ops, or nothing.
651 OptionalSquare,
652 /// <> brackets supporting zero or more ops, or nothing.
653 OptionalLessGreater,
654 /// {} brackets surrounding zero or more operands, or nothing.
655 OptionalBraces,
656 };
657
658 /// Parse a list of comma-separated items with an optional delimiter. If a
659 /// delimiter is provided, then an empty list is allowed. If not, then at
660 /// least one element will be parsed.
661 ///
662 /// contextMessage is an optional message appended to "expected '('" sorts of
663 /// diagnostics when parsing the delimeters.
664 virtual ParseResult
665 parseCommaSeparatedList(Delimiter delimiter,
666 function_ref<ParseResult()> parseElementFn,
667 StringRef contextMessage = StringRef()) = 0;
668
669 /// Parse a comma separated list of elements that must have at least one entry
670 /// in it.
671 ParseResult
672 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
673 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
674 }
675
676 //===--------------------------------------------------------------------===//
677 // Keyword Parsing
678 //===--------------------------------------------------------------------===//
679
680 /// This class represents a StringSwitch like class that is useful for parsing
681 /// expected keywords. On construction, it invokes `parseKeyword` and
682 /// processes each of the provided cases statements until a match is hit. The
683 /// provided `ResultT` must be assignable from `failure()`.
684 template <typename ResultT = ParseResult>
685 class KeywordSwitch {
686 public:
687 KeywordSwitch(AsmParser &parser)
688 : parser(parser), loc(parser.getCurrentLocation()) {
689 if (failed(parser.parseKeywordOrCompletion(&keyword)))
690 result = failure();
691 }
692
693 /// Case that uses the provided value when true.
694 KeywordSwitch &Case(StringLiteral str, ResultT value) {
695 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
696 }
697 KeywordSwitch &Default(ResultT value) {
698 return Default([&](StringRef, SMLoc) { return std::move(value); });
699 }
700 /// Case that invokes the provided functor when true. The parameters passed
701 /// to the functor are the keyword, and the location of the keyword (in case
702 /// any errors need to be emitted).
703 template <typename FnT>
704 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
705 Case(StringLiteral str, FnT &&fn) {
706 if (result)
707 return *this;
708
709 // If the word was empty, record this as a completion.
710 if (keyword.empty())
711 parser.codeCompleteExpectedTokens(str);
712 else if (keyword == str)
713 result.emplace(std::move(fn(keyword, loc)));
714 return *this;
715 }
716 template <typename FnT>
717 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
718 Default(FnT &&fn) {
719 if (!result)
720 result.emplace(fn(keyword, loc));
721 return *this;
722 }
723
724 /// Returns true if this switch has a value yet.
725 bool hasValue() const { return result.has_value(); }
726
727 /// Return the result of the switch.
728 [[nodiscard]] operator ResultT() {
729 if (!result)
730 return parser.emitError(loc, "unexpected keyword: ") << keyword;
731 return std::move(*result);
732 }
733
734 private:
735 /// The parser used to construct this switch.
736 AsmParser &parser;
737
738 /// The location of the keyword, used to emit errors as necessary.
739 SMLoc loc;
740
741 /// The parsed keyword itself.
742 StringRef keyword;
743
744 /// The result of the switch statement or none if currently unknown.
745 Optional<ResultT> result;
746 };
747
748 /// Parse a given keyword.
749 ParseResult parseKeyword(StringRef keyword) {
750 return parseKeyword(keyword, "");
751 }
752 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
753
754 /// Parse a keyword into 'keyword'.
755 ParseResult parseKeyword(StringRef *keyword) {
756 auto loc = getCurrentLocation();
757 if (parseOptionalKeyword(keyword))
758 return emitError(loc, "expected valid keyword");
759 return success();
760 }
761
762 /// Parse the given keyword if present.
763 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
764
765 /// Parse a keyword, if present, into 'keyword'.
766 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
767
768 /// Parse a keyword, if present, and if one of the 'allowedValues',
769 /// into 'keyword'
770 virtual ParseResult
771 parseOptionalKeyword(StringRef *keyword,
772 ArrayRef<StringRef> allowedValues) = 0;
773
774 /// Parse a keyword or a quoted string.
775 ParseResult parseKeywordOrString(std::string *result) {
776 if (failed(parseOptionalKeywordOrString(result)))
777 return emitError(getCurrentLocation())
778 << "expected valid keyword or string";
779 return success();
780 }
781
782 /// Parse an optional keyword or string.
783 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
784
785 //===--------------------------------------------------------------------===//
786 // Attribute/Type Parsing
787 //===--------------------------------------------------------------------===//
788
789 /// Invoke the `getChecked` method of the given Attribute or Type class, using
790 /// the provided location to emit errors in the case of failure. Note that
791 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
792 /// context parameter.
793 template <typename T, typename... ParamsT>
794 auto getChecked(SMLoc loc, ParamsT &&...params) {
795 return T::getChecked([&] { return emitError(loc); },
796 std::forward<ParamsT>(params)...);
797 }
798 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
799 /// errors.
800 template <typename T, typename... ParamsT>
801 auto getChecked(ParamsT &&...params) {
802 return T::getChecked([&] { return emitError(getNameLoc()); },
803 std::forward<ParamsT>(params)...);
804 }
805
806 //===--------------------------------------------------------------------===//
807 // Attribute Parsing
808 //===--------------------------------------------------------------------===//
809
810 /// Parse an arbitrary attribute of a given type and return it in result.
811 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
812
813 /// Parse a custom attribute with the provided callback, unless the next
814 /// token is `#`, in which case the generic parser is invoked.
815 virtual ParseResult parseCustomAttributeWithFallback(
816 Attribute &result, Type type,
817 function_ref<ParseResult(Attribute &result, Type type)>
818 parseAttribute) = 0;
819
820 /// Parse an attribute of a specific kind and type.
821 template <typename AttrType>
822 ParseResult parseAttribute(AttrType &result, Type type = {}) {
823 SMLoc loc = getCurrentLocation();
824
825 // Parse any kind of attribute.
826 Attribute attr;
827 if (parseAttribute(attr, type))
828 return failure();
829
830 // Check for the right kind of attribute.
831 if (!(result = attr.dyn_cast<AttrType>()))
832 return emitError(loc, "invalid kind of attribute specified");
833
834 return success();
835 }
836
837 /// Parse an arbitrary attribute and return it in result. This also adds the
838 /// attribute to the specified attribute list with the specified name.
839 ParseResult parseAttribute(Attribute &result, StringRef attrName,
840 NamedAttrList &attrs) {
841 return parseAttribute(result, Type(), attrName, attrs);
842 }
843
844 /// Parse an attribute of a specific kind and type.
845 template <typename AttrType>
846 ParseResult parseAttribute(AttrType &result, StringRef attrName,
847 NamedAttrList &attrs) {
848 return parseAttribute(result, Type(), attrName, attrs);
849 }
850
851 /// Parse an arbitrary attribute of a given type and populate it in `result`.
852 /// This also adds the attribute to the specified attribute list with the
853 /// specified name.
854 template <typename AttrType>
855 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
856 NamedAttrList &attrs) {
857 SMLoc loc = getCurrentLocation();
858
859 // Parse any kind of attribute.
860 Attribute attr;
861 if (parseAttribute(attr, type))
862 return failure();
863
864 // Check for the right kind of attribute.
865 result = attr.dyn_cast<AttrType>();
866 if (!result)
867 return emitError(loc, "invalid kind of attribute specified");
868
869 attrs.append(attrName, result);
870 return success();
871 }
872
873 /// Trait to check if `AttrType` provides a `parse` method.
874 template <typename AttrType>
875 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
876 std::declval<Type>()));
877 template <typename AttrType>
878 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
879
880 /// Parse a custom attribute of a given type unless the next token is `#`, in
881 /// which case the generic parser is invoked. The parsed attribute is
882 /// populated in `result` and also added to the specified attribute list with
883 /// the specified name.
884 template <typename AttrType>
885 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
886 parseCustomAttributeWithFallback(AttrType &result, Type type,
887 StringRef attrName, NamedAttrList &attrs) {
888 SMLoc loc = getCurrentLocation();
889
890 // Parse any kind of attribute.
891 Attribute attr;
892 if (parseCustomAttributeWithFallback(
893 attr, type, [&](Attribute &result, Type type) -> ParseResult {
894 result = AttrType::parse(*this, type);
895 if (!result)
896 return failure();
897 return success();
898 }))
899 return failure();
900
901 // Check for the right kind of attribute.
902 result = attr.dyn_cast<AttrType>();
903 if (!result)
904 return emitError(loc, "invalid kind of attribute specified");
905
906 attrs.append(attrName, result);
907 return success();
908 }
909
910 /// SFINAE parsing method for Attribute that don't implement a parse method.
911 template <typename AttrType>
912 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
913 parseCustomAttributeWithFallback(AttrType &result, Type type,
914 StringRef attrName, NamedAttrList &attrs) {
915 return parseAttribute(result, type, attrName, attrs);
916 }
917
918 /// Parse a custom attribute of a given type unless the next token is `#`, in
919 /// which case the generic parser is invoked. The parsed attribute is
920 /// populated in `result`.
921 template <typename AttrType>
922 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
923 parseCustomAttributeWithFallback(AttrType &result) {
924 SMLoc loc = getCurrentLocation();
925
926 // Parse any kind of attribute.
927 Attribute attr;
928 if (parseCustomAttributeWithFallback(
929 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
930 result = AttrType::parse(*this, type);
931 return success(!!result);
932 }))
933 return failure();
934
935 // Check for the right kind of attribute.
936 result = attr.dyn_cast<AttrType>();
937 if (!result)
938 return emitError(loc, "invalid kind of attribute specified");
939 return success();
940 }
941
942 /// SFINAE parsing method for Attribute that don't implement a parse method.
943 template <typename AttrType>
944 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
945 parseCustomAttributeWithFallback(AttrType &result) {
946 return parseAttribute(result);
947 }
948
949 /// Parse an arbitrary optional attribute of a given type and return it in
950 /// result.
951 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
952 Type type = {}) = 0;
953
954 /// Parse an optional array attribute and return it in result.
955 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
956 Type type = {}) = 0;
957
958 /// Parse an optional string attribute and return it in result.
959 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
960 Type type = {}) = 0;
961
962 /// Parse an optional attribute of a specific type and add it to the list with
963 /// the specified name.
964 template <typename AttrType>
965 OptionalParseResult parseOptionalAttribute(AttrType &result,
966 StringRef attrName,
967 NamedAttrList &attrs) {
968 return parseOptionalAttribute(result, Type(), attrName, attrs);
969 }
970
971 /// Parse an optional attribute of a specific type and add it to the list with
972 /// the specified name.
973 template <typename AttrType>
974 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
975 StringRef attrName,
976 NamedAttrList &attrs) {
977 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
978 if (parseResult.has_value() && succeeded(*parseResult))
979 attrs.append(attrName, result);
980 return parseResult;
981 }
982
983 /// Parse a named dictionary into 'result' if it is present.
984 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
985
986 /// Parse a named dictionary into 'result' if the `attributes` keyword is
987 /// present.
988 virtual ParseResult
989 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
990
991 /// Parse an affine map instance into 'map'.
992 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
993
994 /// Parse an integer set instance into 'set'.
995 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
996
997 //===--------------------------------------------------------------------===//
998 // Identifier Parsing
999 //===--------------------------------------------------------------------===//
1000
1001 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1002 /// attribute named 'attrName'.
1003 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1004 NamedAttrList &attrs) {
1005 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1006 return emitError(getCurrentLocation())
1007 << "expected valid '@'-identifier for symbol name";
1008 return success();
1009 }
1010
1011 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1012 /// string attribute named 'attrName'.
1013 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1014 StringRef attrName,
1015 NamedAttrList &attrs) = 0;
1016
1017 //===--------------------------------------------------------------------===//
1018 // Resource Parsing
1019 //===--------------------------------------------------------------------===//
1020
1021 /// Parse a handle to a resource within the assembly format.
1022 template <typename ResourceT>
1023 FailureOr<ResourceT> parseResourceHandle() {
1024 SMLoc handleLoc = getCurrentLocation();
1025
1026 // Try to load the dialect that owns the handle.
1027 auto *dialect =
1028 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1029 if (!dialect) {
1030 return emitError(handleLoc)
1031 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1032 << "' is unknown";
1033 }
1034
1035 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1036 if (failed(handle))
1037 return failure();
1038 if (auto *result = dyn_cast<ResourceT>(&*handle))
1039 return std::move(*result);
1040 return emitError(handleLoc) << "provided resource handle differs from the "
1041 "expected resource type";
1042 }
1043
1044 //===--------------------------------------------------------------------===//
1045 // Type Parsing
1046 //===--------------------------------------------------------------------===//
1047
1048 /// Parse a type.
1049 virtual ParseResult parseType(Type &result) = 0;
1050
1051 /// Parse a custom type with the provided callback, unless the next
1052 /// token is `#`, in which case the generic parser is invoked.
1053 virtual ParseResult parseCustomTypeWithFallback(
1054 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1055
1056 /// Parse an optional type.
1057 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1058
1059 /// Parse a type of a specific type.
1060 template <typename TypeT>
1061 ParseResult parseType(TypeT &result) {
1062 SMLoc loc = getCurrentLocation();
1063
1064 // Parse any kind of type.
1065 Type type;
1066 if (parseType(type))
1067 return failure();
1068
1069 // Check for the right kind of type.
1070 result = type.dyn_cast<TypeT>();
1071 if (!result)
1072 return emitError(loc, "invalid kind of type specified");
1073
1074 return success();
1075 }
1076
1077 /// Trait to check if `TypeT` provides a `parse` method.
1078 template <typename TypeT>
1079 using type_has_parse_method =
1080 decltype(TypeT::parse(std::declval<AsmParser &>()));
1081 template <typename TypeT>
1082 using detect_type_has_parse_method =
1083 llvm::is_detected<type_has_parse_method, TypeT>;
1084
1085 /// Parse a custom Type of a given type unless the next token is `#`, in
1086 /// which case the generic parser is invoked. The parsed Type is
1087 /// populated in `result`.
1088 template <typename TypeT>
1089 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1090 parseCustomTypeWithFallback(TypeT &result) {
1091 SMLoc loc = getCurrentLocation();
1092
1093 // Parse any kind of Type.
1094 Type type;
1095 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1096 result = TypeT::parse(*this);
1097 return success(!!result);
1098 }))
1099 return failure();
1100
1101 // Check for the right kind of Type.
1102 result = type.dyn_cast<TypeT>();
1103 if (!result)
1104 return emitError(loc, "invalid kind of Type specified");
1105 return success();
1106 }
1107
1108 /// SFINAE parsing method for Type that don't implement a parse method.
1109 template <typename TypeT>
1110 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1111 parseCustomTypeWithFallback(TypeT &result) {
1112 return parseType(result);
1113 }
1114
1115 /// Parse a type list.
1116 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1117 return parseCommaSeparatedList(
1118 [&]() { return parseType(result.emplace_back()); });
1119 }
1120
1121 /// Parse an arrow followed by a type list.
1122 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1123
1124 /// Parse an optional arrow followed by a type list.
1125 virtual ParseResult
1126 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1127
1128 /// Parse a colon followed by a type.
1129 virtual ParseResult parseColonType(Type &result) = 0;
1130
1131 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1132 template <typename TypeType>
1133 ParseResult parseColonType(TypeType &result) {
1134 SMLoc loc = getCurrentLocation();
1135
1136 // Parse any kind of type.
1137 Type type;
1138 if (parseColonType(type))
1139 return failure();
1140
1141 // Check for the right kind of type.
1142 result = type.dyn_cast<TypeType>();
1143 if (!result)
1144 return emitError(loc, "invalid kind of type specified");
1145
1146 return success();
1147 }
1148
1149 /// Parse a colon followed by a type list, which must have at least one type.
1150 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1151
1152 /// Parse an optional colon followed by a type list, which if present must
1153 /// have at least one type.
1154 virtual ParseResult
1155 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1156
1157 /// Parse a keyword followed by a type.
1158 ParseResult parseKeywordType(const char *keyword, Type &result) {
1159 return failure(parseKeyword(keyword) || parseType(result));
1160 }
1161
1162 /// Add the specified type to the end of the specified type list and return
1163 /// success. This is a helper designed to allow parse methods to be simple
1164 /// and chain through || operators.
1165 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1166 result.push_back(type);
1167 return success();
1168 }
1169
1170 /// Add the specified types to the end of the specified type list and return
1171 /// success. This is a helper designed to allow parse methods to be simple
1172 /// and chain through || operators.
1173 ParseResult addTypesToList(ArrayRef<Type> types,
1174 SmallVectorImpl<Type> &result) {
1175 result.append(types.begin(), types.end());
1176 return success();
1177 }
1178
1179 /// Parse a dimension list of a tensor or memref type. This populates the
1180 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1181 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1182 ///
1183 /// dimension-list ::= eps | dimension (`x` dimension)*
1184 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1185 /// dimension ::= `?` | decimal-literal
1186 ///
1187 /// When `allowDynamic` is not set, this is used to parse:
1188 ///
1189 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1190 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1191 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1192 bool allowDynamic = true,
1193 bool withTrailingX = true) = 0;
1194
1195 /// Parse an 'x' token in a dimension list, handling the case where the x is
1196 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1197 /// next token.
1198 virtual ParseResult parseXInDimensionList() = 0;
1199
1200protected:
1201 /// Parse a handle to a resource within the assembly format for the given
1202 /// dialect.
1203 virtual FailureOr<AsmDialectResourceHandle>
1204 parseResourceHandle(Dialect *dialect) = 0;
1205
1206 //===--------------------------------------------------------------------===//
1207 // Code Completion
1208 //===--------------------------------------------------------------------===//
1209
1210 /// Parse a keyword, or an empty string if the current location signals a code
1211 /// completion.
1212 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1213
1214 /// Signal the code completion of a set of expected tokens.
1215 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1216
1217private:
1218 AsmParser(const AsmParser &) = delete;
1219 void operator=(const AsmParser &) = delete;
1220};
1221
1222//===----------------------------------------------------------------------===//
1223// OpAsmParser
1224//===----------------------------------------------------------------------===//
1225
1226/// The OpAsmParser has methods for interacting with the asm parser: parsing
1227/// things from it, emitting errors etc. It has an intentionally high-level API
1228/// that is designed to reduce/constrain syntax innovation in individual
1229/// operations.
1230///
1231/// For example, consider an op like this:
1232///
1233/// %x = load %p[%1, %2] : memref<...>
1234///
1235/// The "%x = load" tokens are already parsed and therefore invisible to the
1236/// custom op parser. This can be supported by calling `parseOperandList` to
1237/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1238/// parse the indices, then calling `parseColonTypeList` to parse the result
1239/// type.
1240///
1241class OpAsmParser : public AsmParser {
1242public:
1243 using AsmParser::AsmParser;
1244 ~OpAsmParser() override;
1245
1246 /// Parse a loc(...) specifier if present, filling in result if so.
1247 /// Location for BlockArgument and Operation may be deferred with an alias, in
1248 /// which case an OpaqueLoc is set and will be resolved when parsing
1249 /// completes.
1250 virtual ParseResult
1251 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1252
1253 /// Return the name of the specified result in the specified syntax, as well
1254 /// as the sub-element in the name. It returns an empty string and ~0U for
1255 /// invalid result numbers. For example, in this operation:
1256 ///
1257 /// %x, %y:2, %z = foo.op
1258 ///
1259 /// getResultName(0) == {"x", 0 }
1260 /// getResultName(1) == {"y", 0 }
1261 /// getResultName(2) == {"y", 1 }
1262 /// getResultName(3) == {"z", 0 }
1263 /// getResultName(4) == {"", ~0U }
1264 virtual std::pair<StringRef, unsigned>
1265 getResultName(unsigned resultNo) const = 0;
1266
1267 /// Return the number of declared SSA results. This returns 4 for the foo.op
1268 /// example in the comment for `getResultName`.
1269 virtual size_t getNumResults() const = 0;
1270
1271 // These methods emit an error and return failure or success. This allows
1272 // these to be chained together into a linear sequence of || expressions in
1273 // many cases.
1274
1275 /// Parse an operation in its generic form.
1276 /// The parsed operation is parsed in the current context and inserted in the
1277 /// provided block and insertion point. The results produced by this operation
1278 /// aren't mapped to any named value in the parser. Returns nullptr on
1279 /// failure.
1280 virtual Operation *parseGenericOperation(Block *insertBlock,
1281 Block::iterator insertPt) = 0;
1282
1283 /// Parse the name of an operation, in the custom form. On success, return a
1284 /// an object of type 'OperationName'. Otherwise, failure is returned.
1285 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1286
1287 //===--------------------------------------------------------------------===//
1288 // Operand Parsing
1289 //===--------------------------------------------------------------------===//
1290
1291 /// This is the representation of an operand reference.
1292 struct UnresolvedOperand {
1293 SMLoc location; // Location of the token.
1294 StringRef name; // Value name, e.g. %42 or %abc
1295 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1296 };
1297
1298 /// Parse different components, viz., use-info of operand(s), successor(s),
1299 /// region(s), attribute(s) and function-type, of the generic form of an
1300 /// operation instance and populate the input operation-state 'result' with
1301 /// those components. If any of the components is explicitly provided, then
1302 /// skip parsing that component.
1303 virtual ParseResult parseGenericOperationAfterOpName(
1304 OperationState &result,
1305 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1306 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1307 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1308 llvm::None,
1309 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1310 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1311
1312 /// Parse a single SSA value operand name along with a result number if
1313 /// `allowResultNumber` is true.
1314 virtual ParseResult parseOperand(UnresolvedOperand &result,
1315 bool allowResultNumber = true) = 0;
1316
1317 /// Parse a single operand if present.
1318 virtual OptionalParseResult
1319 parseOptionalOperand(UnresolvedOperand &result,
1320 bool allowResultNumber = true) = 0;
1321
1322 /// Parse zero or more SSA comma-separated operand references with a specified
1323 /// surrounding delimiter, and an optional required operand count.
1324 virtual ParseResult
1325 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1326 Delimiter delimiter = Delimiter::None,
1327 bool allowResultNumber = true,
1328 int requiredOperandCount = -1) = 0;
1329
1330 /// Parse a specified number of comma separated operands.
1331 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1332 int requiredOperandCount,
1333 Delimiter delimiter = Delimiter::None) {
1334 return parseOperandList(result, delimiter,
1335 /*allowResultNumber=*/true, requiredOperandCount);
1336 }
1337
1338 /// Parse zero or more trailing SSA comma-separated trailing operand
1339 /// references with a specified surrounding delimiter, and an optional
1340 /// required operand count. A leading comma is expected before the
1341 /// operands.
1342 ParseResult
1343 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1344 Delimiter delimiter = Delimiter::None) {
1345 if (failed(parseOptionalComma()))
1346 return success(); // The comma is optional.
1347 return parseOperandList(result, delimiter);
1348 }
1349
1350 /// Resolve an operand to an SSA value, emitting an error on failure.
1351 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1352 Type type,
1353 SmallVectorImpl<Value> &result) = 0;
1354
1355 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1356 /// appending the results to the list on success. This method should be used
1357 /// when all operands have the same type.
1358 template <typename Operands = ArrayRef<UnresolvedOperand>>
1359 ParseResult resolveOperands(Operands &&operands, Type type,
1360 SmallVectorImpl<Value> &result) {
1361 for (const UnresolvedOperand &operand : operands)
1362 if (resolveOperand(operand, type, result))
1363 return failure();
1364 return success();
1365 }
1366 template <typename Operands = ArrayRef<UnresolvedOperand>>
1367 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1368 SmallVectorImpl<Value> &result) {
1369 return resolveOperands(std::forward<Operands>(operands), type, result);
1370 }
1371
1372 /// Resolve a list of operands and a list of operand types to SSA values,
1373 /// emitting an error and returning failure, or appending the results
1374 /// to the list on success.
1375 template <typename Operands = ArrayRef<UnresolvedOperand>,
1376 typename Types = ArrayRef<Type>>
1377 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1378 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1379 SmallVectorImpl<Value> &result) {
1380 size_t operandSize = std::distance(operands.begin(), operands.end());
1381 size_t typeSize = std::distance(types.begin(), types.end());
1382 if (operandSize != typeSize)
1383 return emitError(loc)
1384 << operandSize << " operands present, but expected " << typeSize;
1385
1386 for (auto [operand, type] : llvm::zip(operands, types))
1387 if (resolveOperand(operand, type, result))
1388 return failure();
1389 return success();
1390 }
1391
1392 /// Parses an affine map attribute where dims and symbols are SSA operands.
1393 /// Operand values must come from single-result sources, and be valid
1394 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1395 virtual ParseResult
1396 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1397 Attribute &map, StringRef attrName,
1398 NamedAttrList &attrs,
1399 Delimiter delimiter = Delimiter::Square) = 0;
1400
1401 /// Parses an affine expression where dims and symbols are SSA operands.
1402 /// Operand values must come from single-result sources, and be valid
1403 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1404 virtual ParseResult
1405 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1406 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1407 AffineExpr &expr) = 0;
1408
1409 //===--------------------------------------------------------------------===//
1410 // Argument Parsing
1411 //===--------------------------------------------------------------------===//
1412
1413 struct Argument {
1414 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1415 Type type; // Type.
1416 DictionaryAttr attrs; // Attributes if present.
1417 Optional<Location> sourceLoc; // Source location specifier if present.
1418 };
1419
1420 /// Parse a single argument with the following syntax:
1421 ///
1422 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1423 ///
1424 /// If `allowType` is false or `allowAttrs` are false then the respective
1425 /// parts of the grammar are not parsed.
1426 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1427 bool allowAttrs = false) = 0;
1428
1429 /// Parse a single argument if present.
1430 virtual OptionalParseResult
1431 parseOptionalArgument(Argument &result, bool allowType = false,
1432 bool allowAttrs = false) = 0;
1433
1434 /// Parse zero or more arguments with a specified surrounding delimiter.
1435 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1436 Delimiter delimiter = Delimiter::None,
1437 bool allowType = false,
1438 bool allowAttrs = false) = 0;
1439
1440 //===--------------------------------------------------------------------===//
1441 // Region Parsing
1442 //===--------------------------------------------------------------------===//
1443
1444 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1445 /// moved to the op regions after the op is created. The first block of the
1446 /// region takes 'arguments'.
1447 ///
1448 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1449 /// shadow the names of other existing SSA values defined above the region
1450 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1451 /// to operations that are 'IsolatedFromAbove'.
1452 virtual ParseResult parseRegion(Region &region,
1453 ArrayRef<Argument> arguments = {},
1454 bool enableNameShadowing = false) = 0;
1455
1456 /// Parses a region if present.
1457 virtual OptionalParseResult
1458 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1459 bool enableNameShadowing = false) = 0;
1460
1461 /// Parses a region if present. If the region is present, a new region is
1462 /// allocated and placed in `region`. If no region is present or on failure,
1463 /// `region` remains untouched.
1464 virtual OptionalParseResult
1465 parseOptionalRegion(std::unique_ptr<Region> &region,
1466 ArrayRef<Argument> arguments = {},
1467 bool enableNameShadowing = false) = 0;
1468
1469 //===--------------------------------------------------------------------===//
1470 // Successor Parsing
1471 //===--------------------------------------------------------------------===//
1472
1473 /// Parse a single operation successor.
1474 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1475
1476 /// Parse an optional operation successor.
1477 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1478
1479 /// Parse a single operation successor and its operand list.
1480 virtual ParseResult
1481 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1482
1483 //===--------------------------------------------------------------------===//
1484 // Type Parsing
1485 //===--------------------------------------------------------------------===//
1486
1487 /// Parse a list of assignments of the form
1488 /// (%x1 = %y1, %x2 = %y2, ...)
1489 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1490 SmallVectorImpl<UnresolvedOperand> &rhs) {
1491 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1492 if (!result.has_value())
1493 return emitError(getCurrentLocation(), "expected '('");
1494 return result.value();
1495 }
1496
1497 virtual OptionalParseResult
1498 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1499 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1500};
1501
1502//===--------------------------------------------------------------------===//
1503// Dialect OpAsm interface.
1504//===--------------------------------------------------------------------===//
1505
1506/// A functor used to set the name of the start of a result group of an
1507/// operation. See 'getAsmResultNames' below for more details.
1508using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1509
1510/// A functor used to set the name of blocks in regions directly nested under
1511/// an operation.
1512using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1513
1514class OpAsmDialectInterface
1515 : public DialectInterface::Base<OpAsmDialectInterface> {
1516public:
1517 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1518
1519 //===------------------------------------------------------------------===//
1520 // Aliases
1521 //===------------------------------------------------------------------===//
1522
1523 /// Holds the result of `getAlias` hook call.
1524 enum class AliasResult {
1525 /// The object (type or attribute) is not supported by the hook
1526 /// and an alias was not provided.
1527 NoAlias,
1528 /// An alias was provided, but it might be overriden by other hook.
1529 OverridableAlias,
1530 /// An alias was provided and it should be used
1531 /// (no other hooks will be checked).
1532 FinalAlias
1533 };
1534
1535 /// Hooks for getting an alias identifier alias for a given symbol, that is
1536 /// not necessarily a part of this dialect. The identifier is used in place of
1537 /// the symbol when printing textual IR. These aliases must not contain `.` or
1538 /// end with a numeric digit([0-9]+).
1539 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1540 return AliasResult::NoAlias;
1541 }
1542 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1543 return AliasResult::NoAlias;
1544 }
1545
1546 //===--------------------------------------------------------------------===//
1547 // Resources
1548 //===--------------------------------------------------------------------===//
1549
1550 /// Declare a resource with the given key, returning a handle to use for any
1551 /// references of this resource key within the IR during parsing. The result
1552 /// of `getResourceKey` on the returned handle is permitted to be different
1553 /// than `key`.
1554 virtual FailureOr<AsmDialectResourceHandle>
1555 declareResource(StringRef key) const {
1556 return failure();
1557 }
1558
1559 /// Return a key to use for the given resource. This key should uniquely
1560 /// identify this resource within the dialect.
1561 virtual std::string
1562 getResourceKey(const AsmDialectResourceHandle &handle) const {
1563 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1564)
1564 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1564)
;
1565 }
1566
1567 /// Hook for parsing resource entries. Returns failure if the entry was not
1568 /// valid, or could otherwise not be processed correctly. Any necessary errors
1569 /// can be emitted via the provided entry.
1570 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1571
1572 /// Hook for building resources to use during printing. The given `op` may be
1573 /// inspected to help determine what information to include.
1574 /// `referencedResources` contains all of the resources detected when printing
1575 /// 'op'.
1576 virtual void
1577 buildResources(Operation *op,
1578 const SetVector<AsmDialectResourceHandle> &referencedResources,
1579 AsmResourceBuilder &builder) const {}
1580};
1581} // namespace mlir
1582
1583//===--------------------------------------------------------------------===//
1584// Operation OpAsm interface.
1585//===--------------------------------------------------------------------===//
1586
1587/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1588#include "mlir/IR/OpAsmInterface.h.inc"
1589
1590namespace llvm {
1591template <>
1592struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1593 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1594 return {DenseMapInfo<void *>::getEmptyKey(),
1595 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1596 }
1597 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1598 return {DenseMapInfo<void *>::getTombstoneKey(),
1599 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1600 }
1601 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1602 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1603 }
1604 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1605 const mlir::AsmDialectResourceHandle &rhs) {
1606 return lhs.getResource() == rhs.getResource();
1607 }
1608};
1609} // namespace llvm
1610
1611#endif