Bug Summary

File:build/source/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Warning:line 1090, column 9
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name LinalgTransformOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16 -I tools/mlir/lib/Dialect/Linalg/TransformOps -I /build/source/mlir/lib/Dialect/Linalg/TransformOps -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1671487667 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-12-20-010714-16201-1 -x c++ /build/source/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

/build/source/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

1//===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h"
10
11#include "mlir/AsmParser/AsmParser.h"
12#include "mlir/Dialect/Affine/IR/AffineOps.h"
13#include "mlir/Dialect/Arith/IR/Arith.h"
14#include "mlir/Dialect/GPU/IR/GPUDialect.h"
15#include "mlir/Dialect/Linalg/IR/Linalg.h"
16#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17#include "mlir/Dialect/PDL/IR/PDL.h"
18#include "mlir/Dialect/PDL/IR/PDLTypes.h"
19#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h"
20#include "mlir/Dialect/Transform/IR/TransformDialect.h"
21#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
22#include "mlir/Dialect/Transform/IR/TransformUtils.h"
23#include "mlir/Dialect/Transform/Utils/Utils.h"
24#include "mlir/IR/BuiltinTypes.h"
25#include "mlir/IR/Matchers.h"
26#include "mlir/IR/OpDefinition.h"
27#include "mlir/Interfaces/TilingInterface.h"
28#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
29#include "llvm/ADT/StringSet.h"
30#include "llvm/Support/Debug.h"
31
32using namespace mlir;
33using namespace mlir::linalg;
34using namespace mlir::transform;
35
36#define DEBUG_TYPE"linalg-transforms" "linalg-transforms"
37
38/// Attempts to apply the pattern specified as template argument to the given
39/// operation. The pattern is expected to have a `returningMatchAndRewrite`
40/// function that returns the "main" result or failure. Returns failure if the
41/// pattern failed to apply. Extra arguments are forwarded to the pattern
42/// constructor.
43template <typename PatternTy, typename... Args>
44static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
45 // Check if the given operation has the type expected by the pattern.
46 using OpTy = typename llvm::function_traits<
47 decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
48 auto op = dyn_cast<OpTy>(operation);
49 if (!op)
50 return failure();
51
52 // Apply the pattern directly to the op.
53 PatternTy pattern(operation->getContext(), std::forward<Args>(args)...);
54 TrivialPatternRewriter rewriter(operation->getContext());
55 rewriter.setInsertionPoint(operation);
56 auto result = pattern.returningMatchAndRewrite(op, rewriter);
57 if (failed(result))
58 return failure();
59 return cast<LinalgOp>(result->getOperation());
60}
61
62//===----------------------------------------------------------------------===//
63// DecomposeOp
64//===----------------------------------------------------------------------===//
65
66DiagnosedSilenceableFailure
67transform::DecomposeOp::applyToOne(linalg::LinalgOp target,
68 SmallVectorImpl<Operation *> &results,
69 transform::TransformState &state) {
70 FailureOr<LinalgOp> windowedNhwc =
71 tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp,
72 Conv1DNwcWcfOp>>(target);
73 if (succeeded(windowedNhwc)) {
74 results.push_back(*windowedNhwc);
75 return DiagnosedSilenceableFailure::success();
76 }
77 FailureOr<LinalgOp> windowedNchw =
78 tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp,
79 Conv1DNcwFcwOp>>(target);
80 if (succeeded(windowedNchw)) {
81 results.push_back(*windowedNchw);
82 return DiagnosedSilenceableFailure::success();
83 }
84 FailureOr<LinalgOp> depthwise =
85 tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target);
86 if (succeeded(depthwise)) {
87 results.push_back(*depthwise);
88 return DiagnosedSilenceableFailure::success();
89 }
90 results.assign(1, nullptr);
91 return emitDefaultSilenceableFailure(target);
92}
93//===----------------------------------------------------------------------===//
94// FuseOp
95//===----------------------------------------------------------------------===//
96
97/// Apply a tiling transformation to all payload ops and store both the
98/// tiled operation as well as the created tile loops.
99static LogicalResult applyTilingToAll(
100 Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops,
101 transform::TransformResults &transformResults,
102 function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)>
103 applyFn) {
104 SmallVector<Operation *> tiledLinalgOps;
105 SmallVector<SmallVector<Operation *>> loopOps(numLoops);
106 for (unsigned int i = 0; i < numLoops; ++i)
107 loopOps[i].reserve(payloadOps.size());
108
109 for (Operation *target : payloadOps) {
110 auto tilingInterfaceOp = dyn_cast<TilingInterface>(target);
111 if (!tilingInterfaceOp)
112 return transformOp->emitError("only TilingInterface ops are supported");
113
114 TrivialPatternRewriter rewriter(target->getContext());
115 rewriter.setInsertionPoint(target);
116 FailureOr<scf::SCFTileAndFuseResult> tiledResults =
117 applyFn(tilingInterfaceOp);
118 if (failed(tiledResults))
119 return failure();
120
121 // Perform the replacement of tiled and fused values.
122 SmallVector<Operation *> opsToReplace{target};
123 llvm::append_range(opsToReplace, tiledResults->fusedProducers);
124 for (Operation *toReplace : opsToReplace) {
125 SmallVector<Value> replacements;
126 replacements.reserve(toReplace->getNumResults());
127 for (OpResult res : toReplace->getResults()) {
128 auto it = tiledResults->replacements.find(res);
129 if (it == tiledResults->replacements.end())
130 replacements.push_back(res);
131 else
132 replacements.push_back(it->getSecond());
133 }
134 rewriter.replaceOp(toReplace, replacements);
135 }
136
137 // Report back the relevant handles to the transform op.
138 tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front());
139 assert(tiledResults->loops.size() == numLoops &&(static_cast <bool> (tiledResults->loops.size() == numLoops
&& "Mismatched number of loops, tile and fuse transform should have "
"failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 141, __extension__ __PRETTY_FUNCTION__))
140 "Mismatched number of loops, tile and fuse transform should have "(static_cast <bool> (tiledResults->loops.size() == numLoops
&& "Mismatched number of loops, tile and fuse transform should have "
"failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 141, __extension__ __PRETTY_FUNCTION__))
141 "failed")(static_cast <bool> (tiledResults->loops.size() == numLoops
&& "Mismatched number of loops, tile and fuse transform should have "
"failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 141, __extension__ __PRETTY_FUNCTION__))
;
142 for (unsigned int i = 0; i < numLoops; ++i)
143 loopOps[i].push_back(tiledResults->loops[i]);
144 }
145
146 transformResults.set(transformOp->getOpResult(0), tiledLinalgOps);
147 for (unsigned int i = 0; i < numLoops; ++i)
148 transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]);
149
150 return success();
151}
152
153/// Parse a tiling-like operation that returns the tiled op as well as the
154/// created tile loops. The function counts the non-zero tile sizes to compute
155/// the number of results.
156static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result,
157 StringRef sizesAttrName) {
158 OpAsmParser::UnresolvedOperand targetOperand;
159 SMLoc opLoc = parser.getCurrentLocation();
160 if (parser.parseOperand(targetOperand) ||
161 parser.parseOptionalAttrDict(result.attributes))
162 return failure();
163 Attribute sizesAttr = result.attributes.get(sizesAttrName);
164 if (!sizesAttr)
165 return parser.emitError(opLoc)
166 << "expected '" << sizesAttrName << "' attribute";
167 auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>();
168 if (!sizesArrayAttr)
169 return parser.emitError(opLoc)
170 << "'" << sizesAttrName << "' attribute must be an array";
171 Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>();
172 size_t numExpectedLoops =
173 sizesArrayAttr.size() -
174 llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0);
175 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType));
176 if (parser.resolveOperand(targetOperand, pdlOpType, result.operands))
177 return failure();
178 return success();
179}
180
181DiagnosedSilenceableFailure
182transform::FuseOp::apply(mlir::transform::TransformResults &transformResults,
183 mlir::transform::TransformState &state) {
184 SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes());
185 SmallVector<int64_t> tileInterchange =
186 extractFromI64ArrayAttr(getTileInterchange());
187
188 scf::SCFTilingOptions tilingOptions;
189 tilingOptions.interchangeVector = tileInterchange;
190 tilingOptions = tilingOptions.setTileSizes(tileSizes);
191 scf::SCFTileAndFuseOptions tileAndFuseOptions;
192 tileAndFuseOptions.tilingOptions = tilingOptions;
193 LogicalResult result = applyTilingToAll(
194 getOperation(), state.getPayloadOps(getTarget()),
195 tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
196 [&](TilingInterface tilingInterfaceOp)
197 -> FailureOr<scf::SCFTileAndFuseResult> {
198 TrivialPatternRewriter rewriter(getContext());
199 return tileConsumerAndFuseProducerGreedilyUsingSCFForOp(
200 rewriter, tilingInterfaceOp, tileAndFuseOptions);
201 });
202 return failed(result) ? DiagnosedSilenceableFailure::definiteFailure()
203 : DiagnosedSilenceableFailure::success();
204}
205
206ParseResult transform::FuseOp::parse(OpAsmParser &parser,
207 OperationState &result) {
208 return parseTileLikeOp(
209 parser, result,
210 transform::FuseOp::getTileSizesAttrName(result.name).getValue());
211}
212
213void transform::FuseOp::print(OpAsmPrinter &p) {
214 p << ' ';
215 p << getTarget();
216 p.printOptionalAttrDict((*this)->getAttrs());
217}
218
219LogicalResult transform::FuseOp::verify() {
220 SmallVector<int64_t> permutation =
221 extractFromI64ArrayAttr(getTileInterchange());
222 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
223 if (!std::is_permutation(sequence.begin(), sequence.end(),
224 permutation.begin(), permutation.end())) {
225 return emitOpError() << "expects interchange to be a permutation, found "
226 << getTileInterchange();
227 }
228 return success();
229}
230
231//===----------------------------------------------------------------------===//
232// FuseIntoContainingOp
233//===----------------------------------------------------------------------===//
234
235void transform::FuseIntoContainingOp::build(OpBuilder &builder,
236 OperationState &result,
237 Value producerOp,
238 Value containingOp) {
239 result.addOperands({producerOp, containingOp});
240 result.addTypes(pdl::OperationType::get(builder.getContext()));
241}
242
243/// Find the first "extract" user of `producerOp` and tile it right before its
244/// use. The tiled op is fused under the `containingOp`.
245/// Return this fused op on success or nullptr if anything fails.
246static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter,
247 Diagnostic &diag,
248 Operation *producerOp,
249 Operation *containingOp) {
250 LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "Try to fuse a direct extract use\n"
; } } while (false)
;
251 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
252 if (!tileableProducer) {
253 diag.attachNote(producerOp->getLoc())
254 << "producer is not a TileableInterface: " << *producerOp;
255 return nullptr;
256 }
257
258 // Search the producer slices accessed within the containing operation.
259 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
260 // evolve into an interface.
261 auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
262 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
263 return sliceOp && containingOp->isProperAncestor(sliceOp);
264 });
265
266 // Find a fusion opportunity.
267 if (it == tileableProducer->getUsers().end()) {
268 diag.attachNote(tileableProducer->getLoc())
269 << "could not find fusion opportunity for: " << *tileableProducer;
270 return nullptr;
271 }
272 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
273
274 // Try to fuse the producer in-place.
275 OpBuilder::InsertionGuard guard(rewriter);
276 rewriter.setInsertionPoint(sliceOpToTile);
277
278 // Tile the producer.
279 int64_t resultNumber =
280 sliceOpToTile.getSource().cast<OpResult>().getResultNumber();
281 LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "resultNumber: "
<< resultNumber << "\n"; } } while (false)
;
282
283 FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
284 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
285 sliceOpToTile.getMixedSizes());
286 if (failed(tiledProducer)) {
287 diag.attachNote(tileableProducer->getLoc())
288 << "failed to tile producer op: " << *tileableProducer;
289 return nullptr;
290 }
291 LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "tiledProducer: "
<< *tiledProducer << "\n"; } } while (false)
;
292
293 // Replace the extract op.
294 Operation *fusedOp = tiledProducer->getDefiningOp();
295 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
296 rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
297 sliceOpToTile->getResult(0)
298 .getType()
299 .cast<RankedTensorType>()
300 .getShape());
301 assert(succeeded(maybeRankReduced) && "unexpected shape")(static_cast <bool> (succeeded(maybeRankReduced) &&
"unexpected shape") ? void (0) : __assert_fail ("succeeded(maybeRankReduced) && \"unexpected shape\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 301, __extension__ __PRETTY_FUNCTION__))
;
302 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
303 return fusedOp;
304}
305
306/// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure
307/// it is exactly the `containingOp`, otherwise bail.
308/// Then, find the first "extract" user of the tied block argument and tile it
309/// right before its "extract" use. The tiled op is fused under the
310/// `containingOp`.
311/// Return this fused op on success or nullptr if anything fails.
312static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
313 RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp,
314 Operation *containingOp) {
315 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "Try to fuse an extract use through block argument\n"
; } } while (false)
316 llvm::dbgs() << "Try to fuse an extract use through block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "Try to fuse an extract use through block argument\n"
; } } while (false)
;
317
318 auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
319 if (!tileableProducer) {
320 diag.attachNote(producerOp->getLoc())
321 << "producer is not a TileableInterface: " << *producerOp;
322 return nullptr;
323 }
324
325 // Search the first use by a "scf::ForeachThreadOp" user.
326 scf::ForeachThreadOp foreachThreadOp;
327 auto itProducerUses =
328 llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) {
329 foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner());
330 return foreachThreadOp;
331 });
332 // If it's not from the containing op, return.
333 if (!foreachThreadOp || foreachThreadOp != containingOp) {
334 diag.attachNote(tileableProducer->getLoc())
335 << "could not find a use by the containing op: " << *tileableProducer;
336 return nullptr;
337 }
338
339 // Search the producer slices accessed within the containing
340 // operation.
341 // TODO: Generalize to more extract/insert/parallel_insert triples.
342 // Maybe evolve into an interface.
343 OpOperand *pUse = &(*itProducerUses);
344 BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse);
345
346 // Search the producer slices accessed within the containing operation.
347 // TODO: Generalize to more extract/insert/parallel_insert triples, maybe
348 // evolve into an interface.
349 auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) {
350 auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
351 return sliceOp && containingOp->isProperAncestor(sliceOp);
352 });
353
354 // Find a fusion opportunity.
355 if (itBBArgUsers == bbArg.getUsers().end()) {
356 diag.attachNote(containingOp->getLoc())
357 << "could not find fusion opportunity for bbArg: " << bbArg;
358 return nullptr;
359 }
360 auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers);
361
362 // Try to fuse the producer in-place.
363 OpBuilder::InsertionGuard guard(rewriter);
364 rewriter.setInsertionPoint(sliceOpToTile);
365
366 // Replace the use in the tileableProducer before tiling: clone, replace and
367 // then tile.
368 int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber();
369 LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "resultNumber: "
<< resultNumber << "\n"; } } while (false)
;
370
371 // Gather destination tensors.
372 SmallVector<Value> destinationTensors;
373 if (failed(tensor::getOrCreateDestinations(
374 rewriter, tileableProducer->getLoc(), tileableProducer,
375 destinationTensors))) {
376 diag.attachNote(tileableProducer->getLoc())
377 << "failed to get destination tensors for: " << *tileableProducer;
378 return nullptr;
379 }
380
381 BlockAndValueMapping bvm;
382 bvm.map(destinationTensors[resultNumber], bbArg);
383 auto tileableProducerClone =
384 cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
385 auto scopeGuard =
386 llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });
387
388 // Tile the producer.
389 FailureOr<Value> tiledProducer =
390 tileableProducerClone.generateResultTileValue(
391 rewriter, resultNumber, sliceOpToTile.getMixedOffsets(),
392 sliceOpToTile.getMixedSizes());
393 if (failed(tiledProducer)) {
394 diag.attachNote(tileableProducer->getLoc())
395 << "failed to tile producer op: " << *tileableProducer;
396 return nullptr;
397 }
398 LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "tiledProducer: "
<< *tiledProducer << "\n"; } } while (false)
;
399
400 // Replace the extract op.
401 Operation *fusedOp = tiledProducer->getDefiningOp();
402 auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded(
403 rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber),
404 sliceOpToTile->getResult(0)
405 .getType()
406 .cast<RankedTensorType>()
407 .getShape());
408 assert(succeeded(maybeRankReduced) && "unexpected shape")(static_cast <bool> (succeeded(maybeRankReduced) &&
"unexpected shape") ? void (0) : __assert_fail ("succeeded(maybeRankReduced) && \"unexpected shape\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 408, __extension__ __PRETTY_FUNCTION__))
;
409 rewriter.replaceOp(sliceOpToTile, *maybeRankReduced);
410
411 // Replace the use in containingOp.
412 rewriter.updateRootInPlace(containingOp, [&]() {
413 containingOp->setOperand(pUse->getOperandNumber(),
414 destinationTensors.front());
415 });
416
417 return fusedOp;
418}
419
420static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag,
421 Operation *producerOp,
422 Operation *containingOp) {
423 LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "Try to fuse an use by cloning\n"
; } } while (false)
;
424
425 // Gather all uses inside the containing op.
426 SmallVector<OpOperand *> uses;
427 for (OpResult result : producerOp->getOpResults()) {
428 for (OpOperand &use : result.getUses()) {
429 if (containingOp->isProperAncestor(use.getOwner())) {
430 uses.push_back(&use);
431 continue;
432 }
433 // Cannot clone and fuse if the use is by the containing op itself: fail
434 // immediately.
435 if (containingOp == use.getOwner()) {
436 diag.attachNote(producerOp->getLoc())
437 << "producer op use by containing op cannot be fused by cloning";
438 return nullptr;
439 }
440 }
441 }
442
443 // Check for a non-empty list of fusion opportunities.
444 if (uses.empty()) {
445 diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning";
446 return nullptr;
447 }
448
449 // Clone and fuse inside the containing op.
450 Operation *fusedOp = nullptr;
451 OpOperand *use = uses.front();
452 // Parallel insert slice is not a valid clone destination.
453 // TODO: Generalize to other type of ops.
454 assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp
>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"
) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 455, __extension__ __PRETTY_FUNCTION__))
455 "Parallel insert slice is not a valid clone destination")(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp
>(use->getOwner()) && "Parallel insert slice is not a valid clone destination"
) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 455, __extension__ __PRETTY_FUNCTION__))
;
456 unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
457 LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "resultNumber: "
<< resultNumber << "\n"; } } while (false)
;
458
459 OpBuilder::InsertionGuard guard(rewriter);
460 rewriter.setInsertionPoint(use->getOwner());
461 fusedOp = rewriter.clone(*producerOp);
462 rewriter.updateRootInPlace(
463 use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
464
465 return fusedOp;
466}
467
468DiagnosedSilenceableFailure
469transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
470 transform::TransformState &state) {
471 SmallVector<Operation *> fusedOps;
472 ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp());
473 // If nothing to fuse, propagate success.
474 if (producerOps.empty()) {
475 results.set(getFusedOp().cast<OpResult>(),
476 SmallVector<mlir::Operation *>{});
477 return DiagnosedSilenceableFailure::success();
478 }
479 ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp());
480 if (containingOps.size() != 1) {
481 return emitDefiniteFailure()
482 << "requires exactly one containing_op handle (got "
483 << containingOps.size() << ")";
484 }
485 Operation *containingOp = containingOps.front();
486
487 // Helper function to find the next producer that should be fused. Take any
488 // producer that has a use inside the containing op.
489 SmallVector<Operation *> remainingProducers(producerOps.begin(),
490 producerOps.end());
491 auto getNextProducer = [&]() -> FailureOr<Operation *> {
492 for (const auto &it : enumerate(remainingProducers)) {
493 Operation *producerOp = it.value();
494 // The containing op may be a user of producerOp: use isAncestor.
495 int64_t numUsesInContainingOp =
496 llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
497 return containingOp->isAncestor(op);
498 });
499 // TODO: When resolving the TODO below (no duplicate ops), take an op
500 // that has no use among the remaining producers. This is a topological
501 // sorting.
502 if (numUsesInContainingOp > 0) {
503 if (numUsesInContainingOp == 1)
504 remainingProducers.erase(remainingProducers.begin() + it.index());
505 return producerOp;
506 }
507 }
508 return failure();
509 };
510
511 IRRewriter rewriter(getContext());
512 while (!remainingProducers.empty()) {
513 auto nextProducer = getNextProducer();
514 if (failed(nextProducer)) {
515 results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
516 Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
517 diag << "could not find next producer to fuse into container";
518 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
519 }
520
521 Operation *producerOp = *nextProducer;
522
523 // Default diagnostic, to be complemented with more failure information.
524 Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
525 diag << "could not fuse " << *producerOp << " into " << *containingOp;
526
527 // TODO: If there are multiple uses of the producer in the containing op,
528 // we currently tile/clone the op multiple times (once per use). In some
529 // cases, we can tile/clone once and reuse the value for each use.
530 // Futhermore, producers should then be traversed according to a
531 // topological sorting.
532 Operation *tiled =
533 tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp);
534 if (tiled) {
535 LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused a direct extract use\n"
<< *containingOp; } } while (false)
536 << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused a direct extract use\n"
<< *containingOp; } } while (false)
;
537 fusedOps.push_back(tiled);
538 continue;
539 }
540
541 Operation *tiledContainingOpOperand =
542 tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
543 rewriter, diag, producerOp, containingOp);
544 if (tiledContainingOpOperand) {
545 LLVM_DEBUG(llvm::dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n"
<< *containingOp; } } while (false)
546 << "\nFused an extract use through block argument\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n"
<< *containingOp; } } while (false)
547 << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n"
<< *containingOp; } } while (false)
;
548 fusedOps.push_back(tiledContainingOpOperand);
549 continue;
550 }
551
552 Operation *cloned =
553 cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp);
554 if (cloned) {
555 LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused an use by cloning\n"
<< *containingOp; } } while (false)
556 << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("linalg-transforms")) { llvm::dbgs() << "\nFused an use by cloning\n"
<< *containingOp; } } while (false)
;
557 fusedOps.push_back(cloned);
558 continue;
559 }
560 results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>());
561 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
562 }
563
564 results.set(getFusedOp().cast<OpResult>(), fusedOps);
565 return DiagnosedSilenceableFailure::success();
566}
567
568//===----------------------------------------------------------------------===//
569// GeneralizeOp
570//===----------------------------------------------------------------------===//
571
572DiagnosedSilenceableFailure
573transform::GeneralizeOp::applyToOne(linalg::LinalgOp target,
574 SmallVectorImpl<Operation *> &results,
575 transform::TransformState &state) {
576 // Exit early if no transformation is needed.
577 if (isa<GenericOp>(target)) {
578 results.push_back(target);
579 return DiagnosedSilenceableFailure::success();
580 }
581 FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target);
582 if (succeeded(generic)) {
583 results.push_back(generic->getOperation());
584 return DiagnosedSilenceableFailure::success();
585 }
586 results.assign(1, nullptr);
587 return emitDefaultSilenceableFailure(target);
588}
589
590//===----------------------------------------------------------------------===//
591// InterchangeOp
592//===----------------------------------------------------------------------===//
593
594DiagnosedSilenceableFailure
595transform::InterchangeOp::applyToOne(linalg::GenericOp target,
596 SmallVectorImpl<Operation *> &results,
597 transform::TransformState &state) {
598 ArrayRef<int64_t> interchangeVector = getIteratorInterchange();
599 // Exit early if no transformation is needed.
600 if (interchangeVector.empty()) {
601 results.push_back(target);
602 return DiagnosedSilenceableFailure::success();
603 }
604 TrivialPatternRewriter rewriter(target->getContext());
605 FailureOr<GenericOp> res =
606 interchangeGenericOp(rewriter, target,
607 SmallVector<unsigned>(interchangeVector.begin(),
608 interchangeVector.end()));
609 if (failed(res))
610 return DiagnosedSilenceableFailure::definiteFailure();
611 results.push_back(res->getOperation());
612 return DiagnosedSilenceableFailure::success();
613}
614
615LogicalResult transform::InterchangeOp::verify() {
616 ArrayRef<int64_t> permutation = getIteratorInterchange();
617 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size()));
618 if (!std::is_permutation(sequence.begin(), sequence.end(),
619 permutation.begin(), permutation.end())) {
620 return emitOpError()
621 << "expects iterator_interchange to be a permutation, found "
622 << getIteratorInterchange();
623 }
624 return success();
625}
626
627//===---------------------------------------------------------------------===//
628// MatchOp
629//===---------------------------------------------------------------------===//
630
631void transform::MatchOp::build(OpBuilder &builder, OperationState &result,
632 Value target, ArrayRef<StringRef> opNames) {
633 result.addOperands(target);
634 result.addAttribute(MatchOp::getOpsAttrName(result.name),
635 builder.getStrArrayAttr(opNames));
636 result.addTypes(pdl::OperationType::get(builder.getContext()));
637}
638
639DiagnosedSilenceableFailure
640transform::MatchOp::apply(transform::TransformResults &results,
641 transform::TransformState &state) {
642 llvm::StringSet<> strs;
643 if (getOps().has_value())
644 strs.insert(getOps()->getAsValueRange<StringAttr>().begin(),
645 getOps()->getAsValueRange<StringAttr>().end());
646
647 ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget());
648 if (payloadOps.size() != 1) {
649 results.set(getResult().cast<OpResult>(), {});
650 return emitDefiniteFailure("requires exactly one target handle");
651 }
652
653 SmallVector<Operation *> res;
654 auto matchFun = [&](Operation *op) {
655 if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
656 return;
657
658 // Interfaces cannot be matched by name, just by ID.
659 // So we specifically encode the interfaces we care about for this op.
660 if (getInterface().has_value()) {
661 auto iface = getInterface().value();
662 if (iface == transform::MatchInterfaceEnum::LinalgOp &&
663 !isa<linalg::LinalgOp>(op))
664 return;
665 if (iface == transform::MatchInterfaceEnum::TilingInterface &&
666 isa<TilingInterface>(op))
667 return;
668 }
669
670 // Check if all specified attributes match.
671 if (getOpAttrs().has_value()) {
672 DictionaryAttr opAttrs = getOpAttrs().value();
673 for (NamedAttribute attr : opAttrs) {
674 if (attr.getName() == getInterfaceAttrName() ||
675 attr.getName() == getOpsAttrName())
676 continue;
677 if (!op->hasAttr(attr.getName()))
678 return;
679 if (op->getAttr(attr.getName()) != attr.getValue())
680 return;
681 }
682 }
683
684 if (getFilterResultType().has_value()) {
685 Type t = getFilterResultType().value();
686 if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
687 return;
688 }
689
690 // All constraints are satisfied.
691 res.push_back(op);
692 return;
693 };
694
695 payloadOps.front()->walk(matchFun);
696 results.set(getResult().cast<OpResult>(), res);
697 return DiagnosedSilenceableFailure::success();
698}
699
700//===---------------------------------------------------------------------===//
701// MultiTileSizesOp
702//===---------------------------------------------------------------------===//
703
704DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne(
705 LinalgOp target, SmallVector<Operation *> &results, TransformState &state) {
706 OpBuilder builder(target.getContext());
707 builder.setInsertionPoint(target);
708 OpFoldResult targetSize = builder.getIndexAttr(getTargetSize());
709 OpFoldResult divisor = builder.getIndexAttr(getDivisor());
710 FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes(
711 builder, target, getDimension(), targetSize, divisor);
712 if (failed(spec)) {
713 return emitSilenceableError() << "could not generate tile size computation";
714 }
715
716 AffineExpr s0 = builder.getAffineSymbolExpr(0);
717 AffineExpr s1 = builder.getAffineSymbolExpr(1);
718 Operation *splitPoint =
719 makeComposedAffineApply(builder, target.getLoc(), s0 * s1,
720 {spec->lowTileSize, spec->lowTripCount});
721 Operation *lowTileSize = spec->lowTileSize.getDefiningOp();
722 Operation *highTileSize = spec->highTileSize.getDefiningOp();
723 assert(lowTileSize && highTileSize && splitPoint &&(static_cast <bool> (lowTileSize && highTileSize
&& splitPoint && "tile sizes are not produced by operations"
) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 724, __extension__ __PRETTY_FUNCTION__))
724 "tile sizes are not produced by operations")(static_cast <bool> (lowTileSize && highTileSize
&& splitPoint && "tile sizes are not produced by operations"
) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\""
, "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp"
, 724, __extension__ __PRETTY_FUNCTION__))
;
725 results.reserve(results.size() + 3);
726 results.push_back(lowTileSize);
727 results.push_back(highTileSize);
728 results.push_back(splitPoint);
729 return DiagnosedSilenceableFailure::success();
730}
731
732void transform::MultiTileSizesOp::getEffects(
733 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
734 onlyReadsHandle(getTarget(), effects);
735 producesHandle(getResults(), effects);
736 modifiesPayload(effects);
737}
738
739//===---------------------------------------------------------------------===//
740// PadOp
741//===---------------------------------------------------------------------===//
742
743DiagnosedSilenceableFailure
744transform::PadOp::applyToOne(linalg::LinalgOp target,
745 SmallVectorImpl<Operation *> &results,
746 transform::TransformState &state) {
747 // Convert the integer packing flags to booleans.
748 SmallVector<bool> packPaddings;
749 for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings()))
750 packPaddings.push_back(static_cast<bool>(packPadding));
751
752 // Convert the padding values to attributes.
753 SmallVector<Attribute> paddingValues;
754 for (auto const &it :
755 llvm::zip(getPaddingValues(), target->getOperandTypes())) {
756 auto attr = std::get<0>(it).dyn_cast<TypedAttr>();
757 if (!attr) {
758 emitOpError("expects padding values to be typed attributes");
759 return DiagnosedSilenceableFailure::definiteFailure();
760 }
761 Type elementType = getElementTypeOrSelf(std::get<1>(it));
762 // Try to parse string attributes to obtain an attribute of element type.
763 if (auto stringAttr = attr.dyn_cast<StringAttr>()) {
764 paddingValues.push_back(
765 parseAttribute(attr.cast<StringAttr>(), elementType));
766 if (!paddingValues.back()) {
767 auto diag = this->emitOpError("expects a padding that parses to ")
768 << elementType << ", got " << std::get<0>(it);
769 diag.attachNote(target.getLoc()) << "when applied to this op";
770 return DiagnosedSilenceableFailure::definiteFailure();
771 }
772 continue;
773 }
774 // Otherwise, add the attribute directly.
775 if (attr.getType() != elementType) {
776 auto diag = this->emitOpError("expects a padding value of type ")
777 << elementType << ", got " << attr;
778 diag.attachNote(target.getLoc()) << "when applied to this op";
779 return DiagnosedSilenceableFailure::definiteFailure();
780 }
781 paddingValues.push_back(attr);
782 }
783
784 // Extract the transpose vectors.
785 SmallVector<SmallVector<int64_t>> transposePaddings;
786 for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>())
787 transposePaddings.push_back(
788 extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>()));
789
790 LinalgPaddingOptions paddingOptions;
791 paddingOptions.setPaddingValues(paddingValues);
792 paddingOptions.setPaddingDimensions(
793 extractFromI64ArrayAttr(getPaddingDimensions()));
794 paddingOptions.setPackPaddings(packPaddings);
795 paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings()));
796 paddingOptions.setTransposePaddings(transposePaddings);
797
798 FailureOr<LinalgOp> result =
799 tryApply<LinalgPaddingPattern>(target, paddingOptions);
800 if (succeeded(result)) {
801 results.push_back(result->getOperation());
802 return DiagnosedSilenceableFailure::success();
803 }
804
805 results.assign(1, nullptr);
806 return emitDefaultSilenceableFailure(target);
807}
808
809LogicalResult transform::PadOp::verify() {
810 SmallVector<int64_t> packPaddings =
811 extractFromI64ArrayAttr(getPackPaddings());
812 if (any_of(packPaddings, [](int64_t packPadding) {
813 return packPadding != 0 && packPadding != 1;
814 })) {
815 return emitOpError()
816 << "expects pack_paddings to contain booleans (0/1), found "
817 << getPackPaddings();
818 }
819
820 SmallVector<int64_t> paddingDimensions =
821 extractFromI64ArrayAttr(getPaddingDimensions());
822 if (any_of(paddingDimensions,
823 [](int64_t paddingDimension) { return paddingDimension < 0; })) {
824 return emitOpError() << "expects padding_dimensions to contain positive "
825 "integers, found "
826 << getPaddingDimensions();
827 }
828
829 SmallVector<int64_t> hoistPaddings =
830 extractFromI64ArrayAttr(getHoistPaddings());
831 if (any_of(hoistPaddings,
832 [](int64_t hoistPadding) { return hoistPadding < 0; })) {
833 return emitOpError()
834 << "expects hoist_paddings to contain positive integers, found "
835 << getHoistPaddings();
836 }
837
838 ArrayAttr transposes = getTransposePaddings();
839 for (Attribute attr : transposes) {
840 SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr);
841 auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size()));
842 if (!std::is_permutation(sequence.begin(), sequence.end(),
843 transpose.begin(), transpose.end())) {
844 return emitOpError()
845 << "expects transpose_paddings to be a permutation, found "
846 << attr;
847 }
848 }
849 return success();
850}
851
852//===----------------------------------------------------------------------===//
853// PromoteOp
854//===----------------------------------------------------------------------===//
855
856DiagnosedSilenceableFailure
857transform::PromoteOp::applyToOne(linalg::LinalgOp target,
858 SmallVectorImpl<Operation *> &results,
859 transform::TransformState &state) {
860 LinalgPromotionOptions promotionOptions;
861 if (!getOperandsToPromote().empty())
862 promotionOptions = promotionOptions.setOperandsToPromote(
863 extractFromI64ArrayAttr(getOperandsToPromote()));
864 if (getUseFullTilesByDefault())
865 promotionOptions = promotionOptions.setUseFullTileBuffersByDefault(
866 getUseFullTilesByDefault());
867 if (getUseAlloca())
868 promotionOptions = promotionOptions.setUseAlloca(getUseAlloca());
869 if (!getUseFullTileBuffers().empty())
870 promotionOptions = promotionOptions.setUseFullTileBuffers(
871 llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>()));
872 if (getAlignment().has_value())
873 promotionOptions = promotionOptions.setAlignment(*getAlignment());
874
875 if (failed(promoteSubviewsPrecondition(target, promotionOptions)))
876 return emitDefaultDefiniteFailure(target);
877
878 TrivialPatternRewriter rewriter(target->getContext());
879 rewriter.setInsertionPoint(target);
880 FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions);
881 if (failed(res))
882 return emitDefaultDefiniteFailure(target);
883 results.push_back(target);
884 return DiagnosedSilenceableFailure::success();
885}
886
887//===----------------------------------------------------------------------===//
888// ReplaceOp
889//===----------------------------------------------------------------------===//
890
891DiagnosedSilenceableFailure
892transform::ReplaceOp::apply(TransformResults &transformResults,
893 TransformState &state) {
894 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
895
896 // Check for invalid targets.
897 for (Operation *target : payload) {
898 if (target->getNumOperands() > 0)
899 return emitDefiniteFailure() << "expected target without operands";
900 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
901 target->getNumRegions() > 0)
902 return emitDefiniteFailure()
903 << "expected target that is isloated from above";
904 }
905
906 // Clone and replace.
907 IRRewriter rewriter(getContext());
908 Operation *pattern = &getBodyRegion().front().front();
909 SmallVector<Operation *> replacements;
910 for (Operation *target : payload) {
911 if (getOperation()->isAncestor(target))
912 continue;
913 rewriter.setInsertionPoint(target);
914 Operation *replacement = rewriter.clone(*pattern);
915 rewriter.replaceOp(target, replacement->getResults());
916 replacements.push_back(replacement);
917 }
918 transformResults.set(getReplacement().cast<OpResult>(), replacements);
919 return DiagnosedSilenceableFailure::success();
920}
921
922void transform::ReplaceOp::getEffects(
923 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
924 consumesHandle(getTarget(), effects);
925 producesHandle(getReplacement(), effects);
926 modifiesPayload(effects);
927}
928
929LogicalResult transform::ReplaceOp::verify() {
930 if (!getBodyRegion().hasOneBlock())
931 return emitOpError() << "expected one block";
932 if (std::distance(getBodyRegion().front().begin(),
933 getBodyRegion().front().end()) != 1)
934 return emitOpError() << "expected one operation in block";
935 Operation *replacement = &getBodyRegion().front().front();
936 if (replacement->getNumOperands() > 0)
937 return replacement->emitOpError()
938 << "expected replacement without operands";
939 if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() &&
940 replacement->getNumRegions() > 0)
941 return replacement->emitOpError()
942 << "expect op that is isolated from above";
943 return success();
944}
945
946//===----------------------------------------------------------------------===//
947// ScalarizeOp
948//===----------------------------------------------------------------------===//
949
950DiagnosedSilenceableFailure
951transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
952 SmallVectorImpl<Operation *> &results,
953 transform::TransformState &state) {
954 scf::SCFTilingOptions tilingOptions;
955 tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) {
956 SmallVector<Value, 4> tileSizes;
957 Location loc = target.getLoc();
958 SmallVector<OpFoldResult> allShapeSizes =
959 target.createFlatListOfOperandDims(b, loc);
960 AffineMap map = target.getShapesToLoopsMap();
961 if (!map)
962 return tileSizes;
963 IRRewriter rewriter(b);
964 SmallVector<OpFoldResult> shapeSizes =
965 makeComposedFoldedMultiResultAffineApply(rewriter, loc, map,
966 allShapeSizes);
967 // If the shape size is dynamic, tile by 1.
968 // Otherwise, do not tile (i.e. tile size 0).
969 for (OpFoldResult shapeSize : shapeSizes) {
970 tileSizes.push_back(getConstantIntValue(shapeSize)
971 ? b.create<arith::ConstantIndexOp>(loc, 0)
972 : b.create<arith::ConstantIndexOp>(loc, 1));
973 }
974 return tileSizes;
975 });
976 SmallVector<int64_t> emptyTileSizes;
977 TrivialPatternRewriter rewriter(getContext());
978 rewriter.setInsertionPoint(target);
979 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
980 rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions);
981 if (failed(maybeTilingResult))
982 return emitDefaultDefiniteFailure(target);
983
984 results.append(maybeTilingResult->tiledOps);
985 return DiagnosedSilenceableFailure::success();
986}
987
988//===----------------------------------------------------------------------===//
989// SplitOp
990//===----------------------------------------------------------------------===//
991
992DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results,
993 TransformState &state) {
994 // Collect the dynamic split points if provided.
995 ArrayRef<Operation *> payload = state.getPayloadOps(getTarget());
996 TrivialPatternRewriter rewriter(getContext());
997 SmallVector<OpFoldResult> splitPoints;
998 splitPoints.reserve(payload.size());
999 if (getDynamicSplitPoint()) {
1000 auto diag = DiagnosedSilenceableFailure::success();
1001 splitPoints = llvm::to_vector(llvm::map_range(
1002 state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) {
1003 if (op->getNumResults() != 1 ||
1004 !op->getResult(0).getType().isIndex()) {
1005 diag = emitSilenceableError()
1006 << "expected dynamic split point handle to point to a "
1007 "single-result index-typed op";
1008 diag.attachNote(op->getLoc()) << "dynamic split point";
1009 }
1010 return OpFoldResult(op->getResult(0));
1011 }));
1012 if (diag.isSilenceableFailure()) {
1013 results.set(getFirst().cast<OpResult>(), {});
1014 results.set(getSecond().cast<OpResult>(), {});
1015 return diag;
1016 }
1017
1018 if (splitPoints.size() != payload.size()) {
1019 return emitDefiniteFailure()
1020 << "expected the dynamic split point handle to point to as "
1021 "many operations ("
1022 << splitPoints.size() << ") as the target handle ("
1023 << payload.size() << ")";
1024 }
1025 } else {
1026 splitPoints.resize(payload.size(),
1027 rewriter.getIndexAttr(getStaticSplitPoint()));
1028 }
1029
1030 // Split each target operation.
1031 SmallVector<Operation *> first, second;
1032 for (const auto &pair : llvm::zip(payload, splitPoints)) {
1033 Operation *target = std::get<0>(pair);
1034 auto linalgOp = dyn_cast<LinalgOp>(target);
1035 if (!linalgOp) {
1036 auto diag = emitSilenceableError() << "only applies to structured ops";
1037 diag.attachNote(target->getLoc()) << "target op";
1038 results.set(getFirst().cast<OpResult>(), {});
1039 results.set(getSecond().cast<OpResult>(), {});
1040 return diag;
1041 }
1042
1043 if (getDimension() >= linalgOp.getNumLoops()) {
1044 auto diag = emitSilenceableError() << "dimension " << getDimension()
1045 << " does not exist in target op";
1046 diag.attachNote(target->getLoc()) << "target op";
1047 results.set(getFirst().cast<OpResult>(), {});
1048 results.set(getSecond().cast<OpResult>(), {});
1049 return diag;
1050 }
1051
1052 rewriter.setInsertionPoint(linalgOp);
1053 std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp(
1054 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
1055 getDimension(), std::get<1>(pair));
1056 }
1057
1058 results.set(getFirst().cast<OpResult>(), first);
1059 results.set(getSecond().cast<OpResult>(), second);
1060 return DiagnosedSilenceableFailure::success();
1061}
1062
1063void SplitOp::getEffects(
1064 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1065 consumesHandle(getTarget(), effects);
1066 if (getDynamicSplitPoint())
1067 onlyReadsHandle(getDynamicSplitPoint(), effects);
1068 producesHandle(getResults(), effects);
1069 modifiesPayload(effects);
1070}
1071
1072ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) {
1073 OpAsmParser::UnresolvedOperand target, dynamicSplitPoint;
1074 IntegerAttr staticSplitPoint;
1075 auto pdlOperationType =
1076 pdl::OperationType::get(parser.getBuilder().getContext());
1077 if (parser.parseOperand(target) ||
1
Taking false branch
1078 parser.resolveOperand(target, pdlOperationType, result.operands) ||
1079 parser.parseKeyword("after"))
1080 return failure();
1081
1082 OptionalParseResult dynamicPointParseResult =
1083 parser.parseOptionalOperand(dynamicSplitPoint);
1084 if (!dynamicPointParseResult.has_value()) {
2
Assuming the condition is true
3
Taking true branch
1085 int64_t staticSplitPointValue;
4
'staticSplitPointValue' declared without an initial value
1086 if (failed(parser.parseInteger(staticSplitPointValue)))
5
Calling 'AsmParser::parseInteger'
13
Returning from 'AsmParser::parseInteger'
14
Taking false branch
1087 return failure();
1088
1089 staticSplitPoint =
1090 parser.getBuilder().getI64IntegerAttr(staticSplitPointValue);
15
1st function call argument is an uninitialized value
1091 } else {
1092 if (failed(*dynamicPointParseResult) ||
1093 parser.resolveOperand(dynamicSplitPoint, pdlOperationType,
1094 result.operands)) {
1095 return failure();
1096 }
1097
1098 staticSplitPoint =
1099 parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic);
1100 }
1101
1102 result.addAttribute(
1103 SplitOp::getStaticSplitPointAttrName(result.name).getValue(),
1104 staticSplitPoint);
1105 if (failed(parser.parseOptionalAttrDict(result.attributes)))
1106 return failure();
1107
1108 result.addTypes({pdlOperationType, pdlOperationType});
1109 return success();
1110}
1111
1112void SplitOp::print(OpAsmPrinter &printer) {
1113 printer << " " << getTarget() << " after ";
1114 int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint());
1115 if (staticSplitSize != ShapedType::kDynamic)
1116 printer << staticSplitSize;
1117 else
1118 printer << getDynamicSplitPoint();
1119 printer << " ";
1120 printer.printOptionalAttrDict(getOperation()->getAttrs(),
1121 {getStaticSplitPointAttrName()});
1122}
1123
1124LogicalResult SplitOp::verify() {
1125 if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^
1126 (getDynamicSplitPoint() == nullptr)) {
1127 return emitOpError() << "expects either a dynamic or a static split "
1128 "point to be provided";
1129 }
1130 return success();
1131}
1132
1133//===----------------------------------------------------------------------===//
1134// SplitReductionOp
1135//===----------------------------------------------------------------------===//
1136
1137void transform::SplitReductionOp::build(
1138 OpBuilder &builder, OperationState &result, Value target,
1139 int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel,
1140 bool useScalingAlgorithm, bool useAlloc) {
1141 MLIRContext *ctx = builder.getContext();
1142 result.addOperands(target);
1143 result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name),
1144 builder.getI64IntegerAttr(splitFactor));
1145 result.addAttribute(
1146 SplitReductionOp::getInsertSplitDimensionAttrName(result.name),
1147 builder.getI64IntegerAttr(insertSplitDimension));
1148 if (innerParallel) {
1149 result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name),
1150 builder.getUnitAttr());
1151 }
1152 if (useScalingAlgorithm) {
1153 result.addAttribute(
1154 SplitReductionOp::getUseScalingAlgorithmAttrName(result.name),
1155 builder.getUnitAttr());
1156 }
1157 if (useAlloc) {
1158 result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name),
1159 builder.getUnitAttr());
1160 }
1161 auto resultType = pdl::OperationType::get(ctx);
1162 result.addTypes({resultType, resultType, resultType, resultType});
1163}
1164
1165DiagnosedSilenceableFailure
1166transform::SplitReductionOp::applyToOne(linalg::LinalgOp target,
1167 SmallVectorImpl<Operation *> &results,
1168 transform::TransformState &state) {
1169 ControlSplitReductionFn splitFn = [&](LinalgOp) {
1170 return linalg::SplitReductionOptions{int64_t(getSplitFactor()),
1171 unsigned(getInsertSplitDimension()),
1172 bool(getInnerParallel())};
1173 };
1174 TrivialPatternRewriter rewriter(getContext());
1175 rewriter.setInsertionPoint(target);
1176 FailureOr<SplitReductionResult> splitResult =
1177 (getUseScalingAlgorithm())
1178 ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc())
1179 : splitReduction(rewriter, target, splitFn, getUseAlloc());
1180 if (failed(splitResult))
1181 return emitDefaultDefiniteFailure(target);
1182
1183 results.push_back(splitResult->initOrAlloc);
1184 results.push_back(splitResult->fillOp);
1185 results.push_back(splitResult->splitLinalgOp);
1186 results.push_back(splitResult->resultCombiningLinalgOp);
1187 return DiagnosedSilenceableFailure::success();
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// TileReductionUsingScfOp
1192//===----------------------------------------------------------------------===//
1193
1194void transform::TileReductionUsingScfOp::build(
1195 OpBuilder &builder, OperationState &result, Value target,
1196 ArrayRef<int64_t> staticTileSizes) {
1197 // Call the default builder.
1198 // This is future-proof re mixed static-dynamic and setting up the proper
1199 // operands segment sizes attributes for multiple variadic operands.
1200 // In the absence of this, horrible bugs ensue.
1201 // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
1202 MLIRContext *ctx = builder.getContext();
1203 auto opTy = pdl::OperationType::get(ctx);
1204 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
1205 build(builder, result,
1206 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
1207 /*target=*/target,
1208 /*tile_sizes=*/staticTileSizesAttr);
1209}
1210
1211DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne(
1212 linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
1213 transform::TransformState &state) {
1214 TrivialPatternRewriter rewriter(getContext());
1215 rewriter.setInsertionPoint(target);
1216 FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
1217 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1218 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
1219
1220 if (failed(result))
1221 return emitDefaultSilenceableFailure(target);
1222 results.push_back(result->loops.front());
1223 results.push_back(result->initialOp);
1224 results.push_back(result->parallelTiledOp);
1225 results.push_back(result->mergeOp);
1226 return DiagnosedSilenceableFailure::success();
1227}
1228
1229//===----------------------------------------------------------------------===//
1230// TileReductionUsingForeachThreadOp
1231//===----------------------------------------------------------------------===//
1232
1233void transform::TileReductionUsingForeachThreadOp::build(
1234 OpBuilder &builder, OperationState &result, Value target,
1235 ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes,
1236 ArrayAttr mapping) {
1237 // Call the default builder.
1238 // This is future-proof re mixed static-dynamic and setting up the proper
1239 // operands segment sizes attributes for multiple variadic operands.
1240 // In the absence of this, horrible bugs ensue.
1241 // TODO: support mixed static-dynamic (see TileToForeachThreadOp).
1242 MLIRContext *ctx = builder.getContext();
1243 auto opTy = pdl::OperationType::get(ctx);
1244 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
1245 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
1246 build(builder, result,
1247 /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy},
1248 /*target=*/target,
1249 /*num_threads=*/staticNumThreadsAttr,
1250 /*tile_sizes=*/staticTileSizesAttr,
1251 /*mapping=*/mapping);
1252}
1253
1254DiagnosedSilenceableFailure
1255transform::TileReductionUsingForeachThreadOp::applyToOne(
1256 linalg::LinalgOp target, SmallVectorImpl<Operation *> &results,
1257 transform::TransformState &state) {
1258 TrivialPatternRewriter rewriter(getContext());
1259 rewriter.setInsertionPoint(target);
1260 SmallVector<OpFoldResult> numThreads =
1261 getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
1262 SmallVector<OpFoldResult> tileSizes =
1263 getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()));
1264 FailureOr<linalg::ForeachThreadReductionTilingResult> result =
1265 linalg::tileReductionUsingForeachThread(
1266 rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
1267 numThreads, tileSizes, getMapping());
1268
1269 if (failed(result)) {
1270 results.assign(3, nullptr);
1271 Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark);
1272 diag << "could not tile reduction in target.";
1273 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
1274 }
1275 results.push_back(result->loops);
1276 results.push_back(result->initialOp);
1277 results.push_back(result->parallelTiledOp);
1278 results.push_back(result->mergeOp);
1279 return DiagnosedSilenceableFailure::success();
1280}
1281
1282//===----------------------------------------------------------------------===//
1283// TileOp
1284//===----------------------------------------------------------------------===//
1285void transform::TileOp::build(OpBuilder &builder, OperationState &result,
1286 Value target, ArrayRef<int64_t> staticTileSizes,
1287 ArrayRef<int64_t> interchange) {
1288 return build(builder, result,
1289 /*target=*/target,
1290 /*mixedTileSizes=*/
1291 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
1292 interchange);
1293}
1294
1295void transform::TileOp::build(OpBuilder &builder, OperationState &result,
1296 Value target,
1297 ArrayRef<OpFoldResult> mixedTileSizes,
1298 ArrayRef<int64_t> interchange) {
1299 SmallVector<int64_t> staticTileSizes;
1300 SmallVector<Value> dynamicTileSizes;
1301 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
1302 // Call the default builder which sets up the proper operands segment sizes
1303 // attributes for multiple variadic operands. In the absence of this, horrible
1304 // bugs ensue.
1305 MLIRContext *ctx = builder.getContext();
1306 auto operationType = pdl::OperationType::get(ctx);
1307 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
1308 build(builder, result,
1309 /*resultTypes=*/TypeRange{operationType, operationType},
1310 /*target=*/target,
1311 /*dynamic_sizes=*/dynamicTileSizes,
1312 /*static_sizes=*/staticTileSizesAttr,
1313 /*interchange=*/builder.getDenseI64ArrayAttr(interchange));
1314}
1315
1316DiagnosedSilenceableFailure
1317transform::TileOp::apply(TransformResults &transformResults,
1318 TransformState &state) {
1319 ArrayRef<int64_t> tileSizes = getStaticSizes();
1320
1321 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1322 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
1323 dynamicSizeProducers.reserve(getDynamicSizes().size());
1324 for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
1325 dynamicSizeProducers.push_back(
1326 state.getPayloadOps(dynamicSizeProducerHandle));
1327
1328 if (dynamicSizeProducers.back().size() != targets.size()) {
1329 DiagnosedSilenceableFailure diag =
1330 emitSilenceableError()
1331 << "expected as many dynamic size-producing operations ("
1332 << dynamicSizeProducers.back().size() << ") as target ops ("
1333 << targets.size() << ")";
1334 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1335 return diag;
1336 }
1337
1338 for (Operation *op : dynamicSizeProducers.back()) {
1339 if (op->getNumResults() == 1 &&
1340 op->getResult(0).getType().isa<IndexType>())
1341 continue;
1342 DiagnosedSilenceableFailure diag =
1343 emitSilenceableError() << "expected sizes to be produced by ops "
1344 "with a single index-type result";
1345 diag.attachNote(op->getLoc()) << "size producer op";
1346 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1347 return diag;
1348 }
1349 }
1350
1351 SmallVector<Operation *> tiled;
1352 SmallVector<SmallVector<Operation *, 4>, 4> loops;
1353 loops.resize(getLoops().size());
1354 for (auto &en : llvm::enumerate(targets)) {
1355 auto linalgOp = dyn_cast<LinalgOp>(en.value());
1356 if (!linalgOp) {
1357 DiagnosedSilenceableFailure diag = emitSilenceableError()
1358 << "only linalg ops are supported";
1359 diag.attachNote(en.value()->getLoc()) << "target op";
1360 return diag;
1361 }
1362
1363 scf::SCFTilingOptions tilingOptions;
1364 unsigned index = en.index();
1365 if (!tileSizes.empty()) {
1366 tilingOptions.setTileSizeComputationFunction(
1367 [&, index](OpBuilder &b, Operation *) {
1368 SmallVector<Value, 4> sizes;
1369 sizes.reserve(tileSizes.size());
1370 unsigned dynamicIdx = 0;
1371 for (OpFoldResult ofr : getMixedSizes()) {
1372 if (auto attr = ofr.dyn_cast<Attribute>()) {
1373 sizes.push_back(b.create<arith::ConstantIndexOp>(
1374 getLoc(), attr.cast<IntegerAttr>().getInt()));
1375 } else {
1376 sizes.push_back(
1377 dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
1378 }
1379 }
1380 return sizes;
1381 });
1382 }
1383
1384 tilingOptions.setInterchange(getInterchange());
1385 TrivialPatternRewriter rewriter(linalgOp.getContext());
1386 FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp(
1387 rewriter, cast<TilingInterface>(linalgOp.getOperation()),
1388 tilingOptions);
1389 if (failed(maybeTilingResult))
1390 return DiagnosedSilenceableFailure::definiteFailure();
1391
1392 if (linalgOp.hasBufferSemantics())
1393 rewriter.eraseOp(linalgOp);
1394 else
1395 rewriter.replaceOp(linalgOp,
1396 maybeTilingResult->loops.front()->getResults());
1397
1398 tiled.append(maybeTilingResult->tiledOps);
1399 for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops))
1400 loops[en2.index()].push_back(en2.value());
1401 }
1402
1403 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
1404 for (const auto &en : llvm::enumerate(loops))
1405 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
1406
1407 return DiagnosedSilenceableFailure::success();
1408}
1409
1410SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() {
1411 ValueRange dynamic = getDynamicSizes();
1412 ArrayRef<int64_t> tileSizes = getStaticSizes();
1413 SmallVector<OpFoldResult> results;
1414 results.reserve(tileSizes.size());
1415 unsigned dynamicPos = 0;
1416 Builder builder(getContext());
1417 for (int64_t size : tileSizes) {
1418 if (size == ShapedType::kDynamic) {
1419 results.push_back(dynamic[dynamicPos++]);
1420 } else {
1421 results.push_back(builder.getIndexAttr(size));
1422 }
1423 }
1424 return results;
1425}
1426
1427// We want to parse `DenseI64ArrayAttr` using the short form without the
1428// `array` prefix to be consistent in the IR with `parseDynamicIndexList`.
1429ParseResult parseOptionalInterchange(OpAsmParser &parser,
1430 OperationState &result) {
1431 if (succeeded(parser.parseOptionalLBrace())) {
1432 if (failed(parser.parseKeyword("interchange")))
1433 return parser.emitError(parser.getNameLoc()) << "expect `interchange`";
1434 if (failed(parser.parseEqual()))
1435 return parser.emitError(parser.getNameLoc()) << "expect `=`";
1436 result.addAttribute("interchange",
1437 DenseI64ArrayAttr::parse(parser, Type{}));
1438 if (failed(parser.parseRBrace()))
1439 return parser.emitError(parser.getNameLoc()) << "expect `}`";
1440 }
1441 return success();
1442}
1443
1444void printOptionalInterchange(OpAsmPrinter &p,
1445 ArrayRef<int64_t> interchangeVals) {
1446 if (!interchangeVals.empty()) {
1447 p << " {interchange = [";
1448 llvm::interleaveComma(interchangeVals, p,
1449 [&](int64_t integer) { p << integer; });
1450 p << "]}";
1451 }
1452}
1453
1454ParseResult transform::TileOp::parse(OpAsmParser &parser,
1455 OperationState &result) {
1456 OpAsmParser::UnresolvedOperand target;
1457 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
1458 DenseI64ArrayAttr staticSizes;
1459 auto pdlOperationType = pdl::OperationType::get(parser.getContext());
1460 if (parser.parseOperand(target) ||
1461 parser.resolveOperand(target, pdlOperationType, result.operands) ||
1462 parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
1463 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
1464 return ParseResult::failure();
1465
1466 // Parse optional interchange.
1467 if (failed(parseOptionalInterchange(parser, result)))
1468 return ParseResult::failure();
1469 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
1470 size_t numExpectedLoops =
1471 staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
1472 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1473 return success();
1474}
1475
1476void TileOp::print(OpAsmPrinter &p) {
1477 p << ' ' << getTarget();
1478 printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
1479 printOptionalInterchange(p, getInterchange());
1480}
1481
1482void transform::TileOp::getEffects(
1483 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1484 consumesHandle(getTarget(), effects);
1485 onlyReadsHandle(getDynamicSizes(), effects);
1486 producesHandle(getTiledLinalgOp(), effects);
1487 producesHandle(getLoops(), effects);
1488 modifiesPayload(effects);
1489}
1490
1491//===----------------------------------------------------------------------===//
1492// TileToForeachThreadOp
1493//===----------------------------------------------------------------------===//
1494
1495void transform::TileToForeachThreadOp::build(OpBuilder &builder,
1496 OperationState &result,
1497 Value target,
1498 ArrayRef<int64_t> staticTileSizes,
1499 transform::TileSizesSpec,
1500 ArrayAttr mapping) {
1501 return build(builder, result,
1502 /*target=*/target,
1503 /*mixedTileSizes=*/
1504 getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)),
1505 /*_=*/TileSizesSpec(),
1506 /*mapping=*/mapping);
1507}
1508
1509void transform::TileToForeachThreadOp::build(
1510 OpBuilder &builder, OperationState &result, Value target,
1511 ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec,
1512 ArrayAttr mapping) {
1513 SmallVector<int64_t> staticTileSizes;
1514 SmallVector<Value> dynamicTileSizes;
1515 dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes);
1516 // Call the default builder which sets up the proper operands segment sizes
1517 // attributes for multiple variadic operands. In the absence of this, horrible
1518 // bugs ensue.
1519 MLIRContext *ctx = builder.getContext();
1520 auto operationType = pdl::OperationType::get(ctx);
1521 auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes);
1522 build(builder, result,
1523 /*resultTypes=*/TypeRange{operationType, operationType},
1524 /*target=*/target,
1525 /*num_threads=*/ValueRange{},
1526 /*tile_sizes=*/dynamicTileSizes,
1527 /*packed_num_threads=*/Value(),
1528 /*packed_tile_sizes=*/Value(),
1529 /*static_num_threads=*/builder.getDenseI64ArrayAttr({}),
1530 /*static_tile_sizes=*/staticTileSizesAttr,
1531 /*mapping=*/mapping);
1532}
1533
1534void transform::TileToForeachThreadOp::build(OpBuilder &builder,
1535 OperationState &result,
1536 Value target,
1537 ArrayRef<int64_t> staticNumThreads,
1538 transform::NumThreadsSpec,
1539 ArrayAttr mapping) {
1540 return build(builder, result, target,
1541 getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)),
1542 NumThreadsSpec(), mapping);
1543}
1544
1545void transform::TileToForeachThreadOp::build(
1546 OpBuilder &builder, OperationState &result, Value target,
1547 ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec,
1548 ArrayAttr mapping) {
1549 SmallVector<int64_t> staticNumThreads;
1550 SmallVector<Value> dynamicNumThreads;
1551 dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads,
1552 staticNumThreads);
1553 // Call the default builder which sets up the proper operands segment sizes
1554 // attributes for multiple variadic operands. In the absence of this, horrible
1555 // bugs ensue.
1556 MLIRContext *ctx = builder.getContext();
1557 auto operationType = pdl::OperationType::get(ctx);
1558 auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads);
1559 build(builder, result,
1560 /*resultTypes=*/TypeRange{operationType, operationType},
1561 /*target=*/target,
1562 /*num_threads=*/dynamicNumThreads,
1563 /*tile_sizes=*/ValueRange{},
1564 /*packed_num_threads=*/Value(),
1565 /*packed_tile_sizes=*/Value(),
1566 /*static_num_threads=*/staticNumThreadsAttr,
1567 /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}),
1568 /*mapping=*/mapping);
1569}
1570
1571/// Assuming that `ofr` is an index attr or a transform dialect handle mapped
1572/// to exactly one op with one index result, return that value.
1573static DiagnosedSilenceableFailure unpackPDLOperations(
1574 transform::TransformState &state, TransformOpInterface transformOp,
1575 SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) {
1576 for (OpFoldResult ofr : ofrs) {
1577 if (ofr.is<Attribute>()) {
1578 if (!ofr.get<Attribute>().isa<IntegerAttr>())
1579 return transformOp.emitDefiniteFailure() << "expected IntegerAttr";
1580 result.push_back(ofr);
1581 continue;
1582 }
1583 ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>());
1584 if (payloadOps.size() != 1) {
1585 DiagnosedSilenceableFailure diag =
1586 transformOp.emitSilenceableError()
1587 << "handle must be mapped to exactly one payload op";
1588 diag.attachNote(ofr.get<Value>().getLoc())
1589 << "mapped to " << payloadOps.size() << " payload ops";
1590 return diag;
1591 }
1592
1593 Operation *op = payloadOps[0];
1594 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
1595 DiagnosedSilenceableFailure diag =
1596 transformOp.emitSilenceableError()
1597 << "payload op must have exactly 1 index result";
1598 diag.attachNote(op->getLoc())
1599 << "has " << op->getNumResults() << " results";
1600 return diag;
1601 }
1602 result.push_back(op->getResult(0));
1603 }
1604
1605 return DiagnosedSilenceableFailure::success();
1606}
1607
1608// Given a list of OpFoldResults that are either index attrs or op
1609// handles, return a list of OpFoldResults where all op handles are
1610// replaced with the first (and only) OpResult of that payload op. (There
1611// must be exactly one mapped payload op and it must have exactly one
1612// index result.)
1613static DiagnosedSilenceableFailure
1614unpackPDLOperations(transform::TransformState &state,
1615 TransformOpInterface transformOp,
1616 SmallVector<OpFoldResult> &result, Value packedHandle) {
1617 ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle);
1618 for (Operation *op : payloadOps) {
1619 if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) {
1620 DiagnosedSilenceableFailure diag =
1621 transformOp.emitSilenceableError()
1622 << "payload op must have exactly 1 index result";
1623 diag.attachNote(op->getLoc())
1624 << "has " << op->getNumResults() << " results";
1625 return diag;
1626 }
1627 result.push_back(op->getResult(0));
1628 }
1629
1630 return DiagnosedSilenceableFailure::success();
1631}
1632
1633DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl(
1634 RewriterBase &rewriter, transform::TransformState &state,
1635 TransformOpInterface transformOp, ArrayRef<Operation *> targets,
1636 ArrayRef<OpFoldResult> mixedNumThreads,
1637 ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping,
1638 SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) {
1639 if (targets.empty())
1640 return DiagnosedSilenceableFailure::success();
1641
1642 // Transform all targets one by one.
1643 for (Operation *target : targets) {
1644 auto tilableOp = dyn_cast<TilingInterface>(target);
1645 if (!tilableOp) {
1646 DiagnosedSilenceableFailure diag =
1647 transformOp.emitSilenceableError()
1648 << "only TilingInterface ops are supported";
1649 diag.attachNote(target->getLoc()) << "target op";
1650 return diag;
1651 }
1652 rewriter.setInsertionPoint(tilableOp);
1653 FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure();
1654 if (!mixedNumThreads.empty()) {
1655 tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp,
1656 mixedNumThreads, mapping);
1657 } else {
1658 tilingResult = linalg::tileToForeachThreadOpUsingTileSizes(
1659 rewriter, tilableOp, mixedTileSizes, mapping);
1660 }
1661
1662 if (failed(tilingResult))
1663 return transformOp.emitDefaultSilenceableFailure(tilableOp);
1664 rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults());
1665
1666 tileOps.push_back(tilingResult->tileOp);
1667 tiledOps.push_back(tilingResult->tiledOp);
1668 }
1669 return DiagnosedSilenceableFailure::success();
1670}
1671
1672DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply(
1673 transform::TransformResults &transformResults,
1674 transform::TransformState &state) {
1675 IRRewriter rewriter(getContext());
1676 auto transformOp = cast<TransformOpInterface>(getOperation());
1677 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1678
1679 // Result payload ops.
1680 SmallVector<Operation *> tileOps;
1681 SmallVector<Operation *> tiledOps;
1682
1683 // Unpack handles.
1684 SmallVector<OpFoldResult> mixedNumThreads;
1685 DiagnosedSilenceableFailure status =
1686 getPackedNumThreads()
1687 ? unpackPDLOperations(state, transformOp, mixedNumThreads,
1688 getPackedNumThreads())
1689 : unpackPDLOperations(state, transformOp, mixedNumThreads,
1690 getMixedNumThreads());
1691 if (!status.succeeded())
1692 return status;
1693 SmallVector<OpFoldResult> mixedTileSizes;
1694 status = getPackedTileSizes()
1695 ? unpackPDLOperations(state, transformOp, mixedTileSizes,
1696 getPackedTileSizes())
1697 : unpackPDLOperations(state, transformOp, mixedTileSizes,
1698 getMixedTileSizes());
1699 if (!status.succeeded())
1700 return status;
1701
1702 DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl(
1703 rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes,
1704 getMapping(), tileOps, tiledOps);
1705
1706 if (!diag.succeeded()) {
1707 transformResults.set(getForeachThreadOp().cast<OpResult>(), {});
1708 transformResults.set(getTiledOp().cast<OpResult>(), {});
1709 return diag;
1710 }
1711
1712 transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps);
1713 transformResults.set(getTiledOp().cast<OpResult>(), tiledOps);
1714
1715 return DiagnosedSilenceableFailure::success();
1716}
1717
1718void transform::TileToForeachThreadOp::getEffects(
1719 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1720 consumesHandle(getTarget(), effects);
1721 onlyReadsHandle(getTileSizes(), effects);
1722 onlyReadsHandle(getNumThreads(), effects);
1723 producesHandle(getResults(), effects);
1724}
1725
1726SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() {
1727 Builder b(getContext());
1728 return getMixedValues(getStaticNumThreads(), getNumThreads(), b);
1729}
1730
1731SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() {
1732 Builder b(getContext());
1733 return getMixedValues(getStaticTileSizes(), getTileSizes(), b);
1734}
1735
1736LogicalResult TileToForeachThreadOp::verify() {
1737 int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) +
1738 static_cast<int>(getPackedNumThreads() != Value());
1739 if (numThreadsSpec > 1)
1740 return emitOpError(
1741 "num_threads and packed_num_threads are mutually exclusive");
1742 int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) +
1743 static_cast<int>(getPackedTileSizes() != Value());
1744 if (tileSizesSpec > 1)
1745 return emitOpError(
1746 "tile_sizes and packed_tile_sizes are mutually exclusive");
1747 if (numThreadsSpec == 0 && tileSizesSpec == 0)
1748 return emitOpError(
1749 "either (packed_)num_threads or (packed_)tile_sizes must be specified");
1750 return success();
1751}
1752
1753//===----------------------------------------------------------------------===//
1754// TileToScfForOp
1755//===----------------------------------------------------------------------===//
1756
1757DiagnosedSilenceableFailure
1758transform::TileToScfForOp::apply(TransformResults &transformResults,
1759 TransformState &state) {
1760 ArrayRef<int64_t> tileSizes = getStaticSizes();
1761
1762 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1763 SmallVector<ArrayRef<Operation *>> dynamicSizeProducers;
1764 dynamicSizeProducers.reserve(getDynamicSizes().size());
1765 for (Value dynamicSizeProducerHandle : getDynamicSizes()) {
1766 dynamicSizeProducers.push_back(
1767 state.getPayloadOps(dynamicSizeProducerHandle));
1768
1769 if (dynamicSizeProducers.back().size() != targets.size()) {
1770 DiagnosedSilenceableFailure diag =
1771 emitSilenceableError()
1772 << "expected as many dynamic size-producing operations ("
1773 << dynamicSizeProducers.back().size() << ") as target ops ("
1774 << targets.size() << ")";
1775 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1776 return diag;
1777 }
1778
1779 for (Operation *op : dynamicSizeProducers.back()) {
1780 if (op->getNumResults() == 1 &&
1781 op->getResult(0).getType().isa<IndexType>())
1782 continue;
1783 DiagnosedSilenceableFailure diag =
1784 emitSilenceableError() << "expected sizes to be produced by ops "
1785 "with a single index-type result";
1786 diag.attachNote(op->getLoc()) << "size producer op";
1787 diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle";
1788 return diag;
1789 }
1790 }
1791
1792 SmallVector<Operation *> tiled;
1793 SmallVector<SmallVector<Operation *, 4>, 4> loops;
1794 loops.resize(getLoops().size());
1795 for (auto &en : llvm::enumerate(targets)) {
1796 auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value());
1797 if (!tilingInterfaceOp) {
1798 DiagnosedSilenceableFailure diag =
1799 emitSilenceableError() << "only TilingInterface ops are supported";
1800 diag.attachNote(en.value()->getLoc()) << "target op";
1801 return diag;
1802 }
1803
1804 scf::SCFTilingOptions tilingOptions;
1805 unsigned index = en.index();
1806 if (!tileSizes.empty()) {
1807 tilingOptions.setTileSizeComputationFunction(
1808 [&, index](OpBuilder &b, Operation *) {
1809 SmallVector<Value, 4> sizes;
1810 sizes.reserve(tileSizes.size());
1811 unsigned dynamicIdx = 0;
1812 for (OpFoldResult ofr : getMixedSizes()) {
1813 if (auto attr = ofr.dyn_cast<Attribute>()) {
1814 sizes.push_back(b.create<arith::ConstantIndexOp>(
1815 getLoc(), attr.cast<IntegerAttr>().getInt()));
1816 } else {
1817 sizes.push_back(
1818 dynamicSizeProducers[dynamicIdx++][index]->getResult(0));
1819 }
1820 }
1821 return sizes;
1822 });
1823 }
1824
1825 tilingOptions.setInterchange(getInterchange());
1826 TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext());
1827 FailureOr<scf::SCFTilingResult> tilingResult =
1828 tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions);
1829 if (failed(tilingResult))
1830 return DiagnosedSilenceableFailure::definiteFailure();
1831
1832 rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements);
1833
1834 tiled.append(tilingResult->tiledOps);
1835 for (const auto &en2 : llvm::enumerate(tilingResult->loops))
1836 loops[en2.index()].push_back(en2.value());
1837 }
1838
1839 transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled);
1840 for (const auto &en : llvm::enumerate(loops))
1841 transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value());
1842
1843 return DiagnosedSilenceableFailure::success();
1844}
1845
1846SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() {
1847 ValueRange dynamic = getDynamicSizes();
1848 ArrayRef<int64_t> tileSizes = getStaticSizes();
1849 SmallVector<OpFoldResult> results;
1850 results.reserve(tileSizes.size());
1851 unsigned dynamicPos = 0;
1852 Builder builder(getContext());
1853 for (int64_t size : tileSizes) {
1854 if (size == ShapedType::kDynamic) {
1855 results.push_back(dynamic[dynamicPos++]);
1856 } else {
1857 results.push_back(builder.getIndexAttr(size));
1858 }
1859 }
1860 return results;
1861}
1862
1863ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser,
1864 OperationState &result) {
1865 OpAsmParser::UnresolvedOperand target;
1866 SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes;
1867 DenseI64ArrayAttr staticSizes;
1868 auto pdlOperationType = pdl::OperationType::get(parser.getContext());
1869 if (parser.parseOperand(target) ||
1870 parser.resolveOperand(target, pdlOperationType, result.operands) ||
1871 parseDynamicIndexList(parser, dynamicSizes, staticSizes) ||
1872 parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands))
1873 return ParseResult::failure();
1874
1875 // Parse optional interchange.
1876 if (failed(parseOptionalInterchange(parser, result)))
1877 return ParseResult::failure();
1878 result.addAttribute(getStaticSizesAttrName(result.name), staticSizes);
1879 size_t numExpectedLoops =
1880 staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0);
1881 result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType));
1882 return success();
1883}
1884
1885void TileToScfForOp::print(OpAsmPrinter &p) {
1886 p << ' ' << getTarget();
1887 printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes());
1888 printOptionalInterchange(p, getInterchange());
1889}
1890
1891void transform::TileToScfForOp::getEffects(
1892 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
1893 consumesHandle(getTarget(), effects);
1894 onlyReadsHandle(getDynamicSizes(), effects);
1895 producesHandle(getTiledLinalgOp(), effects);
1896 producesHandle(getLoops(), effects);
1897 modifiesPayload(effects);
1898}
1899
1900//===----------------------------------------------------------------------===//
1901// VectorizeOp
1902//===----------------------------------------------------------------------===//
1903
1904void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result,
1905 Value target, bool vectorizePadding,
1906 bool vectorizeExtract) {
1907 result.addOperands(target);
1908 if (vectorizePadding) {
1909 result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name),
1910 builder.getUnitAttr());
1911 }
1912 if (vectorizeExtract) {
1913 result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name),
1914 builder.getUnitAttr());
1915 }
1916 result.addTypes(pdl::OperationType::get(builder.getContext()));
1917}
1918
1919namespace {
1920/// This is an helper only to call vectorize via a pattern inside of
1921/// VectorizeOp::applyToOne.
1922struct VectorizationPattern : public RewritePattern {
1923 explicit VectorizationPattern(MLIRContext *context,
1924 bool vectorizeExtract = false)
1925 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context),
1926 vectorizeNDExtract(vectorizeExtract) {}
1927 LogicalResult matchAndRewrite(Operation *op,
1928 PatternRewriter &rewriter) const override {
1929 LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
1930 if (!linalgOp)
1931 return rewriter.notifyMatchFailure(op, "expected Linalg Op");
1932 return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
1933 vectorizeNDExtract);
1934 }
1935
1936private:
1937 /// Controls whether to vectorize `tensor.extract` when the input tensor is
1938 /// rank >= 2.
1939 bool vectorizeNDExtract = false;
1940};
1941} // namespace
1942
1943DiagnosedSilenceableFailure
1944transform::VectorizeOp::applyToOne(Operation *target,
1945 SmallVectorImpl<Operation *> &results,
1946 transform::TransformState &state) {
1947 if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
1948 auto diag = this->emitOpError("requires isolated-from-above targets");
1949 diag.attachNote(target->getLoc()) << "non-isolated target";
1950 return DiagnosedSilenceableFailure::definiteFailure();
1951 }
1952
1953 MLIRContext *ctx = getContext();
1954 RewritePatternSet patterns(ctx);
1955 patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract());
1956
1957 if (!getDisableTransferPermutationMapLoweringPatterns())
1958 vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
1959
1960 if (!getDisableMultiReductionToContractPatterns())
1961 vector::populateVectorReductionToContractPatterns(patterns);
1962
1963 patterns.add<linalg::LinalgCopyVTRForwardingPattern,
1964 linalg::LinalgCopyVTWForwardingPattern>(ctx,
1965 /*benefit=*/2);
1966 vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
1967 vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
1968
1969 patterns.add<CopyVectorizationPattern>(ctx);
1970
1971 if (getVectorizePadding())
1972 linalg::populatePadOpVectorizationPatterns(patterns);
1973
1974 if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns))))
1975 return emitDefaultDefiniteFailure(target);
1976
1977 results.push_back(target);
1978 return DiagnosedSilenceableFailure::success();
1979}
1980
1981//===----------------------------------------------------------------------===//
1982// MaskedVectorizeOp
1983//===----------------------------------------------------------------------===//
1984
1985DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
1986 mlir::transform::TransformResults &transformResults,
1987 mlir::transform::TransformState &state) {
1988 IRRewriter rewriter(getContext());
1989 ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
1990 if (targets.empty())
1991 return DiagnosedSilenceableFailure::success();
1992
1993 SmallVector<int64_t> vectorSizes;
1994 for (OpFoldResult sz : getMixedVectorSizes()) {
1995 if (sz.is<Attribute>()) {
1996 auto attr = sz.get<Attribute>();
1997 vectorSizes.push_back(attr.cast<IntegerAttr>().getInt());
1998 continue;
1999 }
2000
2001 ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>());
2002 if (szPayloads.size() != 1) {
2003 auto diag = this->emitOpError(
2004 "requires vector size handle that is mapped to 1 payload op");
2005 diag.attachNote(sz.get<Value>().getLoc())
2006 << "mapped to " << szPayloads.size() << " payload ops";
2007 return DiagnosedSilenceableFailure::definiteFailure();
2008 }
2009
2010 Operation *szPayloadOp = szPayloads[0];
2011 if (szPayloadOp->getNumResults() != 1 ||
2012 !szPayloadOp->getResult(0).getType().isIndex()) {
2013 auto diag = this->emitOpError(
2014 "requires vector size payload op with 1 index result");
2015 diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
2016 return DiagnosedSilenceableFailure::definiteFailure();
2017 }
2018
2019 IntegerAttr attr;
2020 if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
2021 auto diag = this->emitOpError("requires constant vector size");
2022 diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
2023 return DiagnosedSilenceableFailure::definiteFailure();
2024 }
2025
2026 vectorSizes.push_back(attr.getInt());
2027 }
2028
2029 // TODO: Check that the correct number of vectorSizes was provided.
2030
2031 for (Operation *target : targets) {
2032 auto linalgOp = dyn_cast<LinalgOp>(target);
2033 if (!linalgOp) {
2034 Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
2035 diag << "cannot vectorize non-Linalg op";
2036 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
2037 }
2038
2039 if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) {
2040 Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
2041 diag << "failed to vectorize op";
2042 return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
2043 }
2044 }
2045
2046 return DiagnosedSilenceableFailure::success();
2047}
2048
2049void transform::MaskedVectorizeOp::getEffects(
2050 SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
2051 consumesHandle(getTarget(), effects);
2052 onlyReadsHandle(getVectorSizes(), effects);
2053}
2054
2055SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
2056 OpBuilder b(getContext());
2057 return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
2058}
2059
2060//===----------------------------------------------------------------------===//
2061// Transform op registration
2062//===----------------------------------------------------------------------===//
2063
2064namespace {
2065/// Registers new ops and declares PDL as dependent dialect since the
2066/// additional ops are using PDL types for operands and results.
2067class LinalgTransformDialectExtension
2068 : public transform::TransformDialectExtension<
2069 LinalgTransformDialectExtension> {
2070public:
2071 using Base::Base;
2072
2073 void init() {
2074 declareDependentDialect<pdl::PDLDialect>();
2075 declareDependentDialect<LinalgDialect>();
2076 declareGeneratedDialect<AffineDialect>();
2077 declareGeneratedDialect<arith::ArithDialect>();
2078 declareGeneratedDialect<scf::SCFDialect>();
2079 declareGeneratedDialect<vector::VectorDialect>();
2080 declareGeneratedDialect<gpu::GPUDialect>();
2081
2082 registerTransformOps<
2083#define GET_OP_LIST
2084#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
2085 >();
2086 }
2087};
2088} // namespace
2089
2090#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc"
2091
2092#define GET_OP_CLASSES
2093#include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc"
2094
2095void mlir::linalg::registerTransformDialectExtension(
2096 DialectRegistry &registry) {
2097 registry.addExtensions<LinalgTransformDialectExtension>();
2098}

