Bug Summary

File:build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Warning:line 951, column 9
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name LinalgTransformOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Linalg/TransformOps -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Linalg/TransformOps -I include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/llvm/include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-10-03-140002-15933-1 -x c++ /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11#include "mlir/AsmParser/AsmParser.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Dialect/PDL/IR/PDL.h"
18#include "mlir/Dialect/PDL/IR/PDLTypes.h"
19#include "mlir/Dialect/Transform/IR/TransformDialect.h"
20#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
21#include "mlir/Interfaces/TilingInterface.h"
22#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
23#include "llvm/ADT/StringSet.h"
24
25using namespace mlir;
26using namespace mlir::linalg;
27using namespace mlir::transform;
28
29/// Extracts a vector of unsigned from an array attribute. Asserts if the
30/// attribute contains values other than intergers. May truncate.
31static SmallVector<unsigned> extractUIntArray(ArrayAttr attr) {
32 SmallVector<unsigned> result;
33 result.reserve(attr.size());
34 for (APInt value : attr.getAsValueRange<IntegerAttr>())
35 result.push_back(value.getZExtValue());
36 return result;
37}
38
39namespace {
40/// A simple pattern rewriter that implements no special logic.
41class SimpleRewriter : public PatternRewriter {
42public:
43 SimpleRewriter(MLIRContext *context) : PatternRewriter(context) {}
44};
45} // namespace
46
47/// Attempts to apply the pattern specified as template argument to the given
48/// operation. The pattern is expected to have a `returningMatchAndRewrite`
49/// function that returns the "main" result or failure. Returns failure if the
50/// pattern failed to apply. Extra arguments are forwarded to the pattern
51/// constructor.
52template <typename PatternTy, typename... Args>
53static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
54 // Check if the given operation has the type expected by the pattern.
55 using OpTy = typename llvm::function_traits<
56 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
57 auto op = dyn_cast<OpTy>(operation);
58 if (!op)
59 return failure();
60
61 // Apply the pattern directly to the op.
62 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
63 SimpleRewriter rewriter(operation->getContext());
64 rewriter.setInsertionPoint(operation);
65 auto result = pattern.returningMatchAndRewrite(op, rewriter);
66 if (failed(result))
67 return failure();
68 return cast<LinalgOp>(result->getOperation());
69}
70
71//===----------------------------------------------------------------------===//
72// DecomposeOp
73//===----------------------------------------------------------------------===//
74
75DiagnosedSilenceableFailure
76transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
77 SmallVectorImpl<Operation *> &results,
78 transform::TransformState &state) {
79 FailureOr<LinalgOp> windowedNhwc =
80 tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
81 Conv1DNwcWcfOp>>(target);
82 if (succeeded(windowedNhwc)) {
83 results.push_back(*windowedNhwc);
84 return DiagnosedSilenceableFailure(success());
85 }
86 FailureOr<LinalgOp> windowedNchw =
87 tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
88 Conv1DNcwFcwOp>>(target);
89 if (succeeded(windowedNchw)) {
90 results.push_back(*windowedNchw);
91 return DiagnosedSilenceableFailure(success());
92 }
93 FailureOr<LinalgOp> depthwise =
94 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
95 if (succeeded(depthwise)) {
96 results.push_back(*depthwise);
97 return DiagnosedSilenceableFailure(success());
98 }
99 results.assign(1, nullptr);
100 return emitDefaultSilenceableFailure(target);
101}
102
103//===----------------------------------------------------------------------===//
104// FuseOp
105//===----------------------------------------------------------------------===//
106
107/// Apply a tiling transformation to all payload ops and store both the
108/// tiled operation as well as the created tile loops.
109static LogicalResult
110applyTilingToAll(Operation *transformOp, ArrayRef<Operation *> payloadOps,
111 unsigned numLoops,
112 transform::TransformResults &transformResults,
113 function_ref<FailureOr<TiledLinalgOp>(LinalgOp)> applyFn) {
114 SmallVector<Operation *> tiledLinalgOps;
115 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
116 for (unsigned int i = 0; i < numLoops; ++i)
117 loopOps[i].reserve(payloadOps.size());
118
119 for (Operation *target : payloadOps) {
120 auto linalgOp = dyn_cast<linalg::LinalgOp>(target);
121 if (!linalgOp)
122 return transformOp->emitError("only LinalgOps are supported");
123
124 FailureOr<TiledLinalgOp> tiled = applyFn(linalgOp);
125 if (failed(tiled))
126 return failure();
127
128 tiledLinalgOps.push_back(tiled->op);
129 if (tiled->loops.size() != numLoops)
130 // Not enough loops were generated. This usually means that the input size
131 // was smaller than the tiling size.
132 // TODO: LinalgTilingPattern should return failure().
133 return failure();
134 for (unsigned int i = 0; i < numLoops; ++i)
135 loopOps[i].push_back(tiled->loops[i]);
136 }
137
138 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
139 for (unsigned int i = 0; i < numLoops; ++i)
140 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
141 return success();
142}
143
144/// Parse a tiling-like operation that returns the tiled op as well as the
145/// created tile loops. The function counts the non-zero tile sizes to compute
146/// the number of results.
147static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
148 StringRef sizesAttrName) {
149 OpAsmParser::UnresolvedOperand targetOperand;
150 SMLoc opLoc = parser.getCurrentLocation();
151 if (parser.parseOperand(targetOperand) ||
152 parser.parseOptionalAttrDict(result.attributes))
153 return failure();
154 Attribute sizesAttr = result.attributes.get(sizesAttrName);
155 if (!sizesAttr)
156 return parser.emitError(opLoc)
157 << "expected '" << sizesAttrName << "' attribute";
158 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
159 if (!sizesArrayAttr)
160 return parser.emitError(opLoc)
161 << "'" << sizesAttrName << "' attribute must be an array";
162 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
163 size_t numExpectedLoops =
164 sizesArrayAttr.size() -
165 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
166 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
167 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
168 return failure();
169 return success();
170}
171
172DiagnosedSilenceableFailure
173transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
174 mlir::transform::TransformState &state) {
175 LinalgTilingAndFusionOptions fusionOptions;
176 fusionOptions.tileSizes = extractFromI64ArrayAttr(getTileSizes());
177 fusionOptions.tileInterchange = extractFromI64ArrayAttr(getTileInterchange());
178
179 LogicalResult result = applyTilingToAll(
180 getOperation(), state.getPayloadOps(getTarget()),
181 fusionOptions.tileSizes.size() - llvm::count(fusionOptions.tileSizes, 0),
182 transformResults, [&](LinalgOp linalgOp) -> FailureOr<TiledLinalgOp> {
183 LinalgTileAndFuseTensorOpsPattern pattern(getContext(), fusionOptions);
184 SimpleRewriter rewriter(getContext());
185 rewriter.setInsertionPoint(linalgOp);
186 FailureOr<TileLoopNest> tileLoopNest =
187 pattern.returningMatchAndRewrite(linalgOp, rewriter);
188 if (failed(tileLoopNest))
189 return failure();
190
191 TiledLinalgOp tiledLinalgOp;
192 tiledLinalgOp.op = tileLoopNest->getRootOp();
193 tiledLinalgOp.loops = {tileLoopNest->getLoopOps().begin(),
194 tileLoopNest->getLoopOps().end()};
195 return tiledLinalgOp;
196 });
197 return DiagnosedSilenceableFailure(result);
198}
199
200ParseResult transform::FuseOp::parse(OpAsmParser &parser,
201 OperationState &result) {
202 return parseTileLikeOp(
203 parser, result,
204 transform::FuseOp::getTileSizesAttrName(result.name).getValue());
205}
206
207void transform::FuseOp::print(OpAsmPrinter &p) {
208 p << ' ';
209 p << getTarget();
210 p.printOptionalAttrDict((*this)->getAttrs());
211}
212
213LogicalResult transform::FuseOp::verify() {
214 SmallVector<int64_t> permutation =
215 extractFromI64ArrayAttr(getTileInterchange());
216 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
217 if (!std::is_permutation(sequence.begin(), sequence.end(),
218 permutation.begin(), permutation.end())) {
219 return emitOpError() << "expects interchange to be a permutation, found "
220 << getTileInterchange();
221 }
222 return success();
223}
224
225//===----------------------------------------------------------------------===//
226// FuseIntoContainingOp
227//===----------------------------------------------------------------------===//
228
229/// Find the first "extract" user of `producerOp` and tile it right before its
230/// use. The tiled op is fused under the `containingOp`.
231/// Return this fused op on success or nullptr if anything fails.
232static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
233 Diagnostic &diag,
234 Operation *producerOp,
235 Operation *containingOp) {
236 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
237 if (!tileableProducer) {
238 diag.attachNote(producerOp->getLoc())
239 << "producer is not a TileableInterface: " << *producerOp;
240 return nullptr;
241 }
242
243 // Search the producer slices accessed within the containing operation.
244 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
245 // evolve into an interface.
246 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
247 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
248 return sliceOp && containingOp->isProperAncestor(sliceOp);
249 });
250
251 // Find a fusion opportunity.
252 if (it == tileableProducer->getUsers().end()) {
253 diag.attachNote(tileableProducer->getLoc())
254 << "could not find fusion opportunity for: " << *tileableProducer;
255 return nullptr;
256 }
257 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
258
259 // Try to fuse the producer in-place.
260 OpBuilder::InsertionGuard guard(rewriter);
261 rewriter.setInsertionPoint(sliceOpToTile);
262
263 // Tile the producer.
264 FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
265 rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
266 sliceOpToTile.getMixedSizes());
267 if (failed(tiledProducer)) {
268 diag.attachNote(tileableProducer->getLoc())
269 << "failed to tile producer op: " << *tileableProducer;
270 return nullptr;
271 }
272
273 // Replace the extract op.
274 Operation *fusedOp = tiledProducer->getDefiningOp();
275 rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
276 return fusedOp;
277}
278
279/// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure
280/// it is exactly the `containingOp`, otherwise bail.
281/// Then, find the first "extract" user of the tied block argument and tile it
282/// right before its "extract" use. The tiled op is fused under the
283/// `containingOp`.
284/// Return this fused op on success or nullptr if anything fails.
285static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
286 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
287 Operation *containingOp) {
288
289 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
290 if (!tileableProducer) {
291 diag.attachNote(producerOp->getLoc())
292 << "producer is not a TileableInterface: " << *producerOp;
293 return nullptr;
294 }
295
296 // Ensure `tileableProducer` has exactly one destination operand that we can
297 // replace the ForeachThreadOp bbArg with.
298 auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
299 if (destinationOperands.size() != 1) {
300 diag.attachNote(tileableProducer->getLoc())
301 << "tileableProducer must have exactly one destination operand: "
302 << *tileableProducer;
303 return nullptr;
304 }
305
306 // Search the first use by a "scf::ForeachThreadOp" user.
307 scf::ForeachThreadOp foreachThreadOp;
308 auto itProducerUses =
309 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
310 foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner());
311 return foreachThreadOp;
312 });
313 // If it's not from the containing op, return.
314 if (!foreachThreadOp || foreachThreadOp != containingOp) {
315 diag.attachNote(tileableProducer->getLoc())
316 << "could not find a use by the containing op: " << *tileableProducer;
317 return nullptr;
318 }
319
320 // Search the producer slices accessed within the containing
321 // operation.
322 // TODO: Generalize to more extract/insert/parallel_insert triples.
323 // Maybe evolve into an interface.
324 OpOperand *pUse = &(*itProducerUses);
325 BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse);
326
327 // Search the producer slices accessed within the containing operation.
328 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
329 // evolve into an interface.
330 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
331 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
332 return sliceOp && containingOp->isProperAncestor(sliceOp);
333 });
334
335 // Find a fusion opportunity.
336 if (itBBArgUsers == bbArg.getUsers().end()) {
337 diag.attachNote(containingOp->getLoc())
338 << "could not find fusion opportunity for bbArg: " << bbArg;
339 return nullptr;
340 }
341 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
342
343 // Try to fuse the producer in-place.
344 OpBuilder::InsertionGuard guard(rewriter);
345 rewriter.setInsertionPoint(sliceOpToTile);
346
347 // Replace the use in the tileableProducer before tiling: clone, replace and
348 // then tile.
349 BlockAndValueMapping bvm;
350 bvm.map(destinationOperands.front(), bbArg);
351 auto tileableProducerClone =
352 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
353 auto scopeGuard =
354 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
355
356 // Tile the producer.
357 FailureOr<Value> tiledProducer =
358 tileableProducerClone.generateResultTileValue(
359 rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
360 sliceOpToTile.getMixedSizes());
361 if (failed(tiledProducer)) {
362 diag.attachNote(tileableProducer->getLoc())
363 << "failed to tile producer op: " << *tileableProducer;
364 return nullptr;
365 }
366
367 // Replace the extract op.
368 Operation *fusedOp = tiledProducer->getDefiningOp();
369 rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
370
371 // Replace the use in containingOp.
372 rewriter.updateRootInPlace(containingOp, [&]() {
373 containingOp->setOperand(pUse->getOperandNumber(),
374 destinationOperands.front());
375 });
376
377 return fusedOp;
378}
379
380static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
381 Operation *producerOp,
382 Operation *containingOp) {
383 // Gather all uses inside the containing op.
384 SmallVector<OpOperand *> uses;
385 for (OpResult result : producerOp->getOpResults()) {
386 for (OpOperand &use : result.getUses()) {
387 if (containingOp->isProperAncestor(use.getOwner())) {
388 uses.push_back(&use);
389 continue;
390 }
391 // Cannot clone and fuse if the use is by the containing op itself: fail
392 // immediately.
393 if (containingOp == use.getOwner()) {
394 diag.attachNote(producerOp->getLoc())
395 << "producer op use by containing op cannot be fused by cloning";
396 return nullptr;
397 }
398 }
399 }
400
401 // Check for a non-empty list of fusion opportunities.
402 if (uses.empty()) {
403 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
404 return nullptr;
405 }
406
407 // Clone and fuse inside the containing op.
408 Operation *fusedOp = nullptr;
409 OpOperand *use = uses.front();
410 // Parallel insert slice is not a valid clone destination.
411 // TODO: Generalize to other type of ops.
412 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp
>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"
) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 413, __extension__ __PRETTY_FUNCTION__))
413 "Parallel insert slice is not a valid clone destination")(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp
>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"
) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 413, __extension__ __PRETTY_FUNCTION__))
;
414 unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
415 OpBuilder::InsertionGuard guard(rewriter);
416 rewriter.setInsertionPoint(use->getOwner());
417 fusedOp = rewriter.clone(*producerOp);
418 rewriter.updateRootInPlace(
419 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
420
421 return fusedOp;
422}
423
424DiagnosedSilenceableFailure
425transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
426 transform::TransformState &state) {
427 SmallVector<Operation *> fusedOps;
428 ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
429 // If nothing to fuse, propagate success.
430 if (producerOps.empty()) {
431 results.set(getResult().cast<OpResult>(), SmallVector<mlir::Operation *>{});
432 return DiagnosedSilenceableFailure::success();
433 }
434 for (Operation *producerOp : producerOps) {
435 if (producerOp->getNumResults() != 1) {
436 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
437 diag << "op with != 1 results not supported";
438 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
439 }
440 }
441 ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
442 if (containingOps.size() != 1)
443 return DiagnosedSilenceableFailure(
444 this->emitOpError("requires exactly one containing_op handle (got ")
445 << containingOps.size() << ")");
446 Operation *containingOp = containingOps.front();
447
448 // Helper function to find the next producer that should be fused. Take any
449 // producer that has a use inside the containing op.
450 SmallVector<Operation *> remainingProducers(producerOps.begin(),
451 producerOps.end());
452 auto getNextProducer = [&]() -> FailureOr<Operation *> {
453 for (const auto &it : enumerate(remainingProducers)) {
454 Operation *producerOp = it.value();
455 // The containing op may be a user of producerOp: use isAncestor.
456 int64_t numUsesInContainingOp =
457 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
458 return containingOp->isAncestor(op);
459 });
460 // TODO: When resolving the TODO below (no duplicate ops), take an op
461 // that has no use among the remaining producers. This is a topological
462 // sorting.
463 if (numUsesInContainingOp > 0) {
464 if (numUsesInContainingOp == 1)
465 remainingProducers.erase(remainingProducers.begin() + it.index());
466 return producerOp;
467 }
468 }
469 return failure();
470 };
471
472 IRRewriter rewriter(getContext());
473 while (!remainingProducers.empty()) {
474 auto nextProducer = getNextProducer();
475 if (failed(nextProducer)) {
476 Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
477 diag << "could not find next producer to fuse into container";
478 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
479 }
480
481 Operation *producerOp = *nextProducer;
482
483 // Detaul diagnostic, to be complemented with more failure information.
484 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
485 diag << "could not fuse " << *producerOp << " into " << *containingOp;
486
487 // TODO: If there are multiple uses of the producer in the containing op,
488 // we currently tile/clone the op multiple times (once per use). In some
489 // cases, we can tile/clone once and reuse the value for each use.
490 // Futhermore, producers should then be traversed according to a
491 // topological sorting.
492 Operation *tiled =
493 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
494 if (tiled) {
495 fusedOps.push_back(tiled);
496 continue;
497 }
498
499 Operation *tiledContainingOpOperand =
500 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
501 rewriter, diag, producerOp, containingOp);
502 if (tiledContainingOpOperand) {
503 fusedOps.push_back(tiledContainingOpOperand);
504 continue;
505 }
506
507 Operation *cloned =
508 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
509 if (cloned) {
510 fusedOps.push_back(cloned);
511 continue;
512 }
513
514 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
515 }
516
517 results.set(getFusedOp().cast<OpResult>(), fusedOps);
518 return DiagnosedSilenceableFailure::success();
519}
520
521//===----------------------------------------------------------------------===//
522// GeneralizeOp
523//===----------------------------------------------------------------------===//
524
525DiagnosedSilenceableFailure
526transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
527 SmallVectorImpl<Operation *> &results,
528 transform::TransformState &state) {
529 // Exit early if no transformation is needed.
530 if (isa<GenericOp>(target)) {
531 results.push_back(target);
532 return DiagnosedSilenceableFailure(success());
533 }
534 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
535 if (succeeded(generic)) {
536 results.push_back(generic->getOperation());
537 return DiagnosedSilenceableFailure(success());
538 }
539 results.assign(1, nullptr);
540 return emitDefaultSilenceableFailure(target);
541}
542
543//===----------------------------------------------------------------------===//
544// InterchangeOp
545//===----------------------------------------------------------------------===//
546
547DiagnosedSilenceableFailure
548transform::InterchangeOp::applyToOne(linalg::GenericOp target,
549 SmallVectorImpl<Operation *> &results,
550 transform::TransformState &state) {
551 SmallVector<unsigned> interchangeVector =
552 extractUIntArray(getIteratorInterchange());
553 // Exit early if no transformation is needed.
554 if (interchangeVector.empty()) {
555 results.push_back(target);
556 return DiagnosedSilenceableFailure(success());
557 }
558 SimpleRewriter rewriter(target->getContext());
559 FailureOr<GenericOp> res =
560 interchangeGenericOp(rewriter, target, interchangeVector);
561 if (failed(res))
562 return DiagnosedSilenceableFailure::definiteFailure();
563 results.push_back(res->getOperation());
564 return DiagnosedSilenceableFailure(success());
565}
566
567LogicalResult transform::InterchangeOp::verify() {
568 SmallVector<unsigned> permutation =
569 extractUIntArray(getIteratorInterchange());
570 auto sequence = llvm::to_vector(llvm::seq<unsigned>(0, permutation.size()));
571 if (!std::is_permutation(sequence.begin(), sequence.end(),
572 permutation.begin(), permutation.end())) {
573 return emitOpError()
574 << "expects iterator_interchange to be a permutation, found "
575 << getIteratorInterchange();
576 }
577 return success();
578}
579
580//===---------------------------------------------------------------------===//
581// MatchOp
582//===---------------------------------------------------------------------===//
583
584DiagnosedSilenceableFailure
585transform::MatchOp::apply(transform::TransformResults &results,
586 transform::TransformState &state) {
587 llvm::StringSet<> strs;
588 if (getOps().has_value())
589 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
590 getOps()->getAsValueRange<StringAttr>().end());
591
592 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
593 if (payloadOps.size() != 1)
594 return DiagnosedSilenceableFailure(
595 this->emitOpError("requires exactly one target handle"));
596
597 SmallVector<Operation *> res;
598 auto matchFun = [&](Operation *op) {
599 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
600 return;
601
602 // Interfaces cannot be matched by name, just by ID.
603 // So we specifically encode the interfaces we care about for this op.
604 if (getInterface().has_value()) {
605 auto iface = getInterface().value();
606 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
607 !isa<linalg::LinalgOp>(op))
608 return;
609 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
610 isa<TilingInterface>(op))
611 return;
612 }
613
614 // Check if all specified attributes match.
615 if (getOpAttrs().has_value()) {
616 DictionaryAttr opAttrs = getOpAttrs().value();
617 for (NamedAttribute attr : opAttrs) {
618 if (attr.getName() == getInterfaceAttrName() ||
619 attr.getName() == getOpsAttrName())
620 continue;
621 if (!op->hasAttr(attr.getName()))
622 return;
623 if (op->getAttr(attr.getName()) != attr.getValue())
624 return;
625 }
626 }
627
628 if (getFilterResultType().has_value()) {
629 Type t = getFilterResultType().value();
630 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
631 return;
632 }
633
634 // All constraints are satisfied.
635 res.push_back(op);
636 return;
637 };
638
639 payloadOps.front()->walk(matchFun);
640 results.set(getResult().cast<OpResult>(), res);
641 return DiagnosedSilenceableFailure(success());
642}
643
644//===---------------------------------------------------------------------===//
645// MultiTileSizesOp
646//===---------------------------------------------------------------------===//
647
648DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
649 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
650 OpBuilder builder(target.getContext());
651 builder.setInsertionPoint(target);
652 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
653 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
654 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
655 builder, target, getDimension(), targetSize, divisor);
656 if (failed(spec)) {
657 return emitSilenceableError() << "could not generate tile size computation";
658 }
659
660 AffineExpr s0 = builder.getAffineSymbolExpr(0);
661 AffineExpr s1 = builder.getAffineSymbolExpr(1);
662 Operation *splitPoint =
663 makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
664 {spec->lowTileSize, spec->lowTripCount});
665 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
666 Operation *highTileSize = spec->highTileSize.getDefiningOp();
667 assert(lowTileSize && highTileSize && splitPoint &&(static_cast <bool> (lowTileSize && highTileSize
&& splitPoint && "tile sizes are not produced by operations"
) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 668, __extension__ __PRETTY_FUNCTION__))
668 "tile sizes are not produced by operations")(static_cast <bool> (lowTileSize && highTileSize
&& splitPoint && "tile sizes are not produced by operations"
) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 668, __extension__ __PRETTY_FUNCTION__))
;
669 results.reserve(results.size() + 3);
670 results.push_back(lowTileSize);
671 results.push_back(highTileSize);
672 results.push_back(splitPoint);
673 return DiagnosedSilenceableFailure::success();
674}
675
676void transform::MultiTileSizesOp::getEffects(
677 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
678 onlyReadsHandle(getTarget(), effects);
679 producesHandle(getResults(), effects);
680 modifiesPayload(effects);
681}
682
683//===---------------------------------------------------------------------===//
684// PadOp
685//===---------------------------------------------------------------------===//
686
687DiagnosedSilenceableFailure
688transform::PadOp::applyToOne(linalg::LinalgOp target,
689 SmallVectorImpl<Operation *> &results,
690 transform::TransformState &state) {
691 // Convert the integer packing flags to booleans.
692 SmallVector<bool> packPaddings;
693 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
694 packPaddings.push_back(static_cast<bool>(packPadding));
695
696 // Convert the padding values to attributes.
697 SmallVector<Attribute> paddingValues;
698 for (auto const &it :
699 llvm::zip(getPaddingValues(), target->getOperandTypes())) {
700 auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
701 if (!attr) {
702 emitOpError("expects padding values to be typed attributes");
703 return DiagnosedSilenceableFailure::definiteFailure();
704 }
705 Type elementType = getElementTypeOrSelf(std::get<1>(it));
706 // Try to parse string attributes to obtain an attribute of element type.
707 if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
708 paddingValues.push_back(
709 parseAttribute(attr.cast<StringAttr>(), elementType));
710 if (!paddingValues.back()) {
711 auto diag = this->emitOpError("expects a padding that parses to ")
712 << elementType << ", got " << std::get<0>(it);
713 diag.attachNote(target.getLoc()) << "when applied to this op";
714 return DiagnosedSilenceableFailure::definiteFailure();
715 }
716 continue;
717 }
718 // Otherwise, add the attribute directly.
719 if (attr.getType() != elementType) {
720 auto diag = this->emitOpError("expects a padding value of type ")
721 << elementType << ", got " << attr;
722 diag.attachNote(target.getLoc()) << "when applied to this op";
723 return DiagnosedSilenceableFailure::definiteFailure();
724 }
725 paddingValues.push_back(attr);
726 }
727
728 // Extract the transpose vectors.
729 SmallVector<SmallVector<int64_t>> transposePaddings;
730 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
731 transposePaddings.push_back(
732 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
733
734 LinalgPaddingOptions paddingOptions;
735 paddingOptions.setPaddingValues(paddingValues);
736 paddingOptions.setPaddingDimensions(
737 extractFromI64ArrayAttr(getPaddingDimensions()));
738 paddingOptions.setPackPaddings(packPaddings);
739 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
740 paddingOptions.setTransposePaddings(transposePaddings);
741
742 FailureOr<LinalgOp> result =
743 tryApply<LinalgPaddingPattern>(target, paddingOptions);
744 if (succeeded(result)) {
745 results.push_back(result->getOperation());
746 return DiagnosedSilenceableFailure(success());
747 }
748
749 results.assign(1, nullptr);
750 return emitDefaultSilenceableFailure(target);
751}
752
753LogicalResult transform::PadOp::verify() {
754 SmallVector<int64_t> packPaddings =
755 extractFromI64ArrayAttr(getPackPaddings());
756 if (any_of(packPaddings, [](int64_t packPadding) {
757 return packPadding != 0 && packPadding != 1;
758 })) {
759 return emitOpError()
760 << "expects pack_paddings to contain booleans (0/1), found "
761 << getPackPaddings();
762 }
763
764 SmallVector<int64_t> paddingDimensions =
765 extractFromI64ArrayAttr(getPaddingDimensions());
766 if (any_of(paddingDimensions,
767 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
768 return emitOpError() << "expects padding_dimensions to contain positive "
769 "integers, found "
770 << getPaddingDimensions();
771 }
772
773 SmallVector<int64_t> hoistPaddings =
774 extractFromI64ArrayAttr(getHoistPaddings());
775 if (any_of(hoistPaddings,
776 [](int64_t hoistPadding) { return hoistPadding < 0; })) {
777 return emitOpError()
778 << "expects hoist_paddings to contain positive integers, found "
779 << getHoistPaddings();
780 }
781
782 ArrayAttr transposes = getTransposePaddings();
783 for (Attribute attr : transposes) {
784 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
785 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
786 if (!std::is_permutation(sequence.begin(), sequence.end(),
787 transpose.begin(), transpose.end())) {
788 return emitOpError()
789 << "expects transpose_paddings to be a permutation, found "
790 << attr;
791 }
792 }
793 return success();
794}
795
796//===----------------------------------------------------------------------===//
797// PromoteOp
798//===----------------------------------------------------------------------===//
799
800DiagnosedSilenceableFailure
801transform::PromoteOp::applyToOne(linalg::LinalgOp target,
802 SmallVectorImpl<Operation *> &results,
803 transform::TransformState &state) {
804 LinalgPromotionOptions promotionOptions;
805 if (!getOperandsToPromote().empty())
806 promotionOptions = promotionOptions.setOperandsToPromote(
807 extractFromI64ArrayAttr(getOperandsToPromote()));
808 if (getUseFullTilesByDefault())
809 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
810 getUseFullTilesByDefault());
811 if (getUseAlloca())
812 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
813 if (!getUseFullTileBuffers().empty())
814 promotionOptions = promotionOptions.setUseFullTileBuffers(
815 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
816 if (getAlignment().has_value())
817 promotionOptions = promotionOptions.setAlignment(*getAlignment());
818
819 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
820 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
821
822 SimpleRewriter rewriter(target->getContext());
823 rewriter.setInsertionPoint(target);
824 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
825 if (failed(res))
826 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
827 results.push_back(target);
828 return DiagnosedSilenceableFailure(success());
829}
830
831//===----------------------------------------------------------------------===//
832// ScalarizeOp
833//===----------------------------------------------------------------------===//
834
835DiagnosedSilenceableFailure
836transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
837 SmallVectorImpl<Operation *> &results,
838 transform::TransformState &state) {
839 LinalgTilingOptions tilingOptions;
840 tilingOptions.scalarizeDynamicDims();
841 // Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
842 // tile sizes and asserts that it is not already set.
843 SmallVector<int64_t> emptyTileSizes;
844 LinalgTilingPattern pattern(getContext(), tilingOptions);
845 SimpleRewriter rewriter(getContext());
846 rewriter.setInsertionPoint(target);
847 FailureOr<TiledLinalgOp> result =
848 pattern.returningMatchAndRewrite(target, rewriter);
849 if (failed(result))
850 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
851
852 results.push_back(result->op);
853 return DiagnosedSilenceableFailure(success());
854}
855
856//===----------------------------------------------------------------------===//
857// SplitOp
858//===----------------------------------------------------------------------===//
859
860DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
861 TransformState &state) {
862 // Collect the dynamic split points if provided.
863 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
864 SimpleRewriter rewriter(getContext());
865 SmallVector<OpFoldResult> splitPoints;
866 splitPoints.reserve(payload.size());
867 if (getDynamicSplitPoint()) {
868 auto diag = DiagnosedSilenceableFailure::success();
869 splitPoints = llvm::to_vector(llvm::map_range(
870 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
871 if (op->getNumResults() != 1 ||
872 !op->getResult(0).getType().isIndex()) {
873 diag = emitSilenceableError()
874 << "expected dynamic split point handle to point to a "
875 "single-result index-typed op";
876 diag.attachNote(op->getLoc()) << "dynamic split point";
877 }
878 return OpFoldResult(op->getResult(0));
879 }));
880 if (!diag.succeeded())
881 return diag;
882
883 if (splitPoints.size() != payload.size()) {
884 emitError() << "expected the dynamic split point handle to point to as "
885 "many operations ("
886 << splitPoints.size() << ") as the target handle ("
887 << payload.size() << ")";
888 return DiagnosedSilenceableFailure::definiteFailure();
889 }
890 } else {
891 splitPoints.resize(payload.size(),
892 rewriter.getIndexAttr(getStaticSplitPoint()));
893 }
894
895 // Split each target operation.
896 SmallVector<Operation *> first, second;
897 for (const auto &pair : llvm::zip(payload, splitPoints)) {
898 Operation *target = std::get<0>(pair);
899 auto linalgOp = dyn_cast<LinalgOp>(target);
900 if (!linalgOp) {
901 auto diag = emitSilenceableError() << "only applies to structured ops";
902 diag.attachNote(target->getLoc()) << "target op";
903 return diag;
904 }
905
906 if (getDimension() >= linalgOp.getNumLoops()) {
907 auto diag = emitSilenceableError() << "dimension " << getDimension()
908 << " does not exist in target op";
909 diag.attachNote(target->getLoc()) << "target op";
910 return diag;
911 }
912
913 rewriter.setInsertionPoint(linalgOp);
914 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
915 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
916 getDimension(), std::get<1>(pair));
917 }
918
919 results.set(getFirst().cast<OpResult>(), first);
920 results.set(getSecond().cast<OpResult>(), second);
921 return DiagnosedSilenceableFailure::success();
922}
923
924void SplitOp::getEffects(
925 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
926 consumesHandle(getTarget(), effects);
927 if (getDynamicSplitPoint())
928 onlyReadsHandle(getDynamicSplitPoint(), effects);
929 producesHandle(getResults(), effects);
930 modifiesPayload(effects);
931}
932
933ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
934 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
935 IntegerAttr staticSplitPoint;
936 auto pdlOperationType =
937 pdl::OperationType::get(parser.getBuilder().getContext());
938 if (parser.parseOperand(target) ||
1
Taking false branch
939 parser.resolveOperand(target, pdlOperationType, result.operands) ||
940 parser.parseKeyword("after"))
941 return failure();
942
943 OptionalParseResult dynamicPointParseResult =
944 parser.parseOptionalOperand(dynamicSplitPoint);
945 if (!dynamicPointParseResult.has_value()) {
2
Assuming the condition is true
3
Taking true branch
946 int64_t staticSplitPointValue;
4
'staticSplitPointValue' declared without an initial value
947 if (failed(parser.parseInteger(staticSplitPointValue)))
5
Calling 'AsmParser::parseInteger'
13
Returning from 'AsmParser::parseInteger'
14
Taking false branch
948 return failure();
949
950 staticSplitPoint =
951 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
15
1st function call argument is an uninitialized value
952 } else {
953 if (failed(*dynamicPointParseResult) ||
954 parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
955 result.operands)) {
956 return failure();
957 }
958
959 staticSplitPoint =
960 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamicSize);
961 }
962
963 result.addAttribute(
964 SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
965 staticSplitPoint);
966 if (failed(parser.parseOptionalAttrDict(result.attributes)))
967 return failure();
968
969 result.addTypes({pdlOperationType, pdlOperationType});
970 return success();
971}
972
973void SplitOp::print(OpAsmPrinter &printer) {
974 printer << " " << getTarget() << " after ";
975 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
976 if (staticSplitSize != ShapedType::kDynamicSize)
977 printer << staticSplitSize;
978 else
979 printer << getDynamicSplitPoint();
980 printer << " ";
981 printer.printOptionalAttrDict(getOperation()->getAttrs(),
982 {getStaticSplitPointAttrName()});
983}
984
985LogicalResult SplitOp::verify() {
986 if ((static_cast<int64_t>(getStaticSplitPoint()) !=
987 ShapedType::kDynamicSize) ^
988 (getDynamicSplitPoint() == nullptr)) {
989 return emitOpError() << "expects either a dynamic or a static split "
990 "point to be provided";
991 }
992 return success();
993}
994
995//===----------------------------------------------------------------------===//
996// SplitReductionOp
997//===----------------------------------------------------------------------===//
998
999DiagnosedSilenceableFailure
1000transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
1001 SmallVectorImpl<Operation *> &results,
1002 transform::TransformState &state) {
1003 ControlSplitReductionFn splitFn = [&](LinalgOp) {
1004 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
1005 unsigned(getInsertSplitDimension()),
1006 /*innerParallel=*/false};
1007 };
1008 SimpleRewriter rewriter(getContext());
1009 rewriter.setInsertionPoint(target);
1010 FailureOr<SplitReductionResult> splitResult =
1011 (getUseScalingAlgorithm())
1012 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
1013 : splitReduction(rewriter, target, splitFn, getUseAlloc());
1014 if (failed(splitResult))
1015 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1016
1017 results.push_back(splitResult->initOrAlloc);
1018 results.push_back(splitResult->fillOp);
1019 results.push_back(splitResult->splitLinalgOp);
1020 results.push_back(splitResult->resultCombiningLinalgOp);
1021 return DiagnosedSilenceableFailure(success());
1022}
1023
1024//===----------------------------------------------------------------------===//
1025// TileOp
1026//===----------------------------------------------------------------------===//
1027
1028DiagnosedSilenceableFailure
1029transform::TileOp::apply(TransformResults &transformResults,
1030 TransformState &state) {
1031 LinalgTilingOptions tilingOptions;
1032 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
1033
1034 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1035 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
1036 dynamicSizeProducers.reserve(getDynamicSizes().size());
1037 for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
1038 dynamicSizeProducers.push_back(
1039 state.getPayloadOps(dynamicSizeProducerHandle));
1040
1041 if (dynamicSizeProducers.back().size() != targets.size()) {
1042 DiagnosedSilenceableFailure diag =
1043 emitSilenceableError()
1044 << "expected as many dynamic size-producing operations ("
1045 << dynamicSizeProducers.back().size() << ") as target ops ("
1046 << targets.size() << ")";
1047 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1048 return diag;
1049 }
1050
1051 for (Operation *op : dynamicSizeProducers.back()) {
1052 if (op->getNumResults() == 1 &&
1053 op->getResult(0).getType().isa<IndexType>())
1054 continue;
1055 DiagnosedSilenceableFailure diag =
1056 emitSilenceableError() << "expected sizes to be produced by ops "
1057 "with a single index-type result";
1058 diag.attachNote(op->getLoc()) << "size producer op";
1059 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1060 return diag;
1061 }
1062 }
1063
1064 SmallVector<Operation *> tiled;
1065 SmallVector<SmallVector<Operation *, 4>, 4> loops;
1066 loops.resize(getLoops().size());
1067 for (auto &en : llvm::enumerate(targets)) {
1068 auto linalgOp = dyn_cast<LinalgOp>(en.value());
1069 if (!linalgOp) {
1070 DiagnosedSilenceableFailure diag = emitSilenceableError()
1071 << "only linalg ops are supported";
1072 diag.attachNote(en.value()->getLoc()) << "target op";
1073 return diag;
1074 }
1075
1076 unsigned index = en.index();
1077 if (!tileSizes.empty()) {
1078 tilingOptions.setTileSizeComputationFunction(
1079 [&, index](OpBuilder &b, Operation *) {
1080 SmallVector<Value, 4> sizes;
1081 sizes.reserve(tileSizes.size());
1082 unsigned dynamicIdx = 0;
1083 for (OpFoldResult ofr : getMixedSizes()) {
1084 if (auto attr = ofr.dyn_cast<Attribute>()) {
1085 sizes.push_back(b.create<arith::ConstantIndexOp>(
1086 getLoc(), attr.cast<IntegerAttr>().getInt()));
1087 } else {
1088 sizes.push_back(
1089 dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
1090 }
1091 }
1092 return sizes;
1093 });
1094 }
1095
1096 tilingOptions.setInterchange(extractUIntArray(getInterchange()));
1097 LinalgTilingPattern pattern(getContext(), tilingOptions);
1098 SimpleRewriter rewriter(linalgOp.getContext());
1099 FailureOr<TiledLinalgOp> tiledOp =
1100 pattern.returningMatchAndRewrite(linalgOp, rewriter);
1101 if (failed(tiledOp))
1102 return DiagnosedSilenceableFailure::definiteFailure();
1103
1104 tiled.push_back(tiledOp->op);
1105 for (const auto &en2 : llvm::enumerate(tiledOp->loops))
1106 loops[en2.index()].push_back(en2.value());
1107 }
1108
1109 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
1110 for (const auto &en : llvm::enumerate(loops))
1111 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
1112
1113 return DiagnosedSilenceableFailure::success();
1114}
1115
1116SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
1117 ValueRange dynamic = getDynamicSizes();
1118 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getStaticSizes());
1119 SmallVector<OpFoldResult> results;
1120 results.reserve(tileSizes.size());
1121 unsigned dynamicPos = 0;
1122 Builder builder(getContext());
1123 for (int64_t size : tileSizes) {
1124 if (size == ShapedType::kDynamicSize) {
1125 results.push_back(dynamic[dynamicPos++]);
1126 } else {
1127 results.push_back(builder.getIndexAttr(size));
1128 }
1129 }
1130 return results;
1131}
1132
1133ParseResult transform::TileOp::parse(OpAsmParser &parser,
1134 OperationState &result) {
1135 OpAsmParser::UnresolvedOperand target;
1136 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
1137 ArrayAttr staticSizes;
1138 auto pdlOperationType = pdl::OperationType::get(parser.getContext());
1139 if (parser.parseOperand(target) ||
1140 parser.resolveOperand(target, pdlOperationType, result.operands) ||
1141 parseDynamicIndexList(parser, dynamicSizes, staticSizes,
1142 ShapedType::kDynamicSize) ||
1143 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands) ||
1144 parser.parseOptionalAttrDict(result.attributes))
1145 return ParseResult::failure();
1146
1147 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
1148 size_t numExpectedLoops =
1149 staticSizes.size() - llvm::count(extractFromI64ArrayAttr(staticSizes), 0);
1150 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1151 return success();
1152}
1153
1154void TileOp::print(OpAsmPrinter &p) {
1155 p << ' ' << getTarget();
1156 printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes(),
1157 ShapedType::kDynamicSize);
1158 p.printOptionalAttrDict((*this)->getAttrs(), {getStaticSizesAttrName()});
1159}
1160
1161void transform::TileOp::getEffects(
1162 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1163 consumesHandle(getTarget(), effects);
1164 onlyReadsHandle(getDynamicSizes(), effects);
1165 producesHandle(getTiledLinalgOp(), effects);
1166 producesHandle(getLoops(), effects);
1167 modifiesPayload(effects);
1168}
1169
1170//===----------------------------------------------------------------------===//
1171// MapNestedForeachThreadToGpuThreads
1172//===----------------------------------------------------------------------===//
1173
1174/// Searches `scf.foreach_thread` ops nested under `target` and maps each such
1175/// op to GPU threads. Mapping is one-to-one and the induction variables of
1176/// `scf.foreach_thread` are rewritten to gpu.thread_id according to the
1177/// thread_dim_apping attribute. Sibling `scf.foreach_thread` are supported in
1178/// which case, the union of the number of threads is computed and may result in
1179/// predication. Dynamic, `scf.foreach_thread` trip counts are currently not
1180/// supported. Dynamic block dim sizes are currently not supported.
1181static FailureOr<SmallVector<OpFoldResult>> rewriteOneForeachThreadToGpuThreads(
1182 RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
1183 const SmallVector<int64_t> &globalBlockDims, bool syncAfterDistribute) {
1184 if (foreachThreadOp.getNumResults() > 0)
1185 return foreachThreadOp->emitError(
1186 "only bufferized scf.foreach_thread lowers to gpu.thread");
1187 if (foreachThreadOp.getNumThreads().size() > 3)
1188 return foreachThreadOp->emitError(
1189 "scf.foreach_thread with rank > 3 does not lower to gpu.thread");
1190
1191 auto potentialBlockDim = foreachThreadOp.getPermutedNumThreads(rewriter);
1192 if (failed(potentialBlockDim) ||
1193 llvm::any_of(*potentialBlockDim, [](OpFoldResult ofr) {
1194 return !getConstantIntValue(ofr).has_value();
1195 }))
1196 return foreachThreadOp->emitError("unsupported dynamic blockdim size");
1197
1198 SmallVector<int64_t> blockDim =
1199 llvm::to_vector(llvm::map_range(*potentialBlockDim, [](OpFoldResult ofr) {
1200 return getConstantIntValue(ofr).value();
1201 }));
1202
1203 // Step 1. Create the gpu.thread ops
1204 Location loc = foreachThreadOp.getLoc();
1205 IndexType indexType = rewriter.getIndexType();
1206
1207 SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1208 gpu::Dimension::z};
1209 SmallVector<Value> threadOps;
1210 for (int64_t idx : llvm::seq<int64_t>(0, blockDim.size())) {
1211 threadOps.push_back(
1212 rewriter.create<gpu::ThreadIdOp>(loc, indexType, gpuDims[idx]));
1213 }
1214 // Step 2. Maybe create conditionals to predicate the region.
1215 Value predicate;
1216 for (auto [threadId, blockDim, globalBlockDim] :
1217 llvm::zip(threadOps, blockDim, globalBlockDims)) {
1218 if (blockDim > globalBlockDim) {
1219 return foreachThreadOp.emitOpError("blockDim size overflow: ")
1220 << blockDim << " > " << globalBlockDim;
1221 }
1222 if (blockDim == globalBlockDim)
1223 continue;
1224 Value tmpPredicate = rewriter.create<arith::CmpIOp>(
1225 loc, arith::CmpIPredicate::ult, threadId,
1226 rewriter.create<arith::ConstantIndexOp>(loc, blockDim));
1227 predicate =
1228 predicate ? rewriter.create<arith::AndIOp>(loc, predicate, tmpPredicate)
1229 : tmpPredicate;
1230 }
1231
1232 // Step 3. Move the body of foreachThreadOp.
1233 // Erase the terminator first, it will not be used.
1234 rewriter.eraseOp(foreachThreadOp.getTerminator());
1235 Block *targetBlock;
1236 Block::iterator insertionPoint;
1237 if (predicate) {
1238 // Step 3.a. If predicated, move at the beginning.
1239 auto ifOp =
1240 rewriter.create<scf::IfOp>(loc, predicate, /*withElseRegion=*/false);
1241 targetBlock = ifOp.thenBlock();
1242 insertionPoint = ifOp.thenBlock()->begin();
1243 } else {
1244 // Step 3.a. Otherwise, move inline just before foreachThreadOp.
1245 targetBlock = foreachThreadOp->getBlock();
1246 insertionPoint = Block::iterator(foreachThreadOp);
1247 }
1248 Block &sourceBlock = foreachThreadOp.getRegion().front();
1249 targetBlock->getOperations().splice(insertionPoint,
1250 sourceBlock.getOperations());
1251
1252 // Step 4. RAUW thread indices to thread ops.
1253 SmallVector<Value> threadIndices =
1254 *foreachThreadOp.getPermutedThreadIndices();
1255 for (auto it : llvm::zip(threadIndices, threadOps)) {
1256 Value val = std::get<0>(it);
1257 if (!val)
1258 continue;
1259 for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
1260 rewriter.updateRootInPlace(
1261 user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); });
1262 }
1263 }
1264
1265 // Step 5. syncthreads.
1266 // TODO: Need warpsync
1267 if (syncAfterDistribute)
1268 rewriter.create<gpu::BarrierOp>(loc);
1269
1270 // Step 6. Erase old op.
1271 rewriter.eraseOp(foreachThreadOp);
1272
1273 return *potentialBlockDim;
1274}
1275
1276mlir::WalkResult mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
1277 RewriterBase &rewriter, Operation *target,
1278 const SmallVector<int64_t> &blockDim, bool syncAfterDistribute) {
1279 auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
1280 rewriter.setInsertionPoint(foreachThreadOp);
1281 if (failed(rewriteOneForeachThreadToGpuThreads(
1282 rewriter, foreachThreadOp, blockDim, syncAfterDistribute)))
1283 return WalkResult::interrupt();
1284 return WalkResult::advance();
1285 });
1286 return walkResult;
1287}
1288
1289static LogicalResult
1290checkGpuLimits(Optional<int64_t> gridDimX, Optional<int64_t> gridDimY,
1291 Optional<int64_t> gridDimZ, Optional<int64_t> blockDimX,
1292 Optional<int64_t> blockDimY, Optional<int64_t> blockDimZ) {
1293 // TODO The limits should live in the gpu dialect, but it's not like that
1294 // right now. Read them in the common gpu dialect
1295 if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
1296 1024 ||
1297 gridDimY.value_or(1) > 65535 || gridDimZ.value_or(1) > 65535 ||
1298 gridDimX.value_or(1) > 2147483647)
1299 return failure();
1300 return success();
1301}
1302
1303/// Alter grid or block dimensions of the given kernel
1304static LogicalResult alterGpuLaunch(SimpleRewriter &rewriter,
1305 gpu::LaunchOp gpuLaunch,
1306 Optional<int64_t> gridDimX = llvm::None,
1307 Optional<int64_t> gridDimY = llvm::None,
1308 Optional<int64_t> gridDimZ = llvm::None,
1309 Optional<int64_t> blockDimX = llvm::None,
1310 Optional<int64_t> blockDimY = llvm::None,
1311 Optional<int64_t> blockDimZ = llvm::None) {
1312 if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
1313 blockDimZ))) {
1314 gpuLaunch->emitError(
1315 "Requested kernel thread configuration is larger than the limits");
1316 return failure();
1317 }
1318
1319 gpu::KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
1320 OpBuilder::InsertionGuard guard(rewriter);
1321 rewriter.setInsertionPointAfterValue(currentBlockdim.x);
1322 auto createConstValue = [&](int dim) {
1323 return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
1324 dim);
1325 };
1326
1327 if (gridDimX.has_value())
1328 gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
1329 if (gridDimY.has_value())
1330 gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
1331 if (gridDimZ.has_value())
1332 gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
1333 if (blockDimX.has_value())
1334 gpuLaunch.getBlockSizeXMutable().assign(
1335 createConstValue(blockDimX.value()));
1336 if (blockDimY.has_value())
1337 gpuLaunch.getBlockSizeYMutable().assign(
1338 createConstValue(blockDimY.value()));
1339 if (blockDimZ.has_value())
1340 gpuLaunch.getBlockSizeZMutable().assign(
1341 createConstValue(blockDimZ.value()));
1342 return success();
1343}
1344
1345DiagnosedSilenceableFailure
1346transform::MapNestedForeachThreadToGpuThreads::applyToOne(
1347 Operation *target, SmallVectorImpl<Operation *> &results,
1348 transform::TransformState &state) {
1349
1350 gpu::LaunchOp gpuLaunch = dyn_cast<gpu::LaunchOp>(target);
1351 if (!gpuLaunch) {
1352 target->emitError("Given target is not gpu.launch");
1353 return DiagnosedSilenceableFailure::definiteFailure();
1354 }
1355
1356 SmallVector<int64_t> blockDim = extractFromI64ArrayAttr(getBlockDim());
1357 blockDim.resize(/*size=*/3, /*value=*/1);
1358 SimpleRewriter rewriter(getContext());
1359 rewriter.setInsertionPoint(target);
1360 auto walkResult = mlir::linalg::rewriteMapNestedForeachThreadToGpuThreads(
1361 rewriter, target, blockDim, getSyncAfterDistribute());
1362 if (walkResult.wasInterrupted())
1363 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1364
1365 LogicalResult result =
1366 alterGpuLaunch(rewriter, gpuLaunch, llvm::None, llvm::None, llvm::None,
1367 blockDim[0], blockDim[1], blockDim[2]);
1368 if (failed(result))
1369 return DiagnosedSilenceableFailure::definiteFailure();
1370
1371 results.assign({target});
1372 return DiagnosedSilenceableFailure(success());
1373}
1374
1375//===----------------------------------------------------------------------===//
1376// MapNestedForeachThreadToGpuBlocks
1377//===----------------------------------------------------------------------===//
1378
1379LogicalResult mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
1380 RewriterBase &rewriter, scf::ForeachThreadOp foreachThreadOp,
1381 function_ref<void(RewriterBase &, scf::ForeachThreadOp,
1382 SmallVector<Value> &)>
1383 blockIdGenerator,
1384 SmallVector<int64_t> &gridDims) {
1385 if (foreachThreadOp.getNumResults() > 0)
1386 return foreachThreadOp->emitError(
1387 "only bufferized scf.foreach_thread lowers to gpu.block_id");
1388 if (foreachThreadOp.getNumThreads().size() > 3)
1389 return foreachThreadOp->emitError(
1390 "scf.foreach_thread with rank > 3 does not lower to gpu.block_id");
1391
1392 // Step 0. Outline the compute workload region and set up the workload
1393 // operands.
1394 auto potentialGridDim = foreachThreadOp.getPermutedNumThreads(rewriter);
1395 if (failed(potentialGridDim) ||
1396 llvm::any_of(*potentialGridDim, [](OpFoldResult ofr) {
1397 return !getConstantIntValue(ofr).has_value();
1398 }))
1399 return foreachThreadOp->emitError("unsupported dynamic gridDim");
1400
1401 for (OpFoldResult ofr : *potentialGridDim)
1402 gridDims.push_back(getConstantIntValue(ofr).value());
1403
1404 SmallVector<Value> blockOps;
1405 blockIdGenerator(rewriter, foreachThreadOp, blockOps);
1406
1407 // Step 1. Move the body of foreachThreadOp.
1408 // Erase the terminator first, it will not be used since we are on buffers.
1409 rewriter.eraseOp(foreachThreadOp.getTerminator());
1410 Block *targetBlock = foreachThreadOp->getBlock();
1411 Block::iterator insertionPoint = Block::iterator(foreachThreadOp);
1412 Block &sourceBlock = foreachThreadOp.getRegion().front();
1413 targetBlock->getOperations().splice(insertionPoint,
1414 sourceBlock.getOperations());
1415
1416 // Step 2. RAUW thread indices to thread ops.
1417 SmallVector<Value> threadIndices =
1418 *foreachThreadOp.getPermutedThreadIndices();
1419 assert(blockOps.size() == 3 && "3 block id ops are required")(static_cast <bool> (blockOps.size() == 3 && "3 block id ops are required"
) ? void (0) : __assert_fail ("blockOps.size() == 3 && \"3 block id ops are required\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 1419, __extension__ __PRETTY_FUNCTION__))
;
1420 for (auto it : llvm::zip(threadIndices, blockOps)) {
1421 Value val = std::get<0>(it);
1422 if (!val)
1423 continue;
1424 for (Operation *user : llvm::make_early_inc_range(val.getUsers())) {
1425 rewriter.updateRootInPlace(
1426 user, [&]() { user->replaceUsesOfWith(val, std::get<1>(it)); });
1427 }
1428 }
1429
1430 // Step 3. Erase old op.
1431 rewriter.eraseOp(foreachThreadOp);
1432
1433 return success();
1434}
1435
1436FailureOr<scf::ForeachThreadOp>
1437mlir::linalg::findTopLevelForeachThreadOp(Operation *target) {
1438 scf::ForeachThreadOp topLevelForeachThreadOp;
1439 auto walkResult = target->walk([&](scf::ForeachThreadOp foreachThreadOp) {
1440 if (foreachThreadOp->getParentOfType<scf::ForeachThreadOp>())
1441 return WalkResult::advance();
1442 if (topLevelForeachThreadOp)
1443 // TODO Handle multiple foreach if there is no dependences between them
1444 return WalkResult::interrupt();
1445 topLevelForeachThreadOp = foreachThreadOp;
1446 return WalkResult::advance();
1447 });
1448
1449 if (walkResult.wasInterrupted())
1450 return target->emitError(
1451 "could not find a unique topLevel scf.foreach_thread");
1452
1453 return topLevelForeachThreadOp;
1454}
1455
1456/// Create gpuLauncOp with given kernel configurations
1457static FailureOr<gpu::LaunchOp>
1458createGpuLaunch(RewriterBase &rewriter, Location loc,
1459 Optional<int64_t> gridDimX = llvm::None,
1460 Optional<int64_t> gridDimY = llvm::None,
1461 Optional<int64_t> gridDimZ = llvm::None,
1462 Optional<int64_t> blockDimX = llvm::None,
1463 Optional<int64_t> blockDimY = llvm::None,
1464 Optional<int64_t> blockDimZ = llvm::None) {
1465 if (failed(checkGpuLimits(gridDimX, gridDimY, gridDimZ, blockDimX, blockDimY,
1466 blockDimZ)))
1467 return failure();
1468 auto createConstant = [&](int dim) {
1469 return rewriter.create<arith::ConstantIndexOp>(loc, dim);
1470 };
1471 Value one = createConstant(1);
1472 Value gridSizeX =
1473 gridDimX.has_value() ? createConstant(gridDimX.value()) : one;
1474 Value gridSizeY =
1475 gridDimY.has_value() ? createConstant(gridDimY.value()) : one;
1476 Value gridSizeZ =
1477 gridDimZ.has_value() ? createConstant(gridDimZ.value()) : one;
1478 Value blockSizeX =
1479 blockDimX.has_value() ? createConstant(blockDimX.value()) : one;
1480 Value blockSizeY =
1481 blockDimY.has_value() ? createConstant(blockDimY.value()) : one;
1482 Value blockSizeZ =
1483 blockDimZ.has_value() ? createConstant(blockDimZ.value()) : one;
1484 auto launchOp = rewriter.create<gpu::LaunchOp>(
1485 loc, gridSizeX, gridSizeY, gridSizeZ, blockSizeX, blockSizeY, blockSizeZ);
1486 rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
1487 rewriter.create<gpu::TerminatorOp>(loc);
1488 return launchOp;
1489}
1490
1491/// This is an helper that is only used in
1492/// rewriteTopLevelForeachThreadToGpuBlocks. It generates GPU dialects block_id
1493static void generateGpuBlockIds(RewriterBase &rewriter,
1494 scf::ForeachThreadOp foreachOp,
1495 SmallVector<Value> &blockOps) {
1496 Location loc = foreachOp->getLoc();
1497 OpBuilder::InsertionGuard guard(rewriter);
1498 rewriter.setInsertionPoint(foreachOp);
1499 IndexType indexType = rewriter.getIndexType();
1500 SmallVector<gpu::Dimension> gpuDims{gpu::Dimension::x, gpu::Dimension::y,
1501 gpu::Dimension::z};
1502 for (int64_t idx : llvm::seq<int64_t>(0, gpuDims.size())) {
1503 blockOps.push_back(
1504 rewriter.create<gpu::BlockIdOp>(loc, indexType, gpuDims[idx]));
1505 }
1506}
1507
1508DiagnosedSilenceableFailure
1509transform::MapNestedForeachThreadToGpuBlocks::applyToOne(
1510 Operation *target, SmallVectorImpl<Operation *> &results,
1511 transform::TransformState &state) {
1512 gpu::LaunchOp gpuLaunch = dyn_cast<gpu::LaunchOp>(target);
1513 SimpleRewriter rewriter(getContext());
1514
1515 if (!getGenerateGpuLaunch() && !gpuLaunch) {
1516 target->emitError("Given target is not gpu.launch, set "
1517 "`generate_gpu_launch` attribute");
1518 return DiagnosedSilenceableFailure::definiteFailure();
1519 }
1520
1521 auto res = mlir::linalg::findTopLevelForeachThreadOp(target);
1522 if (failed(res))
1523 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1524
1525 scf::ForeachThreadOp topLevelForeachThreadOp = *res;
1526 OpBuilder::InsertionGuard guard(rewriter);
1527 rewriter.setInsertionPoint(topLevelForeachThreadOp);
1528
1529 // Generate gpu launch here and move the foreach_thread inside
1530 if (getGenerateGpuLaunch()) {
1531 FailureOr<gpu::LaunchOp> maybeGpuLaunch =
1532 createGpuLaunch(rewriter, target->getLoc());
1533 if (failed(maybeGpuLaunch))
1534 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1535 gpuLaunch = *maybeGpuLaunch;
1536 rewriter.setInsertionPointToStart(&gpuLaunch.getBody().front());
1537 Operation *newForeachThreadOp = rewriter.clone(*topLevelForeachThreadOp);
1538 rewriter.eraseOp(topLevelForeachThreadOp);
1539 topLevelForeachThreadOp =
1540 dyn_cast<scf::ForeachThreadOp>(newForeachThreadOp);
1541 }
1542
1543 SmallVector<int64_t> gridDim = extractFromI64ArrayAttr(getGridDim());
1544 if (failed(mlir::linalg::rewriteTopLevelForeachThreadToGpuBlocks(
1545 rewriter, topLevelForeachThreadOp, generateGpuBlockIds, gridDim)))
1546 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1547
1548 if (failed(alterGpuLaunch(rewriter, gpuLaunch, gridDim[0], gridDim[1],
1549 gridDim[2])))
1550 return DiagnosedSilenceableFailure::definiteFailure();
1551
1552 results.assign({gpuLaunch});
1553 return DiagnosedSilenceableFailure(success());
1554}
1555
1556//===----------------------------------------------------------------------===//
1557// TileToForeachThreadOp
1558//===----------------------------------------------------------------------===//
1559
1560DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
1561 RewriterBase &rewriter, transform::TransformState &state,
1562 TransformOpInterface transformOp, ArrayRef<Operation *> targets,
1563 ArrayRef<OpFoldResult> mixedNumThreads,
1564 ArrayRef<OpFoldResult> mixedTileSizes, Optional<ArrayAttr> threadDimMapping,
1565 SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
1566
1567 if (targets.empty())
1568 return DiagnosedSilenceableFailure(success());
1569
1570 // Given a list of OpFoldResults that are either index attrs or op handles,
1571 // return a list of OpFoldResults where all op handles are replaced with the
1572 // first (and only) OpResult of that payload op. (There must be exactly one
1573 // mapped payload op and it must have exactly one index result.)
1574 auto getOpResultsOrIndexAttrs =
1575 [&](SmallVector<OpFoldResult> &result,
1576 ArrayRef<OpFoldResult> opHandlesOrIndexAttrs) {
1577 for (OpFoldResult ofr : opHandlesOrIndexAttrs) {
1578 if (ofr.is<Attribute>()) {
1579 result.push_back(ofr);
1580 continue;
1581 }
1582 ArrayRef<Operation *> dynamicNumThreads =
1583 state.getPayloadOps(ofr.get<Value>());
1584 if (dynamicNumThreads.size() != 1) {
1585 DiagnosedSilenceableFailure diag =
1586 transformOp.emitSilenceableError()
1587 << "handle must be mapped to exactly 1 payload op";
1588 diag.attachNote(ofr.get<Value>().getLoc())
1589 << "mapped to " << dynamicNumThreads.size() << " ops";
1590 return diag;
1591 }
1592 Operation *op = dynamicNumThreads[0];
1593 if (op->getNumResults() != 1 ||
1594 !op->getResult(0).getType().isIndex()) {
1595 DiagnosedSilenceableFailure diag =
1596 transformOp.emitSilenceableError()
1597 << "payload op must have exactly 1 index result";
1598 diag.attachNote(op->getLoc())
1599 << "has " << op->getNumResults() << " results";
1600 return diag;
1601 }
1602 result.push_back(op->getResult(0));
1603 }
1604
1605 return DiagnosedSilenceableFailure(success());
1606 };
1607
1608 // getMixedNumThreads are OpFoldResults[index attributes or PDL operation].
1609 // Convert to OpFoldResults[index attributes or payload op].
1610 SmallVector<OpFoldResult> numThreads;
1611 DiagnosedSilenceableFailure status =
1612 getOpResultsOrIndexAttrs(numThreads, mixedNumThreads);
1613 if (!status.succeeded())
1614 return status;
1615
1616 // getMixedTileSizes are OpFoldResults[index attributes or PDL operation].
1617 // Convert to OpFoldResults[index attributes or payload op].
1618 SmallVector<OpFoldResult> tileSizes;
1619 status = getOpResultsOrIndexAttrs(tileSizes, mixedTileSizes);
1620 if (!status.succeeded())
1621 return status;
1622
1623 // Transform all targets one by one.
1624 for (Operation *target : targets) {
1625 auto tilableOp = dyn_cast<TilingInterface>(target);
1626 if (!tilableOp) {
1627 DiagnosedSilenceableFailure diag =
1628 transformOp.emitSilenceableError()
1629 << "only TilingInterface ops are supported";
1630 diag.attachNote(target->getLoc()) << "target op";
1631 return diag;
1632 }
1633 rewriter.setInsertionPoint(tilableOp);
1634 auto maybeThreadDimMappingAttr = threadDimMapping;
1635 auto dimMapping = llvm::to_vector(
1636 maybeThreadDimMappingAttr
1637 ? extractFromI64ArrayAttr(*maybeThreadDimMappingAttr)
1638 : ArrayRef<int64_t>{});
1639
1640 FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
1641 if (!mixedNumThreads.empty()) {
1642 tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
1643 numThreads, dimMapping);
1644 } else {
1645 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
1646 rewriter, tilableOp, tileSizes, dimMapping);
1647 }
1648
1649 if (failed(tilingResult))
1650 return transformOp.emitDefaultSilenceableFailure(tilableOp);
1651 rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
1652
1653 tileOps.push_back(tilingResult->tileOp);
1654 tiledOps.push_back(tilingResult->tiledOp);
1655 }
1656 return DiagnosedSilenceableFailure(success());
1657}
1658
1659DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
1660 transform::TransformResults &transformResults,
1661 transform::TransformState &state) {
1662 IRRewriter rewriter(getContext());
1663 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1664
1665 // Result payload ops.
1666 SmallVector<Operation *> tileOps;
1667 SmallVector<Operation *> tiledOps;
1668
1669 DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
1670 rewriter, state, cast<TransformOpInterface>(getOperation()), targets,
1671 getMixedNumThreads(), getMixedTileSizes(), getThreadDimMapping(), tileOps,
1672 tiledOps);
1673
1674 if (!diag.succeeded())
1675 return diag;
1676
1677 transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
1678 transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
1679
1680 return DiagnosedSilenceableFailure(success());
1681}
1682
1683void transform::TileToForeachThreadOp::getEffects(
1684 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1685 consumesHandle(getTarget(), effects);
1686 onlyReadsHandle(getTileSizes(), effects);
1687 onlyReadsHandle(getNumThreads(), effects);
1688 producesHandle(getResults(), effects);
1689}
1690
1691SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
1692 return getMixedSizes(getStaticNumThreads(), getNumThreads());
1693}
1694
1695SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
1696 return getMixedSizes(getStaticTileSizes(), getTileSizes());
1697}
1698
1699LogicalResult TileToForeachThreadOp::verify() {
1700 if (getMixedNumThreads().empty() == getMixedTileSizes().empty())
1701 return emitOpError("either num_threads or tile_sizes must be specified");
1702 return success();
1703}
1704
1705//===----------------------------------------------------------------------===//
1706// VectorizeOp
1707//===----------------------------------------------------------------------===//
1708
1709namespace {
1710/// This is an helper only to call vectorize via a pattern inside of
1711/// VectorizeOp::applyToOne.
1712struct VectorizationPattern : public RewritePattern {
1713 explicit VectorizationPattern(MLIRContext *context)
1714 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
1715 LogicalResult matchAndRewrite(Operation *op,
1716 PatternRewriter &rewriter) const override {
1717 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
1718 if (!linalgOp)
1719 return failure();
1720 return vectorize(rewriter, linalgOp);
1721 }
1722};
1723} // namespace
1724
1725DiagnosedSilenceableFailure
1726transform::VectorizeOp::applyToOne(Operation *target,
1727 SmallVectorImpl<Operation *> &results,
1728 transform::TransformState &state) {
1729 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1730 auto diag = this->emitOpError("requires isolated-from-above targets");
1731 diag.attachNote(target->getLoc()) << "non-isolated target";
1732 return DiagnosedSilenceableFailure::definiteFailure();
1733 }
1734
1735 MLIRContext *ctx = getContext();
1736 RewritePatternSet patterns(ctx);
1737 patterns.add<VectorizationPattern>(ctx);
1738
1739 if (!getDisableTransferPermutationMapLoweringPatterns())
1740 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1741
1742 if (!getDisableMultiReductionToContractPatterns())
1743 vector::populateVectorReductionToContractPatterns(patterns);
1744
1745 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
1746 linalg::LinalgCopyVTWForwardingPattern>(ctx,
1747 /*benefit=*/2);
1748 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1749 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1750
1751 patterns.add<CopyVectorizationPattern>(ctx);
1752
1753 if (getVectorizePadding())
1754 linalg::populatePadOpVectorizationPatterns(patterns);
1755
1756 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1757 return DiagnosedSilenceableFailure(reportUnknownTransformError(target));
1758
1759 results.push_back(target);
1760 return DiagnosedSilenceableFailure(success());
1761}
1762
1763//===----------------------------------------------------------------------===//
1764// Transform op registration
1765//===----------------------------------------------------------------------===//
1766
1767namespace {
1768/// Registers new ops and declares PDL as dependent dialect since the
1769/// additional ops are using PDL types for operands and results.
1770class LinalgTransformDialectExtension
1771 : public transform::TransformDialectExtension<
1772 LinalgTransformDialectExtension> {
1773public:
1774 using Base::Base;
1775
1776 void init() {
1777 declareDependentDialect<pdl::PDLDialect>();
1778 declareDependentDialect<LinalgDialect>();
1779 declareGeneratedDialect<AffineDialect>();
1780 declareGeneratedDialect<arith::ArithDialect>();
1781 declareGeneratedDialect<scf::SCFDialect>();
1782 declareGeneratedDialect<vector::VectorDialect>();
1783 declareGeneratedDialect<gpu::GPUDialect>();
1784
1785 registerTransformOps<
1786#define GET_OP_LIST
1787#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1788 >();
1789 }
1790};
1791} // namespace
1792
1793#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
1794
1795#define GET_OP_CLASSES
1796#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
1797
1798void mlir::linalg::registerTransformDialectExtension(
1799 DialectRegistry &registry) {
1800 registry.addExtensions<LinalgTransformDialectExtension>();
1801}

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() {}
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a loc(...) specifier if printing debug info is enabled.
331 virtual void printOptionalLocationSpecifier(Location loc) = 0;
332
333 /// Print a newline and indent the printer to the start of the current
334 /// operation.
335 virtual void printNewline() = 0;
336
337 /// Print a block argument in the usual format of:
338 /// %ssaName : type {attr1=42} loc("here")
339 /// where location printing is controlled by the standard internal option.
340 /// You may pass omitType=true to not print a type, and pass an empty
341 /// attribute list if you don't care for attributes.
342 virtual void printRegionArgument(BlockArgument arg,
343 ArrayRef<NamedAttribute> argAttrs = {},
344 bool omitType = false) = 0;
345
346 /// Print implementations for various things an operation contains.
347 virtual void printOperand(Value value) = 0;
348 virtual void printOperand(Value value, raw_ostream &os) = 0;
349
350 /// Print a comma separated list of operands.
351 template <typename ContainerType>
352 void printOperands(const ContainerType &container) {
353 printOperands(container.begin(), container.end());
354 }
355
356 /// Print a comma separated list of operands.
357 template <typename IteratorType>
358 void printOperands(IteratorType it, IteratorType end) {
359 llvm::interleaveComma(llvm::make_range(it, end), getStream(),
360 [this](Value value) { printOperand(value); });
361 }
362
363 /// Print the given successor.
364 virtual void printSuccessor(Block *successor) = 0;
365
366 /// Print the successor and its operands.
367 virtual void printSuccessorAndUseList(Block *successor,
368 ValueRange succOperands) = 0;
369
370 /// If the specified operation has attributes, print out an attribute
371 /// dictionary with their values. elidedAttrs allows the client to ignore
372 /// specific well known attributes, commonly used if the attribute value is
373 /// printed some other way (like as a fixed operand).
374 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
375 ArrayRef<StringRef> elidedAttrs = {}) = 0;
376
377 /// If the specified operation has attributes, print out an attribute
378 /// dictionary prefixed with 'attributes'.
379 virtual void
380 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
381 ArrayRef<StringRef> elidedAttrs = {}) = 0;
382
383 /// Print the entire operation with the default generic assembly form.
384 /// If `printOpName` is true, then the operation name is printed (the default)
385 /// otherwise it is omitted and the print will start with the operand list.
386 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
387
388 /// Prints a region.
389 /// If 'printEntryBlockArgs' is false, the arguments of the
390 /// block are not printed. If 'printBlockTerminator' is false, the terminator
391 /// operation of the block is not printed. If printEmptyBlock is true, then
392 /// the block header is printed even if the block is empty.
393 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
394 bool printBlockTerminators = true,
395 bool printEmptyBlock = false) = 0;
396
397 /// Renumber the arguments for the specified region to the same names as the
398 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
399 /// operations. If any entry in namesToUse is null, the corresponding
400 /// argument name is left alone.
401 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
402
403 /// Prints an affine map of SSA ids, where SSA id names are used in place
404 /// of dims/symbols.
405 /// Operand values must come from single-result sources, and be valid
406 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
407 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
408 ValueRange operands) = 0;
409
410 /// Prints an affine expression of SSA ids with SSA id names used instead of
411 /// dims and symbols.
412 /// Operand values must come from single-result sources, and be valid
413 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
414 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
415 ValueRange symOperands) = 0;
416
417 /// Print the complete type of an operation in functional form.
418 void printFunctionalType(Operation *op);
419 using AsmPrinter::printFunctionalType;
420};
421
422// Make the implementations convenient to use.
423inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
424 p.printOperand(value);
425 return p;
426}
427
428template <typename T,
429 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
430 !std::is_convertible<T &, Value &>::value,
431 T> * = nullptr>
432inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
433 p.printOperands(values);
434 return p;
435}
436
437inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
438 p.printSuccessor(value);
439 return p;
440}
441
442//===----------------------------------------------------------------------===//
443// AsmParser
444//===----------------------------------------------------------------------===//
445
446/// This base class exposes generic asm parser hooks, usable across the various
447/// derived parsers.
448class AsmParser {
449public:
450 AsmParser() = default;
451 virtual ~AsmParser();
452
453 MLIRContext *getContext() const;
454
455 /// Return the location of the original name token.
456 virtual SMLoc getNameLoc() const = 0;
457
458 //===--------------------------------------------------------------------===//
459 // Utilities
460 //===--------------------------------------------------------------------===//
461
462 /// Emit a diagnostic at the specified location and return failure.
463 virtual InFlightDiagnostic emitError(SMLoc loc,
464 const Twine &message = {}) = 0;
465
466 /// Return a builder which provides useful access to MLIRContext, global
467 /// objects like types and attributes.
468 virtual Builder &getBuilder() const = 0;
469
470 /// Get the location of the next token and store it into the argument. This
471 /// always succeeds.
472 virtual SMLoc getCurrentLocation() = 0;
473 ParseResult getCurrentLocation(SMLoc *loc) {
474 *loc = getCurrentLocation();
475 return success();
476 }
477
478 /// Re-encode the given source location as an MLIR location and return it.
479 /// Note: This method should only be used when a `Location` is necessary, as
480 /// the encoding process is not efficient.
481 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
482
483 //===--------------------------------------------------------------------===//
484 // Token Parsing
485 //===--------------------------------------------------------------------===//
486
487 /// Parse a '->' token.
488 virtual ParseResult parseArrow() = 0;
489
490 /// Parse a '->' token if present
491 virtual ParseResult parseOptionalArrow() = 0;
492
493 /// Parse a `{` token.
494 virtual ParseResult parseLBrace() = 0;
495
496 /// Parse a `{` token if present.
497 virtual ParseResult parseOptionalLBrace() = 0;
498
499 /// Parse a `}` token.
500 virtual ParseResult parseRBrace() = 0;
501
502 /// Parse a `}` token if present.
503 virtual ParseResult parseOptionalRBrace() = 0;
504
505 /// Parse a `:` token.
506 virtual ParseResult parseColon() = 0;
507
508 /// Parse a `:` token if present.
509 virtual ParseResult parseOptionalColon() = 0;
510
511 /// Parse a `,` token.
512 virtual ParseResult parseComma() = 0;
513
514 /// Parse a `,` token if present.
515 virtual ParseResult parseOptionalComma() = 0;
516
517 /// Parse a `=` token.
518 virtual ParseResult parseEqual() = 0;
519
520 /// Parse a `=` token if present.
521 virtual ParseResult parseOptionalEqual() = 0;
522
523 /// Parse a '<' token.
524 virtual ParseResult parseLess() = 0;
525
526 /// Parse a '<' token if present.
527 virtual ParseResult parseOptionalLess() = 0;
528
529 /// Parse a '>' token.
530 virtual ParseResult parseGreater() = 0;
531
532 /// Parse a '>' token if present.
533 virtual ParseResult parseOptionalGreater() = 0;
534
535 /// Parse a '?' token.
536 virtual ParseResult parseQuestion() = 0;
537
538 /// Parse a '?' token if present.
539 virtual ParseResult parseOptionalQuestion() = 0;
540
541 /// Parse a '+' token.
542 virtual ParseResult parsePlus() = 0;
543
544 /// Parse a '+' token if present.
545 virtual ParseResult parseOptionalPlus() = 0;
546
547 /// Parse a '*' token.
548 virtual ParseResult parseStar() = 0;
549
550 /// Parse a '*' token if present.
551 virtual ParseResult parseOptionalStar() = 0;
552
553 /// Parse a '|' token.
554 virtual ParseResult parseVerticalBar() = 0;
555
556 /// Parse a '|' token if present.
557 virtual ParseResult parseOptionalVerticalBar() = 0;
558
559 /// Parse a quoted string token.
560 ParseResult parseString(std::string *string) {
561 auto loc = getCurrentLocation();
562 if (parseOptionalString(string))
563 return emitError(loc, "expected string");
564 return success();
565 }
566
567 /// Parse a quoted string token if present.
568 virtual ParseResult parseOptionalString(std::string *string) = 0;
569
570 /// Parse a `(` token.
571 virtual ParseResult parseLParen() = 0;
572
573 /// Parse a `(` token if present.
574 virtual ParseResult parseOptionalLParen() = 0;
575
576 /// Parse a `)` token.
577 virtual ParseResult parseRParen() = 0;
578
579 /// Parse a `)` token if present.
580 virtual ParseResult parseOptionalRParen() = 0;
581
582 /// Parse a `[` token.
583 virtual ParseResult parseLSquare() = 0;
584
585 /// Parse a `[` token if present.
586 virtual ParseResult parseOptionalLSquare() = 0;
587
588 /// Parse a `]` token.
589 virtual ParseResult parseRSquare() = 0;
590
591 /// Parse a `]` token if present.
592 virtual ParseResult parseOptionalRSquare() = 0;
593
594 /// Parse a `...` token.
595 virtual ParseResult parseEllipsis() = 0;
596
597 /// Parse a `...` token if present;
598 virtual ParseResult parseOptionalEllipsis() = 0;
599
600 /// Parse a floating point value from the stream.
601 virtual ParseResult parseFloat(double &result) = 0;
602
603 /// Parse an integer value from the stream.
604 template <typename IntT>
605 ParseResult parseInteger(IntT &result) {
606 auto loc = getCurrentLocation();
607 OptionalParseResult parseResult = parseOptionalInteger(result);
6
Calling 'AsmParser::parseOptionalInteger'
10
Returning from 'AsmParser::parseOptionalInteger'
608 if (!parseResult.has_value())
11
Taking true branch
609 return emitError(loc, "expected integer value");
12
Returning without writing to 'result'
610 return *parseResult;
611 }
612
613 /// Parse an optional integer value from the stream.
614 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
615
616 template <typename IntT>
617 OptionalParseResult parseOptionalInteger(IntT &result) {
618 auto loc = getCurrentLocation();
619
620 // Parse the unsigned variant.
621 APInt uintResult;
622 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
623 if (!parseResult.has_value() || failed(*parseResult))
7
Assuming the condition is true
8
Taking true branch
624 return parseResult;
9
Returning without writing to 'result'
625
626 // Try to convert to the provided integer type. sextOrTrunc is correct even
627 // for unsigned types because parseOptionalInteger ensures the sign bit is
628 // zero for non-negated integers.
629 result =
630 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
631 if (APInt(uintResult.getBitWidth(), result) != uintResult)
632 return emitError(loc, "integer value too large");
633 return success();
634 }
635
636 /// These are the supported delimiters around operand lists and region
637 /// argument lists, used by parseOperandList.
638 enum class Delimiter {
639 /// Zero or more operands with no delimiters.
640 None,
641 /// Parens surrounding zero or more operands.
642 Paren,
643 /// Square brackets surrounding zero or more operands.
644 Square,
645 /// <> brackets surrounding zero or more operands.
646 LessGreater,
647 /// {} brackets surrounding zero or more operands.
648 Braces,
649 /// Parens supporting zero or more operands, or nothing.
650 OptionalParen,
651 /// Square brackets supporting zero or more ops, or nothing.
652 OptionalSquare,
653 /// <> brackets supporting zero or more ops, or nothing.
654 OptionalLessGreater,
655 /// {} brackets surrounding zero or more operands, or nothing.
656 OptionalBraces,
657 };
658
659 /// Parse a list of comma-separated items with an optional delimiter. If a
660 /// delimiter is provided, then an empty list is allowed. If not, then at
661 /// least one element will be parsed.
662 ///
663 /// contextMessage is an optional message appended to "expected '('" sorts of
664 /// diagnostics when parsing the delimeters.
665 virtual ParseResult
666 parseCommaSeparatedList(Delimiter delimiter,
667 function_ref<ParseResult()> parseElementFn,
668 StringRef contextMessage = StringRef()) = 0;
669
670 /// Parse a comma separated list of elements that must have at least one entry
671 /// in it.
672 ParseResult
673 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
674 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
675 }
676
677 //===--------------------------------------------------------------------===//
678 // Keyword Parsing
679 //===--------------------------------------------------------------------===//
680
681 /// This class represents a StringSwitch like class that is useful for parsing
682 /// expected keywords. On construction, it invokes `parseKeyword` and
683 /// processes each of the provided cases statements until a match is hit. The
684 /// provided `ResultT` must be assignable from `failure()`.
685 template <typename ResultT = ParseResult>
686 class KeywordSwitch {
687 public:
688 KeywordSwitch(AsmParser &parser)
689 : parser(parser), loc(parser.getCurrentLocation()) {
690 if (failed(parser.parseKeywordOrCompletion(&keyword)))
691 result = failure();
692 }
693
694 /// Case that uses the provided value when true.
695 KeywordSwitch &Case(StringLiteral str, ResultT value) {
696 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
697 }
698 KeywordSwitch &Default(ResultT value) {
699 return Default([&](StringRef, SMLoc) { return std::move(value); });
700 }
701 /// Case that invokes the provided functor when true. The parameters passed
702 /// to the functor are the keyword, and the location of the keyword (in case
703 /// any errors need to be emitted).
704 template <typename FnT>
705 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
706 Case(StringLiteral str, FnT &&fn) {
707 if (result)
708 return *this;
709
710 // If the word was empty, record this as a completion.
711 if (keyword.empty())
712 parser.codeCompleteExpectedTokens(str);
713 else if (keyword == str)
714 result.emplace(std::move(fn(keyword, loc)));
715 return *this;
716 }
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Default(FnT &&fn) {
720 if (!result)
721 result.emplace(fn(keyword, loc));
722 return *this;
723 }
724
725 /// Returns true if this switch has a value yet.
726 bool hasValue() const { return result.has_value(); }
727
728 /// Return the result of the switch.
729 [[nodiscard]] operator ResultT() {
730 if (!result)
731 return parser.emitError(loc, "unexpected keyword: ") << keyword;
732 return std::move(*result);
733 }
734
735 private:
736 /// The parser used to construct this switch.
737 AsmParser &parser;
738
739 /// The location of the keyword, used to emit errors as necessary.
740 SMLoc loc;
741
742 /// The parsed keyword itself.
743 StringRef keyword;
744
745 /// The result of the switch statement or none if currently unknown.
746 Optional<ResultT> result;
747 };
748
749 /// Parse a given keyword.
750 ParseResult parseKeyword(StringRef keyword) {
751 return parseKeyword(keyword, "");
752 }
753 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
754
755 /// Parse a keyword into 'keyword'.
756 ParseResult parseKeyword(StringRef *keyword) {
757 auto loc = getCurrentLocation();
758 if (parseOptionalKeyword(keyword))
759 return emitError(loc, "expected valid keyword");
760 return success();
761 }
762
763 /// Parse the given keyword if present.
764 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
765
766 /// Parse a keyword, if present, into 'keyword'.
767 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
768
769 /// Parse a keyword, if present, and if one of the 'allowedValues',
770 /// into 'keyword'
771 virtual ParseResult
772 parseOptionalKeyword(StringRef *keyword,
773 ArrayRef<StringRef> allowedValues) = 0;
774
775 /// Parse a keyword or a quoted string.
776 ParseResult parseKeywordOrString(std::string *result) {
777 if (failed(parseOptionalKeywordOrString(result)))
778 return emitError(getCurrentLocation())
779 << "expected valid keyword or string";
780 return success();
781 }
782
783 /// Parse an optional keyword or string.
784 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
785
786 //===--------------------------------------------------------------------===//
787 // Attribute/Type Parsing
788 //===--------------------------------------------------------------------===//
789
790 /// Invoke the `getChecked` method of the given Attribute or Type class, using
791 /// the provided location to emit errors in the case of failure. Note that
792 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
793 /// context parameter.
794 template <typename T, typename... ParamsT>
795 auto getChecked(SMLoc loc, ParamsT &&...params) {
796 return T::getChecked([&] { return emitError(loc); },
797 std::forward<ParamsT>(params)...);
798 }
799 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800 /// errors.
801 template <typename T, typename... ParamsT>
802 auto getChecked(ParamsT &&...params) {
803 return T::getChecked([&] { return emitError(getNameLoc()); },
804 std::forward<ParamsT>(params)...);
805 }
806
807 //===--------------------------------------------------------------------===//
808 // Attribute Parsing
809 //===--------------------------------------------------------------------===//
810
811 /// Parse an arbitrary attribute of a given type and return it in result.
812 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
813
814 /// Parse a custom attribute with the provided callback, unless the next
815 /// token is `#`, in which case the generic parser is invoked.
816 virtual ParseResult parseCustomAttributeWithFallback(
817 Attribute &result, Type type,
818 function_ref<ParseResult(Attribute &result, Type type)>
819 parseAttribute) = 0;
820
821 /// Parse an attribute of a specific kind and type.
822 template <typename AttrType>
823 ParseResult parseAttribute(AttrType &result, Type type = {}) {
824 SMLoc loc = getCurrentLocation();
825
826 // Parse any kind of attribute.
827 Attribute attr;
828 if (parseAttribute(attr, type))
829 return failure();
830
831 // Check for the right kind of attribute.
832 if (!(result = attr.dyn_cast<AttrType>()))
833 return emitError(loc, "invalid kind of attribute specified");
834
835 return success();
836 }
837
838 /// Parse an arbitrary attribute and return it in result. This also adds the
839 /// attribute to the specified attribute list with the specified name.
840 ParseResult parseAttribute(Attribute &result, StringRef attrName,
841 NamedAttrList &attrs) {
842 return parseAttribute(result, Type(), attrName, attrs);
843 }
844
845 /// Parse an attribute of a specific kind and type.
846 template <typename AttrType>
847 ParseResult parseAttribute(AttrType &result, StringRef attrName,
848 NamedAttrList &attrs) {
849 return parseAttribute(result, Type(), attrName, attrs);
850 }
851
852 /// Parse an arbitrary attribute of a given type and populate it in `result`.
853 /// This also adds the attribute to the specified attribute list with the
854 /// specified name.
855 template <typename AttrType>
856 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
857 NamedAttrList &attrs) {
858 SMLoc loc = getCurrentLocation();
859
860 // Parse any kind of attribute.
861 Attribute attr;
862 if (parseAttribute(attr, type))
863 return failure();
864
865 // Check for the right kind of attribute.
866 result = attr.dyn_cast<AttrType>();
867 if (!result)
868 return emitError(loc, "invalid kind of attribute specified");
869
870 attrs.append(attrName, result);
871 return success();
872 }
873
874 /// Trait to check if `AttrType` provides a `parse` method.
875 template <typename AttrType>
876 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
877 std::declval<Type>()));
878 template <typename AttrType>
879 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
880
881 /// Parse a custom attribute of a given type unless the next token is `#`, in
882 /// which case the generic parser is invoked. The parsed attribute is
883 /// populated in `result` and also added to the specified attribute list with
884 /// the specified name.
885 template <typename AttrType>
886 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
887 parseCustomAttributeWithFallback(AttrType &result, Type type,
888 StringRef attrName, NamedAttrList &attrs) {
889 SMLoc loc = getCurrentLocation();
890
891 // Parse any kind of attribute.
892 Attribute attr;
893 if (parseCustomAttributeWithFallback(
894 attr, type, [&](Attribute &result, Type type) -> ParseResult {
895 result = AttrType::parse(*this, type);
896 if (!result)
897 return failure();
898 return success();
899 }))
900 return failure();
901
902 // Check for the right kind of attribute.
903 result = attr.dyn_cast<AttrType>();
904 if (!result)
905 return emitError(loc, "invalid kind of attribute specified");
906
907 attrs.append(attrName, result);
908 return success();
909 }
910
911 /// SFINAE parsing method for Attribute that don't implement a parse method.
912 template <typename AttrType>
913 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
914 parseCustomAttributeWithFallback(AttrType &result, Type type,
915 StringRef attrName, NamedAttrList &attrs) {
916 return parseAttribute(result, type, attrName, attrs);
917 }
918
919 /// Parse a custom attribute of a given type unless the next token is `#`, in
920 /// which case the generic parser is invoked. The parsed attribute is
921 /// populated in `result`.
922 template <typename AttrType>
923 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
924 parseCustomAttributeWithFallback(AttrType &result) {
925 SMLoc loc = getCurrentLocation();
926
927 // Parse any kind of attribute.
928 Attribute attr;
929 if (parseCustomAttributeWithFallback(
930 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
931 result = AttrType::parse(*this, type);
932 return success(!!result);
933 }))
934 return failure();
935
936 // Check for the right kind of attribute.
937 result = attr.dyn_cast<AttrType>();
938 if (!result)
939 return emitError(loc, "invalid kind of attribute specified");
940 return success();
941 }
942
943 /// SFINAE parsing method for Attribute that don't implement a parse method.
944 template <typename AttrType>
945 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
946 parseCustomAttributeWithFallback(AttrType &result) {
947 return parseAttribute(result);
948 }
949
950 /// Parse an arbitrary optional attribute of a given type and return it in
951 /// result.
952 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
953 Type type = {}) = 0;
954
955 /// Parse an optional array attribute and return it in result.
956 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
957 Type type = {}) = 0;
958
959 /// Parse an optional string attribute and return it in result.
960 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
961 Type type = {}) = 0;
962
963 /// Parse an optional attribute of a specific type and add it to the list with
964 /// the specified name.
965 template <typename AttrType>
966 OptionalParseResult parseOptionalAttribute(AttrType &result,
967 StringRef attrName,
968 NamedAttrList &attrs) {
969 return parseOptionalAttribute(result, Type(), attrName, attrs);
970 }
971
972 /// Parse an optional attribute of a specific type and add it to the list with
973 /// the specified name.
974 template <typename AttrType>
975 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
976 StringRef attrName,
977 NamedAttrList &attrs) {
978 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
979 if (parseResult.has_value() && succeeded(*parseResult))
980 attrs.append(attrName, result);
981 return parseResult;
982 }
983
984 /// Parse a named dictionary into 'result' if it is present.
985 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
986
987 /// Parse a named dictionary into 'result' if the `attributes` keyword is
988 /// present.
989 virtual ParseResult
990 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
991
992 /// Parse an affine map instance into 'map'.
993 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
994
995 /// Parse an integer set instance into 'set'.
996 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
997
998 //===--------------------------------------------------------------------===//
999 // Identifier Parsing
1000 //===--------------------------------------------------------------------===//
1001
1002 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1003 /// attribute named 'attrName'.
1004 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1005 NamedAttrList &attrs) {
1006 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1007 return emitError(getCurrentLocation())
1008 << "expected valid '@'-identifier for symbol name";
1009 return success();
1010 }
1011
1012 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1013 /// string attribute named 'attrName'.
1014 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1015 StringRef attrName,
1016 NamedAttrList &attrs) = 0;
1017
1018 //===--------------------------------------------------------------------===//
1019 // Resource Parsing
1020 //===--------------------------------------------------------------------===//
1021
1022 /// Parse a handle to a resource within the assembly format.
1023 template <typename ResourceT>
1024 FailureOr<ResourceT> parseResourceHandle() {
1025 SMLoc handleLoc = getCurrentLocation();
1026
1027 // Try to load the dialect that owns the handle.
1028 auto *dialect =
1029 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1030 if (!dialect) {
1031 return emitError(handleLoc)
1032 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1033 << "' is unknown";
1034 }
1035
1036 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1037 if (failed(handle))
1038 return failure();
1039 if (auto *result = dyn_cast<ResourceT>(&*handle))
1040 return std::move(*result);
1041 return emitError(handleLoc) << "provided resource handle differs from the "
1042 "expected resource type";
1043 }
1044
1045 //===--------------------------------------------------------------------===//
1046 // Type Parsing
1047 //===--------------------------------------------------------------------===//
1048
1049 /// Parse a type.
1050 virtual ParseResult parseType(Type &result) = 0;
1051
1052 /// Parse a custom type with the provided callback, unless the next
1053 /// token is `#`, in which case the generic parser is invoked.
1054 virtual ParseResult parseCustomTypeWithFallback(
1055 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1056
1057 /// Parse an optional type.
1058 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1059
1060 /// Parse a type of a specific type.
1061 template <typename TypeT>
1062 ParseResult parseType(TypeT &result) {
1063 SMLoc loc = getCurrentLocation();
1064
1065 // Parse any kind of type.
1066 Type type;
1067 if (parseType(type))
1068 return failure();
1069
1070 // Check for the right kind of type.
1071 result = type.dyn_cast<TypeT>();
1072 if (!result)
1073 return emitError(loc, "invalid kind of type specified");
1074
1075 return success();
1076 }
1077
1078 /// Trait to check if `TypeT` provides a `parse` method.
1079 template <typename TypeT>
1080 using type_has_parse_method =
1081 decltype(TypeT::parse(std::declval<AsmParser &>()));
1082 template <typename TypeT>
1083 using detect_type_has_parse_method =
1084 llvm::is_detected<type_has_parse_method, TypeT>;
1085
1086 /// Parse a custom Type of a given type unless the next token is `#`, in
1087 /// which case the generic parser is invoked. The parsed Type is
1088 /// populated in `result`.
1089 template <typename TypeT>
1090 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1091 parseCustomTypeWithFallback(TypeT &result) {
1092 SMLoc loc = getCurrentLocation();
1093
1094 // Parse any kind of Type.
1095 Type type;
1096 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1097 result = TypeT::parse(*this);
1098 return success(!!result);
1099 }))
1100 return failure();
1101
1102 // Check for the right kind of Type.
1103 result = type.dyn_cast<TypeT>();
1104 if (!result)
1105 return emitError(loc, "invalid kind of Type specified");
1106 return success();
1107 }
1108
1109 /// SFINAE parsing method for Type that don't implement a parse method.
1110 template <typename TypeT>
1111 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1112 parseCustomTypeWithFallback(TypeT &result) {
1113 return parseType(result);
1114 }
1115
1116 /// Parse a type list.
1117 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1118 return parseCommaSeparatedList(
1119 [&]() { return parseType(result.emplace_back()); });
1120 }
1121
1122 /// Parse an arrow followed by a type list.
1123 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1124
1125 /// Parse an optional arrow followed by a type list.
1126 virtual ParseResult
1127 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1128
1129 /// Parse a colon followed by a type.
1130 virtual ParseResult parseColonType(Type &result) = 0;
1131
1132 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1133 template <typename TypeType>
1134 ParseResult parseColonType(TypeType &result) {
1135 SMLoc loc = getCurrentLocation();
1136
1137 // Parse any kind of type.
1138 Type type;
1139 if (parseColonType(type))
1140 return failure();
1141
1142 // Check for the right kind of type.
1143 result = type.dyn_cast<TypeType>();
1144 if (!result)
1145 return emitError(loc, "invalid kind of type specified");
1146
1147 return success();
1148 }
1149
1150 /// Parse a colon followed by a type list, which must have at least one type.
1151 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1152
1153 /// Parse an optional colon followed by a type list, which if present must
1154 /// have at least one type.
1155 virtual ParseResult
1156 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1157
1158 /// Parse a keyword followed by a type.
1159 ParseResult parseKeywordType(const char *keyword, Type &result) {
1160 return failure(parseKeyword(keyword) || parseType(result));
1161 }
1162
1163 /// Add the specified type to the end of the specified type list and return
1164 /// success. This is a helper designed to allow parse methods to be simple
1165 /// and chain through || operators.
1166 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1167 result.push_back(type);
1168 return success();
1169 }
1170
1171 /// Add the specified types to the end of the specified type list and return
1172 /// success. This is a helper designed to allow parse methods to be simple
1173 /// and chain through || operators.
1174 ParseResult addTypesToList(ArrayRef<Type> types,
1175 SmallVectorImpl<Type> &result) {
1176 result.append(types.begin(), types.end());
1177 return success();
1178 }
1179
1180 /// Parse a dimension list of a tensor or memref type. This populates the
1181 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1182 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1183 ///
1184 /// dimension-list ::= eps | dimension (`x` dimension)*
1185 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1186 /// dimension ::= `?` | decimal-literal
1187 ///
1188 /// When `allowDynamic` is not set, this is used to parse:
1189 ///
1190 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1191 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1192 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1193 bool allowDynamic = true,
1194 bool withTrailingX = true) = 0;
1195
1196 /// Parse an 'x' token in a dimension list, handling the case where the x is
1197 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1198 /// next token.
1199 virtual ParseResult parseXInDimensionList() = 0;
1200
1201protected:
1202 /// Parse a handle to a resource within the assembly format for the given
1203 /// dialect.
1204 virtual FailureOr<AsmDialectResourceHandle>
1205 parseResourceHandle(Dialect *dialect) = 0;
1206
1207 //===--------------------------------------------------------------------===//
1208 // Code Completion
1209 //===--------------------------------------------------------------------===//
1210
1211 /// Parse a keyword, or an empty string if the current location signals a code
1212 /// completion.
1213 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1214
1215 /// Signal the code completion of a set of expected tokens.
1216 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1217
1218private:
1219 AsmParser(const AsmParser &) = delete;
1220 void operator=(const AsmParser &) = delete;
1221};
1222
1223//===----------------------------------------------------------------------===//
1224// OpAsmParser
1225//===----------------------------------------------------------------------===//
1226
1227/// The OpAsmParser has methods for interacting with the asm parser: parsing
1228/// things from it, emitting errors etc. It has an intentionally high-level API
1229/// that is designed to reduce/constrain syntax innovation in individual
1230/// operations.
1231///
1232/// For example, consider an op like this:
1233///
1234/// %x = load %p[%1, %2] : memref<...>
1235///
1236/// The "%x = load" tokens are already parsed and therefore invisible to the
1237/// custom op parser. This can be supported by calling `parseOperandList` to
1238/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1239/// parse the indices, then calling `parseColonTypeList` to parse the result
1240/// type.
1241///
1242class OpAsmParser : public AsmParser {
1243public:
1244 using AsmParser::AsmParser;
1245 ~OpAsmParser() override;
1246
1247 /// Parse a loc(...) specifier if present, filling in result if so.
1248 /// Location for BlockArgument and Operation may be deferred with an alias, in
1249 /// which case an OpaqueLoc is set and will be resolved when parsing
1250 /// completes.
1251 virtual ParseResult
1252 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1253
1254 /// Return the name of the specified result in the specified syntax, as well
1255 /// as the sub-element in the name. It returns an empty string and ~0U for
1256 /// invalid result numbers. For example, in this operation:
1257 ///
1258 /// %x, %y:2, %z = foo.op
1259 ///
1260 /// getResultName(0) == {"x", 0 }
1261 /// getResultName(1) == {"y", 0 }
1262 /// getResultName(2) == {"y", 1 }
1263 /// getResultName(3) == {"z", 0 }
1264 /// getResultName(4) == {"", ~0U }
1265 virtual std::pair<StringRef, unsigned>
1266 getResultName(unsigned resultNo) const = 0;
1267
1268 /// Return the number of declared SSA results. This returns 4 for the foo.op
1269 /// example in the comment for `getResultName`.
1270 virtual size_t getNumResults() const = 0;
1271
1272 // These methods emit an error and return failure or success. This allows
1273 // these to be chained together into a linear sequence of || expressions in
1274 // many cases.
1275
1276 /// Parse an operation in its generic form.
1277 /// The parsed operation is parsed in the current context and inserted in the
1278 /// provided block and insertion point. The results produced by this operation
1279 /// aren't mapped to any named value in the parser. Returns nullptr on
1280 /// failure.
1281 virtual Operation *parseGenericOperation(Block *insertBlock,
1282 Block::iterator insertPt) = 0;
1283
1284 /// Parse the name of an operation, in the custom form. On success, return a
1285 /// an object of type 'OperationName'. Otherwise, failure is returned.
1286 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1287
1288 //===--------------------------------------------------------------------===//
1289 // Operand Parsing
1290 //===--------------------------------------------------------------------===//
1291
1292 /// This is the representation of an operand reference.
1293 struct UnresolvedOperand {
1294 SMLoc location; // Location of the token.
1295 StringRef name; // Value name, e.g. %42 or %abc
1296 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1297 };
1298
1299 /// Parse different components, viz., use-info of operand(s), successor(s),
1300 /// region(s), attribute(s) and function-type, of the generic form of an
1301 /// operation instance and populate the input operation-state 'result' with
1302 /// those components. If any of the components is explicitly provided, then
1303 /// skip parsing that component.
1304 virtual ParseResult parseGenericOperationAfterOpName(
1305 OperationState &result,
1306 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1307 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1308 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1309 llvm::None,
1310 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1311 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1312
1313 /// Parse a single SSA value operand name along with a result number if
1314 /// `allowResultNumber` is true.
1315 virtual ParseResult parseOperand(UnresolvedOperand &result,
1316 bool allowResultNumber = true) = 0;
1317
1318 /// Parse a single operand if present.
1319 virtual OptionalParseResult
1320 parseOptionalOperand(UnresolvedOperand &result,
1321 bool allowResultNumber = true) = 0;
1322
1323 /// Parse zero or more SSA comma-separated operand references with a specified
1324 /// surrounding delimiter, and an optional required operand count.
1325 virtual ParseResult
1326 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1327 Delimiter delimiter = Delimiter::None,
1328 bool allowResultNumber = true,
1329 int requiredOperandCount = -1) = 0;
1330
1331 /// Parse a specified number of comma separated operands.
1332 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1333 int requiredOperandCount,
1334 Delimiter delimiter = Delimiter::None) {
1335 return parseOperandList(result, delimiter,
1336 /*allowResultNumber=*/true, requiredOperandCount);
1337 }
1338
1339 /// Parse zero or more trailing SSA comma-separated trailing operand
1340 /// references with a specified surrounding delimiter, and an optional
1341 /// required operand count. A leading comma is expected before the
1342 /// operands.
1343 ParseResult
1344 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1345 Delimiter delimiter = Delimiter::None) {
1346 if (failed(parseOptionalComma()))
1347 return success(); // The comma is optional.
1348 return parseOperandList(result, delimiter);
1349 }
1350
1351 /// Resolve an operand to an SSA value, emitting an error on failure.
1352 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1353 Type type,
1354 SmallVectorImpl<Value> &result) = 0;
1355
1356 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1357 /// appending the results to the list on success. This method should be used
1358 /// when all operands have the same type.
1359 template <typename Operands = ArrayRef<UnresolvedOperand>>
1360 ParseResult resolveOperands(Operands &&operands, Type type,
1361 SmallVectorImpl<Value> &result) {
1362 for (const UnresolvedOperand &operand : operands)
1363 if (resolveOperand(operand, type, result))
1364 return failure();
1365 return success();
1366 }
1367 template <typename Operands = ArrayRef<UnresolvedOperand>>
1368 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1369 SmallVectorImpl<Value> &result) {
1370 return resolveOperands(std::forward<Operands>(operands), type, result);
1371 }
1372
1373 /// Resolve a list of operands and a list of operand types to SSA values,
1374 /// emitting an error and returning failure, or appending the results
1375 /// to the list on success.
1376 template <typename Operands = ArrayRef<UnresolvedOperand>,
1377 typename Types = ArrayRef<Type>>
1378 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1379 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1380 SmallVectorImpl<Value> &result) {
1381 size_t operandSize = std::distance(operands.begin(), operands.end());
1382 size_t typeSize = std::distance(types.begin(), types.end());
1383 if (operandSize != typeSize)
1384 return emitError(loc)
1385 << operandSize << " operands present, but expected " << typeSize;
1386
1387 for (auto [operand, type] : llvm::zip(operands, types))
1388 if (resolveOperand(operand, type, result))
1389 return failure();
1390 return success();
1391 }
1392
1393 /// Parses an affine map attribute where dims and symbols are SSA operands.
1394 /// Operand values must come from single-result sources, and be valid
1395 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1396 virtual ParseResult
1397 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1398 Attribute &map, StringRef attrName,
1399 NamedAttrList &attrs,
1400 Delimiter delimiter = Delimiter::Square) = 0;
1401
1402 /// Parses an affine expression where dims and symbols are SSA operands.
1403 /// Operand values must come from single-result sources, and be valid
1404 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1405 virtual ParseResult
1406 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1407 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1408 AffineExpr &expr) = 0;
1409
1410 //===--------------------------------------------------------------------===//
1411 // Argument Parsing
1412 //===--------------------------------------------------------------------===//
1413
1414 struct Argument {
1415 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1416 Type type; // Type.
1417 DictionaryAttr attrs; // Attributes if present.
1418 Optional<Location> sourceLoc; // Source location specifier if present.
1419 };
1420
1421 /// Parse a single argument with the following syntax:
1422 ///
1423 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1424 ///
1425 /// If `allowType` is false or `allowAttrs` are false then the respective
1426 /// parts of the grammar are not parsed.
1427 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1428 bool allowAttrs = false) = 0;
1429
1430 /// Parse a single argument if present.
1431 virtual OptionalParseResult
1432 parseOptionalArgument(Argument &result, bool allowType = false,
1433 bool allowAttrs = false) = 0;
1434
1435 /// Parse zero or more arguments with a specified surrounding delimiter.
1436 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1437 Delimiter delimiter = Delimiter::None,
1438 bool allowType = false,
1439 bool allowAttrs = false) = 0;
1440
1441 //===--------------------------------------------------------------------===//
1442 // Region Parsing
1443 //===--------------------------------------------------------------------===//
1444
1445 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1446 /// moved to the op regions after the op is created. The first block of the
1447 /// region takes 'arguments'.
1448 ///
1449 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1450 /// shadow the names of other existing SSA values defined above the region
1451 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1452 /// to operations that are 'IsolatedFromAbove'.
1453 virtual ParseResult parseRegion(Region &region,
1454 ArrayRef<Argument> arguments = {},
1455 bool enableNameShadowing = false) = 0;
1456
1457 /// Parses a region if present.
1458 virtual OptionalParseResult
1459 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1460 bool enableNameShadowing = false) = 0;
1461
1462 /// Parses a region if present. If the region is present, a new region is
1463 /// allocated and placed in `region`. If no region is present or on failure,
1464 /// `region` remains untouched.
1465 virtual OptionalParseResult
1466 parseOptionalRegion(std::unique_ptr<Region> &region,
1467 ArrayRef<Argument> arguments = {},
1468 bool enableNameShadowing = false) = 0;
1469
1470 //===--------------------------------------------------------------------===//
1471 // Successor Parsing
1472 //===--------------------------------------------------------------------===//
1473
1474 /// Parse a single operation successor.
1475 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1476
1477 /// Parse an optional operation successor.
1478 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1479
1480 /// Parse a single operation successor and its operand list.
1481 virtual ParseResult
1482 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1483
1484 //===--------------------------------------------------------------------===//
1485 // Type Parsing
1486 //===--------------------------------------------------------------------===//
1487
1488 /// Parse a list of assignments of the form
1489 /// (%x1 = %y1, %x2 = %y2, ...)
1490 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1491 SmallVectorImpl<UnresolvedOperand> &rhs) {
1492 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1493 if (!result.has_value())
1494 return emitError(getCurrentLocation(), "expected '('");
1495 return result.value();
1496 }
1497
1498 virtual OptionalParseResult
1499 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1500 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1501};
1502
1503//===--------------------------------------------------------------------===//
1504// Dialect OpAsm interface.
1505//===--------------------------------------------------------------------===//
1506
1507/// A functor used to set the name of the start of a result group of an
1508/// operation. See 'getAsmResultNames' below for more details.
1509using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1510
1511/// A functor used to set the name of blocks in regions directly nested under
1512/// an operation.
1513using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1514
1515class OpAsmDialectInterface
1516 : public DialectInterface::Base<OpAsmDialectInterface> {
1517public:
1518 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1519
1520 //===------------------------------------------------------------------===//
1521 // Aliases
1522 //===------------------------------------------------------------------===//
1523
1524 /// Holds the result of `getAlias` hook call.
1525 enum class AliasResult {
1526 /// The object (type or attribute) is not supported by the hook
1527 /// and an alias was not provided.
1528 NoAlias,
1529 /// An alias was provided, but it might be overriden by other hook.
1530 OverridableAlias,
1531 /// An alias was provided and it should be used
1532 /// (no other hooks will be checked).
1533 FinalAlias
1534 };
1535
1536 /// Hooks for getting an alias identifier alias for a given symbol, that is
1537 /// not necessarily a part of this dialect. The identifier is used in place of
1538 /// the symbol when printing textual IR. These aliases must not contain `.` or
1539 /// end with a numeric digit([0-9]+).
1540 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1541 return AliasResult::NoAlias;
1542 }
1543 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1544 return AliasResult::NoAlias;
1545 }
1546
1547 //===--------------------------------------------------------------------===//
1548 // Resources
1549 //===--------------------------------------------------------------------===//
1550
1551 /// Declare a resource with the given key, returning a handle to use for any
1552 /// references of this resource key within the IR during parsing. The result
1553 /// of `getResourceKey` on the returned handle is permitted to be different
1554 /// than `key`.
1555 virtual FailureOr<AsmDialectResourceHandle>
1556 declareResource(StringRef key) const {
1557 return failure();
1558 }
1559
1560 /// Return a key to use for the given resource. This key should uniquely
1561 /// identify this resource within the dialect.
1562 virtual std::string
1563 getResourceKey(const AsmDialectResourceHandle &handle) const {
1564 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1565)
1565 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1565)
;
1566 }
1567
1568 /// Hook for parsing resource entries. Returns failure if the entry was not
1569 /// valid, or could otherwise not be processed correctly. Any necessary errors
1570 /// can be emitted via the provided entry.
1571 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1572
1573 /// Hook for building resources to use during printing. The given `op` may be
1574 /// inspected to help determine what information to include.
1575 /// `referencedResources` contains all of the resources detected when printing
1576 /// 'op'.
1577 virtual void
1578 buildResources(Operation *op,
1579 const SetVector<AsmDialectResourceHandle> &referencedResources,
1580 AsmResourceBuilder &builder) const {}
1581};
1582} // namespace mlir
1583
1584//===--------------------------------------------------------------------===//
1585// Operation OpAsm interface.
1586//===--------------------------------------------------------------------===//
1587
1588/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1589#include "mlir/IR/OpAsmInterface.h.inc"
1590
1591namespace llvm {
1592template <>
1593struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1594 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1595 return {DenseMapInfo<void *>::getEmptyKey(),
1596 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1597 }
1598 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1599 return {DenseMapInfo<void *>::getTombstoneKey(),
1600 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1601 }
1602 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1603 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1604 }
1605 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1606 const mlir::AsmDialectResourceHandle &rhs) {
1607 return lhs.getResource() == rhs.getResource();
1608 }
1609};
1610} // namespace llvm
1611
1612#endif