/build/source/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() = default;
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a loc(...) specifier if printing debug info is enabled.
331 virtual void printOptionalLocationSpecifier(Location loc) = 0;
332
333 /// Print a newline and indent the printer to the start of the current
334 /// operation.
335 virtual void printNewline() = 0;
336
337 /// Increase indentation.
338 virtual void increaseIndent() = 0;
339
340 /// Decrease indentation.
341 virtual void decreaseIndent() = 0;
342
343 /// Print a block argument in the usual format of:
344 /// %ssaName : type {attr1=42} loc("here")
345 /// where location printing is controlled by the standard internal option.
346 /// You may pass omitType=true to not print a type, and pass an empty
347 /// attribute list if you don't care for attributes.
348 virtual void printRegionArgument(BlockArgument arg,
349 ArrayRef<NamedAttribute> argAttrs = {},
350 bool omitType = false) = 0;
351
352 /// Print implementations for various things an operation contains.
353 virtual void printOperand(Value value) = 0;
354 virtual void printOperand(Value value, raw_ostream &os) = 0;
355
356 /// Print a comma separated list of operands.
357 template <typename ContainerType>
358 void printOperands(const ContainerType &container) {
359 printOperands(container.begin(), container.end());
360 }
361
362 /// Print a comma separated list of operands.
363 template <typename IteratorType>
364 void printOperands(IteratorType it, IteratorType end) {
365 llvm::interleaveComma(llvm::make_range(it, end), getStream(),
366 [this](Value value) { printOperand(value); });
367 }
368
369 /// Print the given successor.
370 virtual void printSuccessor(Block *successor) = 0;
371
372 /// Print the successor and its operands.
373 virtual void printSuccessorAndUseList(Block *successor,
374 ValueRange succOperands) = 0;
375
376 /// If the specified operation has attributes, print out an attribute
377 /// dictionary with their values. elidedAttrs allows the client to ignore
378 /// specific well known attributes, commonly used if the attribute value is
379 /// printed some other way (like as a fixed operand).
380 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
381 ArrayRef<StringRef> elidedAttrs = {}) = 0;
382
383 /// If the specified operation has attributes, print out an attribute
384 /// dictionary prefixed with 'attributes'.
385 virtual void
386 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
387 ArrayRef<StringRef> elidedAttrs = {}) = 0;
388
389 /// Prints the entire operation with the custom assembly form, if available,
390 /// or the generic assembly form, otherwise.
391 virtual void printCustomOrGenericOp(Operation *op) = 0;
392
393 /// Print the entire operation with the default generic assembly form.
394 /// If `printOpName` is true, then the operation name is printed (the default)
395 /// otherwise it is omitted and the print will start with the operand list.
396 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
397
398 /// Prints a region.
399 /// If 'printEntryBlockArgs' is false, the arguments of the
400 /// block are not printed. If 'printBlockTerminator' is false, the terminator
401 /// operation of the block is not printed. If printEmptyBlock is true, then
402 /// the block header is printed even if the block is empty.
403 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
404 bool printBlockTerminators = true,
405 bool printEmptyBlock = false) = 0;
406
407 /// Renumber the arguments for the specified region to the same names as the
408 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
409 /// operations. If any entry in namesToUse is null, the corresponding
410 /// argument name is left alone.
411 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
412
413 /// Prints an affine map of SSA ids, where SSA id names are used in place
414 /// of dims/symbols.
415 /// Operand values must come from single-result sources, and be valid
416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
418 ValueRange operands) = 0;
419
420 /// Prints an affine expression of SSA ids with SSA id names used instead of
421 /// dims and symbols.
422 /// Operand values must come from single-result sources, and be valid
423 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
424 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
425 ValueRange symOperands) = 0;
426
427 /// Print the complete type of an operation in functional form.
428 void printFunctionalType(Operation *op);
429 using AsmPrinter::printFunctionalType;
430};
431
432// Make the implementations convenient to use.
433inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
434 p.printOperand(value);
435 return p;
436}
437
438template <typename T,
439 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
440 !std::is_convertible<T &, Value &>::value,
441 T> * = nullptr>
442inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
443 p.printOperands(values);
444 return p;
445}
446
447inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
448 p.printSuccessor(value);
449 return p;
450}
451
452//===----------------------------------------------------------------------===//
453// AsmParser
454//===----------------------------------------------------------------------===//
455
456/// This base class exposes generic asm parser hooks, usable across the various
457/// derived parsers.
458class AsmParser {
459public:
460 AsmParser() = default;
461 virtual ~AsmParser();
462
463 MLIRContext *getContext() const;
464
465 /// Return the location of the original name token.
466 virtual SMLoc getNameLoc() const = 0;
467
468 //===--------------------------------------------------------------------===//
469 // Utilities
470 //===--------------------------------------------------------------------===//
471
472 /// Emit a diagnostic at the specified location and return failure.
473 virtual InFlightDiagnostic emitError(SMLoc loc,
474 const Twine &message = {}) = 0;
475
476 /// Return a builder which provides useful access to MLIRContext, global
477 /// objects like types and attributes.
478 virtual Builder &getBuilder() const = 0;
479
480 /// Get the location of the next token and store it into the argument. This
481 /// always succeeds.
482 virtual SMLoc getCurrentLocation() = 0;
483 ParseResult getCurrentLocation(SMLoc *loc) {
484 *loc = getCurrentLocation();
485 return success();
486 }
487
488 /// Re-encode the given source location as an MLIR location and return it.
489 /// Note: This method should only be used when a `Location` is necessary, as
490 /// the encoding process is not efficient.
491 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
492
493 //===--------------------------------------------------------------------===//
494 // Token Parsing
495 //===--------------------------------------------------------------------===//
496
497 /// Parse a '->' token.
498 virtual ParseResult parseArrow() = 0;
499
500 /// Parse a '->' token if present
501 virtual ParseResult parseOptionalArrow() = 0;
502
503 /// Parse a `{` token.
504 virtual ParseResult parseLBrace() = 0;
505
506 /// Parse a `{` token if present.
507 virtual ParseResult parseOptionalLBrace() = 0;
508
509 /// Parse a `}` token.
510 virtual ParseResult parseRBrace() = 0;
511
512 /// Parse a `}` token if present.
513 virtual ParseResult parseOptionalRBrace() = 0;
514
515 /// Parse a `:` token.
516 virtual ParseResult parseColon() = 0;
517
518 /// Parse a `:` token if present.
519 virtual ParseResult parseOptionalColon() = 0;
520
521 /// Parse a `,` token.
522 virtual ParseResult parseComma() = 0;
523
524 /// Parse a `,` token if present.
525 virtual ParseResult parseOptionalComma() = 0;
526
527 /// Parse a `=` token.
528 virtual ParseResult parseEqual() = 0;
529
530 /// Parse a `=` token if present.
531 virtual ParseResult parseOptionalEqual() = 0;
532
533 /// Parse a '<' token.
534 virtual ParseResult parseLess() = 0;
535
536 /// Parse a '<' token if present.
537 virtual ParseResult parseOptionalLess() = 0;
538
539 /// Parse a '>' token.
540 virtual ParseResult parseGreater() = 0;
541
542 /// Parse a '>' token if present.
543 virtual ParseResult parseOptionalGreater() = 0;
544
545 /// Parse a '?' token.
546 virtual ParseResult parseQuestion() = 0;
547
548 /// Parse a '?' token if present.
549 virtual ParseResult parseOptionalQuestion() = 0;
550
551 /// Parse a '+' token.
552 virtual ParseResult parsePlus() = 0;
553
554 /// Parse a '+' token if present.
555 virtual ParseResult parseOptionalPlus() = 0;
556
557 /// Parse a '*' token.
558 virtual ParseResult parseStar() = 0;
559
560 /// Parse a '*' token if present.
561 virtual ParseResult parseOptionalStar() = 0;
562
563 /// Parse a '|' token.
564 virtual ParseResult parseVerticalBar() = 0;
565
566 /// Parse a '|' token if present.
567 virtual ParseResult parseOptionalVerticalBar() = 0;
568
569 /// Parse a quoted string token.
570 ParseResult parseString(std::string *string) {
571 auto loc = getCurrentLocation();
572 if (parseOptionalString(string))
573 return emitError(loc, "expected string");
574 return success();
575 }
576
577 /// Parse a quoted string token if present.
578 virtual ParseResult parseOptionalString(std::string *string) = 0;
579
580 /// Parses a Base64 encoded string of bytes.
581 virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0;
582
583 /// Parse a `(` token.
584 virtual ParseResult parseLParen() = 0;
585
586 /// Parse a `(` token if present.
587 virtual ParseResult parseOptionalLParen() = 0;
588
589 /// Parse a `)` token.
590 virtual ParseResult parseRParen() = 0;
591
592 /// Parse a `)` token if present.
593 virtual ParseResult parseOptionalRParen() = 0;
594
595 /// Parse a `[` token.
596 virtual ParseResult parseLSquare() = 0;
597
598 /// Parse a `[` token if present.
599 virtual ParseResult parseOptionalLSquare() = 0;
600
601 /// Parse a `]` token.
602 virtual ParseResult parseRSquare() = 0;
603
604 /// Parse a `]` token if present.
605 virtual ParseResult parseOptionalRSquare() = 0;
606
607 /// Parse a `...` token.
608 virtual ParseResult parseEllipsis() = 0;
609
610 /// Parse a `...` token if present;
611 virtual ParseResult parseOptionalEllipsis() = 0;
612
613 /// Parse a floating point value from the stream.
614 virtual ParseResult parseFloat(double &result) = 0;
615
616 /// Parse an integer value from the stream.
617 template <typename IntT>
618 ParseResult parseInteger(IntT &result) {
619 auto loc = getCurrentLocation();
620 OptionalParseResult parseResult = parseOptionalInteger(result);
6
Calling 'AsmParser::parseOptionalInteger'
10
Returning from 'AsmParser::parseOptionalInteger'
621 if (!parseResult.has_value())
11
Taking true branch
622 return emitError(loc, "expected integer value");
12
Returning without writing to 'result'
623 return *parseResult;
624 }
625
626 /// Parse an optional integer value from the stream.
627 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
628
629 template <typename IntT>
630 OptionalParseResult parseOptionalInteger(IntT &result) {
631 auto loc = getCurrentLocation();
632
633 // Parse the unsigned variant.
634 APInt uintResult;
635 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
636 if (!parseResult.has_value() || failed(*parseResult))
7
Assuming the condition is true
8
Taking true branch
637 return parseResult;
9
Returning without writing to 'result'
638
639 // Try to convert to the provided integer type. sextOrTrunc is correct even
640 // for unsigned types because parseOptionalInteger ensures the sign bit is
641 // zero for non-negated integers.
642 result =
643 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
644 if (APInt(uintResult.getBitWidth(), result) != uintResult)
645 return emitError(loc, "integer value too large");
646 return success();
647 }
648
649 /// These are the supported delimiters around operand lists and region
650 /// argument lists, used by parseOperandList.
651 enum class Delimiter {
652 /// Zero or more operands with no delimiters.
653 None,
654 /// Parens surrounding zero or more operands.
655 Paren,
656 /// Square brackets surrounding zero or more operands.
657 Square,
658 /// <> brackets surrounding zero or more operands.
659 LessGreater,
660 /// {} brackets surrounding zero or more operands.
661 Braces,
662 /// Parens supporting zero or more operands, or nothing.
663 OptionalParen,
664 /// Square brackets supporting zero or more ops, or nothing.
665 OptionalSquare,
666 /// <> brackets supporting zero or more ops, or nothing.
667 OptionalLessGreater,
668 /// {} brackets surrounding zero or more operands, or nothing.
669 OptionalBraces,
670 };
671
672 /// Parse a list of comma-separated items with an optional delimiter. If a
673 /// delimiter is provided, then an empty list is allowed. If not, then at
674 /// least one element will be parsed.
675 ///
676 /// contextMessage is an optional message appended to "expected '('" sorts of
677 /// diagnostics when parsing the delimeters.
678 virtual ParseResult
679 parseCommaSeparatedList(Delimiter delimiter,
680 function_ref<ParseResult()> parseElementFn,
681 StringRef contextMessage = StringRef()) = 0;
682
683 /// Parse a comma separated list of elements that must have at least one entry
684 /// in it.
685 ParseResult
686 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
687 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
688 }
689
690 //===--------------------------------------------------------------------===//
691 // Keyword Parsing
692 //===--------------------------------------------------------------------===//
693
694 /// This class represents a StringSwitch like class that is useful for parsing
695 /// expected keywords. On construction, it invokes `parseKeyword` and
696 /// processes each of the provided cases statements until a match is hit. The
697 /// provided `ResultT` must be assignable from `failure()`.
698 template <typename ResultT = ParseResult>
699 class KeywordSwitch {
700 public:
701 KeywordSwitch(AsmParser &parser)
702 : parser(parser), loc(parser.getCurrentLocation()) {
703 if (failed(parser.parseKeywordOrCompletion(&keyword)))
704 result = failure();
705 }
706
707 /// Case that uses the provided value when true.
708 KeywordSwitch &Case(StringLiteral str, ResultT value) {
709 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
710 }
711 KeywordSwitch &Default(ResultT value) {
712 return Default([&](StringRef, SMLoc) { return std::move(value); });
713 }
714 /// Case that invokes the provided functor when true. The parameters passed
715 /// to the functor are the keyword, and the location of the keyword (in case
716 /// any errors need to be emitted).
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Case(StringLiteral str, FnT &&fn) {
720 if (result)
721 return *this;
722
723 // If the word was empty, record this as a completion.
724 if (keyword.empty())
725 parser.codeCompleteExpectedTokens(str);
726 else if (keyword == str)
727 result.emplace(std::move(fn(keyword, loc)));
728 return *this;
729 }
730 template <typename FnT>
731 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
732 Default(FnT &&fn) {
733 if (!result)
734 result.emplace(fn(keyword, loc));
735 return *this;
736 }
737
738 /// Returns true if this switch has a value yet.
739 bool hasValue() const { return result.has_value(); }
740
741 /// Return the result of the switch.
742 [[nodiscard]] operator ResultT() {
743 if (!result)
744 return parser.emitError(loc, "unexpected keyword: ") << keyword;
745 return std::move(*result);
746 }
747
748 private:
749 /// The parser used to construct this switch.
750 AsmParser &parser;
751
752 /// The location of the keyword, used to emit errors as necessary.
753 SMLoc loc;
754
755 /// The parsed keyword itself.
756 StringRef keyword;
757
758 /// The result of the switch statement or none if currently unknown.
759 Optional<ResultT> result;
760 };
761
762 /// Parse a given keyword.
763 ParseResult parseKeyword(StringRef keyword) {
764 return parseKeyword(keyword, "");
765 }
766 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
767
768 /// Parse a keyword into 'keyword'.
769 ParseResult parseKeyword(StringRef *keyword) {
770 auto loc = getCurrentLocation();
771 if (parseOptionalKeyword(keyword))
772 return emitError(loc, "expected valid keyword");
773 return success();
774 }
775
776 /// Parse the given keyword if present.
777 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
778
779 /// Parse a keyword, if present, into 'keyword'.
780 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
781
782 /// Parse a keyword, if present, and if one of the 'allowedValues',
783 /// into 'keyword'
784 virtual ParseResult
785 parseOptionalKeyword(StringRef *keyword,
786 ArrayRef<StringRef> allowedValues) = 0;
787
788 /// Parse a keyword or a quoted string.
789 ParseResult parseKeywordOrString(std::string *result) {
790 if (failed(parseOptionalKeywordOrString(result)))
791 return emitError(getCurrentLocation())
792 << "expected valid keyword or string";
793 return success();
794 }
795
796 /// Parse an optional keyword or string.
797 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
798
799 //===--------------------------------------------------------------------===//
800 // Attribute/Type Parsing
801 //===--------------------------------------------------------------------===//
802
803 /// Invoke the `getChecked` method of the given Attribute or Type class, using
804 /// the provided location to emit errors in the case of failure. Note that
805 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
806 /// context parameter.
807 template <typename T, typename... ParamsT>
808 auto getChecked(SMLoc loc, ParamsT &&...params) {
809 return T::getChecked([&] { return emitError(loc); },
810 std::forward<ParamsT>(params)...);
811 }
812 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
813 /// errors.
814 template <typename T, typename... ParamsT>
815 auto getChecked(ParamsT &&...params) {
816 return T::getChecked([&] { return emitError(getNameLoc()); },
817 std::forward<ParamsT>(params)...);
818 }
819
820 //===--------------------------------------------------------------------===//
821 // Attribute Parsing
822 //===--------------------------------------------------------------------===//
823
824 /// Parse an arbitrary attribute of a given type and return it in result.
825 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
826
827 /// Parse a custom attribute with the provided callback, unless the next
828 /// token is `#`, in which case the generic parser is invoked.
829 virtual ParseResult parseCustomAttributeWithFallback(
830 Attribute &result, Type type,
831 function_ref<ParseResult(Attribute &result, Type type)>
832 parseAttribute) = 0;
833
834 /// Parse an attribute of a specific kind and type.
835 template <typename AttrType>
836 ParseResult parseAttribute(AttrType &result, Type type = {}) {
837 SMLoc loc = getCurrentLocation();
838
839 // Parse any kind of attribute.
840 Attribute attr;
841 if (parseAttribute(attr, type))
842 return failure();
843
844 // Check for the right kind of attribute.
845 if (!(result = attr.dyn_cast<AttrType>()))
846 return emitError(loc, "invalid kind of attribute specified");
847
848 return success();
849 }
850
851 /// Parse an arbitrary attribute and return it in result. This also adds the
852 /// attribute to the specified attribute list with the specified name.
853 ParseResult parseAttribute(Attribute &result, StringRef attrName,
854 NamedAttrList &attrs) {
855 return parseAttribute(result, Type(), attrName, attrs);
856 }
857
858 /// Parse an attribute of a specific kind and type.
859 template <typename AttrType>
860 ParseResult parseAttribute(AttrType &result, StringRef attrName,
861 NamedAttrList &attrs) {
862 return parseAttribute(result, Type(), attrName, attrs);
863 }
864
865 /// Parse an arbitrary attribute of a given type and populate it in `result`.
866 /// This also adds the attribute to the specified attribute list with the
867 /// specified name.
868 template <typename AttrType>
869 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
870 NamedAttrList &attrs) {
871 SMLoc loc = getCurrentLocation();
872
873 // Parse any kind of attribute.
874 Attribute attr;
875 if (parseAttribute(attr, type))
876 return failure();
877
878 // Check for the right kind of attribute.
879 result = attr.dyn_cast<AttrType>();
880 if (!result)
881 return emitError(loc, "invalid kind of attribute specified");
882
883 attrs.append(attrName, result);
884 return success();
885 }
886
887 /// Trait to check if `AttrType` provides a `parse` method.
888 template <typename AttrType>
889 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
890 std::declval<Type>()));
891 template <typename AttrType>
892 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
893
894 /// Parse a custom attribute of a given type unless the next token is `#`, in
895 /// which case the generic parser is invoked. The parsed attribute is
896 /// populated in `result` and also added to the specified attribute list with
897 /// the specified name.
898 template <typename AttrType>
899 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
900 parseCustomAttributeWithFallback(AttrType &result, Type type,
901 StringRef attrName, NamedAttrList &attrs) {
902 SMLoc loc = getCurrentLocation();
903
904 // Parse any kind of attribute.
905 Attribute attr;
906 if (parseCustomAttributeWithFallback(
907 attr, type, [&](Attribute &result, Type type) -> ParseResult {
908 result = AttrType::parse(*this, type);
909 if (!result)
910 return failure();
911 return success();
912 }))
913 return failure();
914
915 // Check for the right kind of attribute.
916 result = attr.dyn_cast<AttrType>();
917 if (!result)
918 return emitError(loc, "invalid kind of attribute specified");
919
920 attrs.append(attrName, result);
921 return success();
922 }
923
924 /// SFINAE parsing method for Attribute that don't implement a parse method.
925 template <typename AttrType>
926 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
927 parseCustomAttributeWithFallback(AttrType &result, Type type,
928 StringRef attrName, NamedAttrList &attrs) {
929 return parseAttribute(result, type, attrName, attrs);
930 }
931
932 /// Parse a custom attribute of a given type unless the next token is `#`, in
933 /// which case the generic parser is invoked. The parsed attribute is
934 /// populated in `result`.
935 template <typename AttrType>
936 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
937 parseCustomAttributeWithFallback(AttrType &result) {
938 SMLoc loc = getCurrentLocation();
939
940 // Parse any kind of attribute.
941 Attribute attr;
942 if (parseCustomAttributeWithFallback(
943 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
944 result = AttrType::parse(*this, type);
945 return success(!!result);
946 }))
947 return failure();
948
949 // Check for the right kind of attribute.
950 result = attr.dyn_cast<AttrType>();
951 if (!result)
952 return emitError(loc, "invalid kind of attribute specified");
953 return success();
954 }
955
956 /// SFINAE parsing method for Attribute that don't implement a parse method.
957 template <typename AttrType>
958 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
959 parseCustomAttributeWithFallback(AttrType &result) {
960 return parseAttribute(result);
961 }
962
963 /// Parse an arbitrary optional attribute of a given type and return it in
964 /// result.
965 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
966 Type type = {}) = 0;
967
968 /// Parse an optional array attribute and return it in result.
969 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
970 Type type = {}) = 0;
971
972 /// Parse an optional string attribute and return it in result.
973 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
974 Type type = {}) = 0;
975
976 /// Parse an optional symbol ref attribute and return it in result.
977 virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result,
978 Type type = {}) = 0;
979
980 /// Parse an optional attribute of a specific type and add it to the list with
981 /// the specified name.
982 template <typename AttrType>
983 OptionalParseResult parseOptionalAttribute(AttrType &result,
984 StringRef attrName,
985 NamedAttrList &attrs) {
986 return parseOptionalAttribute(result, Type(), attrName, attrs);
987 }
988
989 /// Parse an optional attribute of a specific type and add it to the list with
990 /// the specified name.
991 template <typename AttrType>
992 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
993 StringRef attrName,
994 NamedAttrList &attrs) {
995 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
996 if (parseResult.has_value() && succeeded(*parseResult))
997 attrs.append(attrName, result);
998 return parseResult;
999 }
1000
1001 /// Parse a named dictionary into 'result' if it is present.
1002 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
1003
1004 /// Parse a named dictionary into 'result' if the `attributes` keyword is
1005 /// present.
1006 virtual ParseResult
1007 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
1008
1009 /// Parse an affine map instance into 'map'.
1010 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
1011
1012 /// Parse an integer set instance into 'set'.
1013 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
1014
1015 //===--------------------------------------------------------------------===//
1016 // Identifier Parsing
1017 //===--------------------------------------------------------------------===//
1018
1019 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1020 /// attribute.
1021 ParseResult parseSymbolName(StringAttr &result) {
1022 if (failed(parseOptionalSymbolName(result)))
1023 return emitError(getCurrentLocation())
1024 << "expected valid '@'-identifier for symbol name";
1025 return success();
1026 }
1027
1028 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1029 /// attribute named 'attrName'.
1030 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1031 NamedAttrList &attrs) {
1032 if (parseSymbolName(result))
1033 return failure();
1034 attrs.append(attrName, result);
1035 return success();
1036 }
1037
1038 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1039 /// string attribute.
1040 virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0;
1041
1042 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1043 /// string attribute named 'attrName'.
1044 ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
1045 NamedAttrList &attrs) {
1046 if (succeeded(parseOptionalSymbolName(result))) {
1047 attrs.append(attrName, result);
1048 return success();
1049 }
1050 return failure();
1051 }
1052
1053 //===--------------------------------------------------------------------===//
1054 // Resource Parsing
1055 //===--------------------------------------------------------------------===//
1056
1057 /// Parse a handle to a resource within the assembly format.
1058 template <typename ResourceT>
1059 FailureOr<ResourceT> parseResourceHandle() {
1060 SMLoc handleLoc = getCurrentLocation();
1061
1062 // Try to load the dialect that owns the handle.
1063 auto *dialect =
1064 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1065 if (!dialect) {
1066 return emitError(handleLoc)
1067 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1068 << "' is unknown";
1069 }
1070
1071 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1072 if (failed(handle))
1073 return failure();
1074 if (auto *result = dyn_cast<ResourceT>(&*handle))
1075 return std::move(*result);
1076 return emitError(handleLoc) << "provided resource handle differs from the "
1077 "expected resource type";
1078 }
1079
1080 //===--------------------------------------------------------------------===//
1081 // Type Parsing
1082 //===--------------------------------------------------------------------===//
1083
1084 /// Parse a type.
1085 virtual ParseResult parseType(Type &result) = 0;
1086
1087 /// Parse a custom type with the provided callback, unless the next
1088 /// token is `#`, in which case the generic parser is invoked.
1089 virtual ParseResult parseCustomTypeWithFallback(
1090 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1091
1092 /// Parse an optional type.
1093 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1094
1095 /// Parse a type of a specific type.
1096 template <typename TypeT>
1097 ParseResult parseType(TypeT &result) {
1098 SMLoc loc = getCurrentLocation();
1099
1100 // Parse any kind of type.
1101 Type type;
1102 if (parseType(type))
1103 return failure();
1104
1105 // Check for the right kind of type.
1106 result = type.dyn_cast<TypeT>();
1107 if (!result)
1108 return emitError(loc, "invalid kind of type specified");
1109
1110 return success();
1111 }
1112
1113 /// Trait to check if `TypeT` provides a `parse` method.
1114 template <typename TypeT>
1115 using type_has_parse_method =
1116 decltype(TypeT::parse(std::declval<AsmParser &>()));
1117 template <typename TypeT>
1118 using detect_type_has_parse_method =
1119 llvm::is_detected<type_has_parse_method, TypeT>;
1120
1121 /// Parse a custom Type of a given type unless the next token is `#`, in
1122 /// which case the generic parser is invoked. The parsed Type is
1123 /// populated in `result`.
1124 template <typename TypeT>
1125 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1126 parseCustomTypeWithFallback(TypeT &result) {
1127 SMLoc loc = getCurrentLocation();
1128
1129 // Parse any kind of Type.
1130 Type type;
1131 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1132 result = TypeT::parse(*this);
1133 return success(!!result);
1134 }))
1135 return failure();
1136
1137 // Check for the right kind of Type.
1138 result = type.dyn_cast<TypeT>();
1139 if (!result)
1140 return emitError(loc, "invalid kind of Type specified");
1141 return success();
1142 }
1143
1144 /// SFINAE parsing method for Type that don't implement a parse method.
1145 template <typename TypeT>
1146 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1147 parseCustomTypeWithFallback(TypeT &result) {
1148 return parseType(result);
1149 }
1150
1151 /// Parse a type list.
1152 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1153 return parseCommaSeparatedList(
1154 [&]() { return parseType(result.emplace_back()); });
1155 }
1156
1157 /// Parse an arrow followed by a type list.
1158 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1159
1160 /// Parse an optional arrow followed by a type list.
1161 virtual ParseResult
1162 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1163
1164 /// Parse a colon followed by a type.
1165 virtual ParseResult parseColonType(Type &result) = 0;
1166
1167 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1168 template <typename TypeType>
1169 ParseResult parseColonType(TypeType &result) {
1170 SMLoc loc = getCurrentLocation();
1171
1172 // Parse any kind of type.
1173 Type type;
1174 if (parseColonType(type))
1175 return failure();
1176
1177 // Check for the right kind of type.
1178 result = type.dyn_cast<TypeType>();
1179 if (!result)
1180 return emitError(loc, "invalid kind of type specified");
1181
1182 return success();
1183 }
1184
1185 /// Parse a colon followed by a type list, which must have at least one type.
1186 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1187
1188 /// Parse an optional colon followed by a type list, which if present must
1189 /// have at least one type.
1190 virtual ParseResult
1191 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1192
1193 /// Parse a keyword followed by a type.
1194 ParseResult parseKeywordType(const char *keyword, Type &result) {
1195 return failure(parseKeyword(keyword) || parseType(result));
1196 }
1197
1198 /// Add the specified type to the end of the specified type list and return
1199 /// success. This is a helper designed to allow parse methods to be simple
1200 /// and chain through || operators.
1201 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1202 result.push_back(type);
1203 return success();
1204 }
1205
1206 /// Add the specified types to the end of the specified type list and return
1207 /// success. This is a helper designed to allow parse methods to be simple
1208 /// and chain through || operators.
1209 ParseResult addTypesToList(ArrayRef<Type> types,
1210 SmallVectorImpl<Type> &result) {
1211 result.append(types.begin(), types.end());
1212 return success();
1213 }
1214
1215 /// Parse a dimension list of a tensor or memref type. This populates the
1216 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1217 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1218 ///
1219 /// dimension-list ::= eps | dimension (`x` dimension)*
1220 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1221 /// dimension ::= `?` | decimal-literal
1222 ///
1223 /// When `allowDynamic` is not set, this is used to parse:
1224 ///
1225 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1226 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1227 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1228 bool allowDynamic = true,
1229 bool withTrailingX = true) = 0;
1230
1231 /// Parse an 'x' token in a dimension list, handling the case where the x is
1232 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1233 /// next token.
1234 virtual ParseResult parseXInDimensionList() = 0;
1235
1236protected:
1237 /// Parse a handle to a resource within the assembly format for the given
1238 /// dialect.
1239 virtual FailureOr<AsmDialectResourceHandle>
1240 parseResourceHandle(Dialect *dialect) = 0;
1241
1242 //===--------------------------------------------------------------------===//
1243 // Code Completion
1244 //===--------------------------------------------------------------------===//
1245
1246 /// Parse a keyword, or an empty string if the current location signals a code
1247 /// completion.
1248 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1249
1250 /// Signal the code completion of a set of expected tokens.
1251 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1252
1253private:
1254 AsmParser(const AsmParser &) = delete;
1255 void operator=(const AsmParser &) = delete;
1256};
1257
1258//===----------------------------------------------------------------------===//
1259// OpAsmParser
1260//===----------------------------------------------------------------------===//
1261
1262/// The OpAsmParser has methods for interacting with the asm parser: parsing
1263/// things from it, emitting errors etc. It has an intentionally high-level API
1264/// that is designed to reduce/constrain syntax innovation in individual
1265/// operations.
1266///
1267/// For example, consider an op like this:
1268///
1269/// %x = load %p[%1, %2] : memref<...>
1270///
1271/// The "%x = load" tokens are already parsed and therefore invisible to the
1272/// custom op parser. This can be supported by calling `parseOperandList` to
1273/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1274/// parse the indices, then calling `parseColonTypeList` to parse the result
1275/// type.
1276///
1277class OpAsmParser : public AsmParser {
1278public:
1279 using AsmParser::AsmParser;
1280 ~OpAsmParser() override;
1281
1282 /// Parse a loc(...) specifier if present, filling in result if so.
1283 /// Location for BlockArgument and Operation may be deferred with an alias, in
1284 /// which case an OpaqueLoc is set and will be resolved when parsing
1285 /// completes.
1286 virtual ParseResult
1287 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1288
1289 /// Return the name of the specified result in the specified syntax, as well
1290 /// as the sub-element in the name. It returns an empty string and ~0U for
1291 /// invalid result numbers. For example, in this operation:
1292 ///
1293 /// %x, %y:2, %z = foo.op
1294 ///
1295 /// getResultName(0) == {"x", 0 }
1296 /// getResultName(1) == {"y", 0 }
1297 /// getResultName(2) == {"y", 1 }
1298 /// getResultName(3) == {"z", 0 }
1299 /// getResultName(4) == {"", ~0U }
1300 virtual std::pair<StringRef, unsigned>
1301 getResultName(unsigned resultNo) const = 0;
1302
1303 /// Return the number of declared SSA results. This returns 4 for the foo.op
1304 /// example in the comment for `getResultName`.
1305 virtual size_t getNumResults() const = 0;
1306
1307 // These methods emit an error and return failure or success. This allows
1308 // these to be chained together into a linear sequence of || expressions in
1309 // many cases.
1310
1311 /// Parse an operation in its generic form.
1312 /// The parsed operation is parsed in the current context and inserted in the
1313 /// provided block and insertion point. The results produced by this operation
1314 /// aren't mapped to any named value in the parser. Returns nullptr on
1315 /// failure.
1316 virtual Operation *parseGenericOperation(Block *insertBlock,
1317 Block::iterator insertPt) = 0;
1318
1319 /// Parse the name of an operation, in the custom form. On success, return a
1320 /// an object of type 'OperationName'. Otherwise, failure is returned.
1321 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1322
1323 //===--------------------------------------------------------------------===//
1324 // Operand Parsing
1325 //===--------------------------------------------------------------------===//
1326
1327 /// This is the representation of an operand reference.
1328 struct UnresolvedOperand {
1329 SMLoc location; // Location of the token.
1330 StringRef name; // Value name, e.g. %42 or %abc
1331 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1332 };
1333
1334 /// Parse different components, viz., use-info of operand(s), successor(s),
1335 /// region(s), attribute(s) and function-type, of the generic form of an
1336 /// operation instance and populate the input operation-state 'result' with
1337 /// those components. If any of the components is explicitly provided, then
1338 /// skip parsing that component.
1339 virtual ParseResult parseGenericOperationAfterOpName(
1340 OperationState &result,
1341 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt,
1342 Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt,
1343 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1344 std::nullopt,
1345 Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
1346 Optional<FunctionType> parsedFnType = std::nullopt) = 0;
1347
1348 /// Parse a single SSA value operand name along with a result number if
1349 /// `allowResultNumber` is true.
1350 virtual ParseResult parseOperand(UnresolvedOperand &result,
1351 bool allowResultNumber = true) = 0;
1352
1353 /// Parse a single operand if present.
1354 virtual OptionalParseResult
1355 parseOptionalOperand(UnresolvedOperand &result,
1356 bool allowResultNumber = true) = 0;
1357
1358 /// Parse zero or more SSA comma-separated operand references with a specified
1359 /// surrounding delimiter, and an optional required operand count.
1360 virtual ParseResult
1361 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1362 Delimiter delimiter = Delimiter::None,
1363 bool allowResultNumber = true,
1364 int requiredOperandCount = -1) = 0;
1365
1366 /// Parse a specified number of comma separated operands.
1367 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1368 int requiredOperandCount,
1369 Delimiter delimiter = Delimiter::None) {
1370 return parseOperandList(result, delimiter,
1371 /*allowResultNumber=*/true, requiredOperandCount);
1372 }
1373
1374 /// Parse zero or more trailing SSA comma-separated trailing operand
1375 /// references with a specified surrounding delimiter, and an optional
1376 /// required operand count. A leading comma is expected before the
1377 /// operands.
1378 ParseResult
1379 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1380 Delimiter delimiter = Delimiter::None) {
1381 if (failed(parseOptionalComma()))
1382 return success(); // The comma is optional.
1383 return parseOperandList(result, delimiter);
1384 }
1385
1386 /// Resolve an operand to an SSA value, emitting an error on failure.
1387 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1388 Type type,
1389 SmallVectorImpl<Value> &result) = 0;
1390
1391 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1392 /// appending the results to the list on success. This method should be used
1393 /// when all operands have the same type.
1394 template <typename Operands = ArrayRef<UnresolvedOperand>>
1395 ParseResult resolveOperands(Operands &&operands, Type type,
1396 SmallVectorImpl<Value> &result) {
1397 for (const UnresolvedOperand &operand : operands)
1398 if (resolveOperand(operand, type, result))
1399 return failure();
1400 return success();
1401 }
1402 template <typename Operands = ArrayRef<UnresolvedOperand>>
1403 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1404 SmallVectorImpl<Value> &result) {
1405 return resolveOperands(std::forward<Operands>(operands), type, result);
1406 }
1407
1408 /// Resolve a list of operands and a list of operand types to SSA values,
1409 /// emitting an error and returning failure, or appending the results
1410 /// to the list on success.
1411 template <typename Operands = ArrayRef<UnresolvedOperand>,
1412 typename Types = ArrayRef<Type>>
1413 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1414 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1415 SmallVectorImpl<Value> &result) {
1416 size_t operandSize = std::distance(operands.begin(), operands.end());
1417 size_t typeSize = std::distance(types.begin(), types.end());
1418 if (operandSize != typeSize)
1419 return emitError(loc)
1420 << operandSize << " operands present, but expected " << typeSize;
1421
1422 for (auto [operand, type] : llvm::zip(operands, types))
1423 if (resolveOperand(operand, type, result))
1424 return failure();
1425 return success();
1426 }
1427
1428 /// Parses an affine map attribute where dims and symbols are SSA operands.
1429 /// Operand values must come from single-result sources, and be valid
1430 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1431 virtual ParseResult
1432 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1433 Attribute &map, StringRef attrName,
1434 NamedAttrList &attrs,
1435 Delimiter delimiter = Delimiter::Square) = 0;
1436
1437 /// Parses an affine expression where dims and symbols are SSA operands.
1438 /// Operand values must come from single-result sources, and be valid
1439 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1440 virtual ParseResult
1441 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1442 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1443 AffineExpr &expr) = 0;
1444
1445 //===--------------------------------------------------------------------===//
1446 // Argument Parsing
1447 //===--------------------------------------------------------------------===//
1448
1449 struct Argument {
1450 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1451 Type type; // Type.
1452 DictionaryAttr attrs; // Attributes if present.
1453 Optional<Location> sourceLoc; // Source location specifier if present.
1454 };
1455
1456 /// Parse a single argument with the following syntax:
1457 ///
1458 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1459 ///
1460 /// If `allowType` is false or `allowAttrs` are false then the respective
1461 /// parts of the grammar are not parsed.
1462 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1463 bool allowAttrs = false) = 0;
1464
1465 /// Parse a single argument if present.
1466 virtual OptionalParseResult
1467 parseOptionalArgument(Argument &result, bool allowType = false,
1468 bool allowAttrs = false) = 0;
1469
1470 /// Parse zero or more arguments with a specified surrounding delimiter.
1471 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1472 Delimiter delimiter = Delimiter::None,
1473 bool allowType = false,
1474 bool allowAttrs = false) = 0;
1475
1476 //===--------------------------------------------------------------------===//
1477 // Region Parsing
1478 //===--------------------------------------------------------------------===//
1479
1480 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1481 /// moved to the op regions after the op is created. The first block of the
1482 /// region takes 'arguments'.
1483 ///
1484 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1485 /// shadow the names of other existing SSA values defined above the region
1486 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1487 /// to operations that are 'IsolatedFromAbove'.
1488 virtual ParseResult parseRegion(Region &region,
1489 ArrayRef<Argument> arguments = {},
1490 bool enableNameShadowing = false) = 0;
1491
1492 /// Parses a region if present.
1493 virtual OptionalParseResult
1494 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1495 bool enableNameShadowing = false) = 0;
1496
1497 /// Parses a region if present. If the region is present, a new region is
1498 /// allocated and placed in `region`. If no region is present or on failure,
1499 /// `region` remains untouched.
1500 virtual OptionalParseResult
1501 parseOptionalRegion(std::unique_ptr<Region> &region,
1502 ArrayRef<Argument> arguments = {},
1503 bool enableNameShadowing = false) = 0;
1504
1505 //===--------------------------------------------------------------------===//
1506 // Successor Parsing
1507 //===--------------------------------------------------------------------===//
1508
1509 /// Parse a single operation successor.
1510 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1511
1512 /// Parse an optional operation successor.
1513 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1514
1515 /// Parse a single operation successor and its operand list.
1516 virtual ParseResult
1517 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1518
1519 //===--------------------------------------------------------------------===//
1520 // Type Parsing
1521 //===--------------------------------------------------------------------===//
1522
1523 /// Parse a list of assignments of the form
1524 /// (%x1 = %y1, %x2 = %y2, ...)
1525 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1526 SmallVectorImpl<UnresolvedOperand> &rhs) {
1527 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1528 if (!result.has_value())
1529 return emitError(getCurrentLocation(), "expected '('");
1530 return result.value();
1531 }
1532
1533 virtual OptionalParseResult
1534 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1535 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1536};
1537
1538//===--------------------------------------------------------------------===//
1539// Dialect OpAsm interface.
1540//===--------------------------------------------------------------------===//
1541
1542/// A functor used to set the name of the start of a result group of an
1543/// operation. See 'getAsmResultNames' below for more details.
1544using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1545
1546/// A functor used to set the name of blocks in regions directly nested under
1547/// an operation.
1548using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1549
1550class OpAsmDialectInterface
1551 : public DialectInterface::Base<OpAsmDialectInterface> {
1552public:
1553 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1554
1555 //===------------------------------------------------------------------===//
1556 // Aliases
1557 //===------------------------------------------------------------------===//
1558
1559 /// Holds the result of `getAlias` hook call.
1560 enum class AliasResult {
1561 /// The object (type or attribute) is not supported by the hook
1562 /// and an alias was not provided.
1563 NoAlias,
1564 /// An alias was provided, but it might be overriden by other hook.
1565 OverridableAlias,
1566 /// An alias was provided and it should be used
1567 /// (no other hooks will be checked).
1568 FinalAlias
1569 };
1570
1571 /// Hooks for getting an alias identifier alias for a given symbol, that is
1572 /// not necessarily a part of this dialect. The identifier is used in place of
1573 /// the symbol when printing textual IR. These aliases must not contain `.` or
1574 /// end with a numeric digit([0-9]+).
1575 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1576 return AliasResult::NoAlias;
1577 }
1578 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1579 return AliasResult::NoAlias;
1580 }
1581
1582 //===--------------------------------------------------------------------===//
1583 // Resources
1584 //===--------------------------------------------------------------------===//
1585
1586 /// Declare a resource with the given key, returning a handle to use for any
1587 /// references of this resource key within the IR during parsing. The result
1588 /// of `getResourceKey` on the returned handle is permitted to be different
1589 /// than `key`.
1590 virtual FailureOr<AsmDialectResourceHandle>
1591 declareResource(StringRef key) const {
1592 return failure();
1593 }
1594
1595 /// Return a key to use for the given resource. This key should uniquely
1596 /// identify this resource within the dialect.
1597 virtual std::string
1598 getResourceKey(const AsmDialectResourceHandle &handle) const {
1599 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
1600 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
;
1601 }
1602
1603 /// Hook for parsing resource entries. Returns failure if the entry was not
1604 /// valid, or could otherwise not be processed correctly. Any necessary errors
1605 /// can be emitted via the provided entry.
1606 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1607
1608 /// Hook for building resources to use during printing. The given `op` may be
1609 /// inspected to help determine what information to include.
1610 /// `referencedResources` contains all of the resources detected when printing
1611 /// 'op'.
1612 virtual void
1613 buildResources(Operation *op,
1614 const SetVector<AsmDialectResourceHandle> &referencedResources,
1615 AsmResourceBuilder &builder) const {}
1616};
1617} // namespace mlir
1618
1619//===--------------------------------------------------------------------===//
1620// Operation OpAsm interface.
1621//===--------------------------------------------------------------------===//
1622
1623/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1624#include "mlir/IR/OpAsmInterface.h.inc"
1625
1626namespace llvm {
1627template <>
1628struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1629 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1630 return {DenseMapInfo<void *>::getEmptyKey(),
1631 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1632 }
1633 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1634 return {DenseMapInfo<void *>::getTombstoneKey(),
1635 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1636 }
1637 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1638 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1639 }
1640 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1641 const mlir::AsmDialectResourceHandle &rhs) {
1642 return lhs.getResource() == rhs.getResource();
1643 }
1644};
1645} // namespace llvm
1646
1647#endif