File: | build/source/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp |
Warning: | line 1090, column 9 1st function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- LinalgTransformOps.cpp - Implementation of Linalg transform ops ----===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.h" | |||
10 | ||||
11 | #include "mlir/AsmParser/AsmParser.h" | |||
12 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | |||
13 | #include "mlir/Dialect/Arith/IR/Arith.h" | |||
14 | #include "mlir/Dialect/GPU/IR/GPUDialect.h" | |||
15 | #include "mlir/Dialect/Linalg/IR/Linalg.h" | |||
16 | #include "mlir/Dialect/Linalg/Transforms/Transforms.h" | |||
17 | #include "mlir/Dialect/PDL/IR/PDL.h" | |||
18 | #include "mlir/Dialect/PDL/IR/PDLTypes.h" | |||
19 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" | |||
20 | #include "mlir/Dialect/Transform/IR/TransformDialect.h" | |||
21 | #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" | |||
22 | #include "mlir/Dialect/Transform/IR/TransformUtils.h" | |||
23 | #include "mlir/Dialect/Transform/Utils/Utils.h" | |||
24 | #include "mlir/IR/BuiltinTypes.h" | |||
25 | #include "mlir/IR/Matchers.h" | |||
26 | #include "mlir/IR/OpDefinition.h" | |||
27 | #include "mlir/Interfaces/TilingInterface.h" | |||
28 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" | |||
29 | #include "llvm/ADT/StringSet.h" | |||
30 | #include "llvm/Support/Debug.h" | |||
31 | ||||
32 | using namespace mlir; | |||
33 | using namespace mlir::linalg; | |||
34 | using namespace mlir::transform; | |||
35 | ||||
36 | #define DEBUG_TYPE"linalg-transforms" "linalg-transforms" | |||
37 | ||||
38 | /// Attempts to apply the pattern specified as template argument to the given | |||
39 | /// operation. The pattern is expected to have a `returningMatchAndRewrite` | |||
40 | /// function that returns the "main" result or failure. Returns failure if the | |||
41 | /// pattern failed to apply. Extra arguments are forwarded to the pattern | |||
42 | /// constructor. | |||
43 | template <typename PatternTy, typename... Args> | |||
44 | static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) { | |||
45 | // Check if the given operation has the type expected by the pattern. | |||
46 | using OpTy = typename llvm::function_traits< | |||
47 | decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; | |||
48 | auto op = dyn_cast<OpTy>(operation); | |||
49 | if (!op) | |||
50 | return failure(); | |||
51 | ||||
52 | // Apply the pattern directly to the op. | |||
53 | PatternTy pattern(operation->getContext(), std::forward<Args>(args)...); | |||
54 | TrivialPatternRewriter rewriter(operation->getContext()); | |||
55 | rewriter.setInsertionPoint(operation); | |||
56 | auto result = pattern.returningMatchAndRewrite(op, rewriter); | |||
57 | if (failed(result)) | |||
58 | return failure(); | |||
59 | return cast<LinalgOp>(result->getOperation()); | |||
60 | } | |||
61 | ||||
62 | //===----------------------------------------------------------------------===// | |||
63 | // DecomposeOp | |||
64 | //===----------------------------------------------------------------------===// | |||
65 | ||||
66 | DiagnosedSilenceableFailure | |||
67 | transform::DecomposeOp::applyToOne(linalg::LinalgOp target, | |||
68 | SmallVectorImpl<Operation *> &results, | |||
69 | transform::TransformState &state) { | |||
70 | FailureOr<LinalgOp> windowedNhwc = | |||
71 | tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNhwcHwcfOp, | |||
72 | Conv1DNwcWcfOp>>(target); | |||
73 | if (succeeded(windowedNhwc)) { | |||
74 | results.push_back(*windowedNhwc); | |||
75 | return DiagnosedSilenceableFailure::success(); | |||
76 | } | |||
77 | FailureOr<LinalgOp> windowedNchw = | |||
78 | tryApply<DownscaleSizeOneWindowed2DConvolution<linalg::Conv2DNchwFchwOp, | |||
79 | Conv1DNcwFcwOp>>(target); | |||
80 | if (succeeded(windowedNchw)) { | |||
81 | results.push_back(*windowedNchw); | |||
82 | return DiagnosedSilenceableFailure::success(); | |||
83 | } | |||
84 | FailureOr<LinalgOp> depthwise = | |||
85 | tryApply<DownscaleDepthwiseConv2DNhwcHwcOp>(target); | |||
86 | if (succeeded(depthwise)) { | |||
87 | results.push_back(*depthwise); | |||
88 | return DiagnosedSilenceableFailure::success(); | |||
89 | } | |||
90 | results.assign(1, nullptr); | |||
91 | return emitDefaultSilenceableFailure(target); | |||
92 | } | |||
93 | //===----------------------------------------------------------------------===// | |||
94 | // FuseOp | |||
95 | //===----------------------------------------------------------------------===// | |||
96 | ||||
97 | /// Apply a tiling transformation to all payload ops and store both the | |||
98 | /// tiled operation as well as the created tile loops. | |||
99 | static LogicalResult applyTilingToAll( | |||
100 | Operation *transformOp, ArrayRef<Operation *> payloadOps, unsigned numLoops, | |||
101 | transform::TransformResults &transformResults, | |||
102 | function_ref<FailureOr<scf::SCFTileAndFuseResult>(TilingInterface)> | |||
103 | applyFn) { | |||
104 | SmallVector<Operation *> tiledLinalgOps; | |||
105 | SmallVector<SmallVector<Operation *>> loopOps(numLoops); | |||
106 | for (unsigned int i = 0; i < numLoops; ++i) | |||
107 | loopOps[i].reserve(payloadOps.size()); | |||
108 | ||||
109 | for (Operation *target : payloadOps) { | |||
110 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(target); | |||
111 | if (!tilingInterfaceOp) | |||
112 | return transformOp->emitError("only TilingInterface ops are supported"); | |||
113 | ||||
114 | TrivialPatternRewriter rewriter(target->getContext()); | |||
115 | rewriter.setInsertionPoint(target); | |||
116 | FailureOr<scf::SCFTileAndFuseResult> tiledResults = | |||
117 | applyFn(tilingInterfaceOp); | |||
118 | if (failed(tiledResults)) | |||
119 | return failure(); | |||
120 | ||||
121 | // Perform the replacement of tiled and fused values. | |||
122 | SmallVector<Operation *> opsToReplace{target}; | |||
123 | llvm::append_range(opsToReplace, tiledResults->fusedProducers); | |||
124 | for (Operation *toReplace : opsToReplace) { | |||
125 | SmallVector<Value> replacements; | |||
126 | replacements.reserve(toReplace->getNumResults()); | |||
127 | for (OpResult res : toReplace->getResults()) { | |||
128 | auto it = tiledResults->replacements.find(res); | |||
129 | if (it == tiledResults->replacements.end()) | |||
130 | replacements.push_back(res); | |||
131 | else | |||
132 | replacements.push_back(it->getSecond()); | |||
133 | } | |||
134 | rewriter.replaceOp(toReplace, replacements); | |||
135 | } | |||
136 | ||||
137 | // Report back the relevant handles to the transform op. | |||
138 | tiledLinalgOps.push_back(tiledResults->tiledAndFusedOps.front()); | |||
139 | assert(tiledResults->loops.size() == numLoops &&(static_cast <bool> (tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 141, __extension__ __PRETTY_FUNCTION__)) | |||
140 | "Mismatched number of loops, tile and fuse transform should have "(static_cast <bool> (tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 141, __extension__ __PRETTY_FUNCTION__)) | |||
141 | "failed")(static_cast <bool> (tiledResults->loops.size() == numLoops && "Mismatched number of loops, tile and fuse transform should have " "failed") ? void (0) : __assert_fail ("tiledResults->loops.size() == numLoops && \"Mismatched number of loops, tile and fuse transform should have \" \"failed\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 141, __extension__ __PRETTY_FUNCTION__)); | |||
142 | for (unsigned int i = 0; i < numLoops; ++i) | |||
143 | loopOps[i].push_back(tiledResults->loops[i]); | |||
144 | } | |||
145 | ||||
146 | transformResults.set(transformOp->getOpResult(0), tiledLinalgOps); | |||
147 | for (unsigned int i = 0; i < numLoops; ++i) | |||
148 | transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); | |||
149 | ||||
150 | return success(); | |||
151 | } | |||
152 | ||||
153 | /// Parse a tiling-like operation that returns the tiled op as well as the | |||
154 | /// created tile loops. The function counts the non-zero tile sizes to compute | |||
155 | /// the number of results. | |||
156 | static ParseResult parseTileLikeOp(OpAsmParser &parser, OperationState &result, | |||
157 | StringRef sizesAttrName) { | |||
158 | OpAsmParser::UnresolvedOperand targetOperand; | |||
159 | SMLoc opLoc = parser.getCurrentLocation(); | |||
160 | if (parser.parseOperand(targetOperand) || | |||
161 | parser.parseOptionalAttrDict(result.attributes)) | |||
162 | return failure(); | |||
163 | Attribute sizesAttr = result.attributes.get(sizesAttrName); | |||
164 | if (!sizesAttr) | |||
165 | return parser.emitError(opLoc) | |||
166 | << "expected '" << sizesAttrName << "' attribute"; | |||
167 | auto sizesArrayAttr = sizesAttr.dyn_cast<ArrayAttr>(); | |||
168 | if (!sizesArrayAttr) | |||
169 | return parser.emitError(opLoc) | |||
170 | << "'" << sizesAttrName << "' attribute must be an array"; | |||
171 | Type pdlOpType = parser.getBuilder().getType<pdl::OperationType>(); | |||
172 | size_t numExpectedLoops = | |||
173 | sizesArrayAttr.size() - | |||
174 | llvm::count(extractFromI64ArrayAttr(sizesArrayAttr), 0); | |||
175 | result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOpType)); | |||
176 | if (parser.resolveOperand(targetOperand, pdlOpType, result.operands)) | |||
177 | return failure(); | |||
178 | return success(); | |||
179 | } | |||
180 | ||||
181 | DiagnosedSilenceableFailure | |||
182 | transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, | |||
183 | mlir::transform::TransformState &state) { | |||
184 | SmallVector<int64_t> tileSizes = extractFromI64ArrayAttr(getTileSizes()); | |||
185 | SmallVector<int64_t> tileInterchange = | |||
186 | extractFromI64ArrayAttr(getTileInterchange()); | |||
187 | ||||
188 | scf::SCFTilingOptions tilingOptions; | |||
189 | tilingOptions.interchangeVector = tileInterchange; | |||
190 | tilingOptions = tilingOptions.setTileSizes(tileSizes); | |||
191 | scf::SCFTileAndFuseOptions tileAndFuseOptions; | |||
192 | tileAndFuseOptions.tilingOptions = tilingOptions; | |||
193 | LogicalResult result = applyTilingToAll( | |||
194 | getOperation(), state.getPayloadOps(getTarget()), | |||
195 | tileSizes.size() - llvm::count(tileSizes, 0), transformResults, | |||
196 | [&](TilingInterface tilingInterfaceOp) | |||
197 | -> FailureOr<scf::SCFTileAndFuseResult> { | |||
198 | TrivialPatternRewriter rewriter(getContext()); | |||
199 | return tileConsumerAndFuseProducerGreedilyUsingSCFForOp( | |||
200 | rewriter, tilingInterfaceOp, tileAndFuseOptions); | |||
201 | }); | |||
202 | return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() | |||
203 | : DiagnosedSilenceableFailure::success(); | |||
204 | } | |||
205 | ||||
206 | ParseResult transform::FuseOp::parse(OpAsmParser &parser, | |||
207 | OperationState &result) { | |||
208 | return parseTileLikeOp( | |||
209 | parser, result, | |||
210 | transform::FuseOp::getTileSizesAttrName(result.name).getValue()); | |||
211 | } | |||
212 | ||||
213 | void transform::FuseOp::print(OpAsmPrinter &p) { | |||
214 | p << ' '; | |||
215 | p << getTarget(); | |||
216 | p.printOptionalAttrDict((*this)->getAttrs()); | |||
217 | } | |||
218 | ||||
219 | LogicalResult transform::FuseOp::verify() { | |||
220 | SmallVector<int64_t> permutation = | |||
221 | extractFromI64ArrayAttr(getTileInterchange()); | |||
222 | auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); | |||
223 | if (!std::is_permutation(sequence.begin(), sequence.end(), | |||
224 | permutation.begin(), permutation.end())) { | |||
225 | return emitOpError() << "expects interchange to be a permutation, found " | |||
226 | << getTileInterchange(); | |||
227 | } | |||
228 | return success(); | |||
229 | } | |||
230 | ||||
231 | //===----------------------------------------------------------------------===// | |||
232 | // FuseIntoContainingOp | |||
233 | //===----------------------------------------------------------------------===// | |||
234 | ||||
235 | void transform::FuseIntoContainingOp::build(OpBuilder &builder, | |||
236 | OperationState &result, | |||
237 | Value producerOp, | |||
238 | Value containingOp) { | |||
239 | result.addOperands({producerOp, containingOp}); | |||
240 | result.addTypes(pdl::OperationType::get(builder.getContext())); | |||
241 | } | |||
242 | ||||
243 | /// Find the first "extract" user of `producerOp` and tile it right before its | |||
244 | /// use. The tiled op is fused under the `containingOp`. | |||
245 | /// Return this fused op on success or nullptr if anything fails. | |||
246 | static Operation *tileAndFuseFirstExtractUse(RewriterBase &rewriter, | |||
247 | Diagnostic &diag, | |||
248 | Operation *producerOp, | |||
249 | Operation *containingOp) { | |||
250 | LLVM_DEBUG(llvm::dbgs() << "Try to fuse a direct extract use\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "Try to fuse a direct extract use\n" ; } } while (false); | |||
251 | auto tileableProducer = dyn_cast<TilingInterface>(producerOp); | |||
252 | if (!tileableProducer) { | |||
253 | diag.attachNote(producerOp->getLoc()) | |||
254 | << "producer is not a TileableInterface: " << *producerOp; | |||
255 | return nullptr; | |||
256 | } | |||
257 | ||||
258 | // Search the producer slices accessed within the containing operation. | |||
259 | // TODO: Generalize to more extract/insert/parallel_insert triples, maybe | |||
260 | // evolve into an interface. | |||
261 | auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) { | |||
262 | auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); | |||
263 | return sliceOp && containingOp->isProperAncestor(sliceOp); | |||
264 | }); | |||
265 | ||||
266 | // Find a fusion opportunity. | |||
267 | if (it == tileableProducer->getUsers().end()) { | |||
268 | diag.attachNote(tileableProducer->getLoc()) | |||
269 | << "could not find fusion opportunity for: " << *tileableProducer; | |||
270 | return nullptr; | |||
271 | } | |||
272 | auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it); | |||
273 | ||||
274 | // Try to fuse the producer in-place. | |||
275 | OpBuilder::InsertionGuard guard(rewriter); | |||
276 | rewriter.setInsertionPoint(sliceOpToTile); | |||
277 | ||||
278 | // Tile the producer. | |||
279 | int64_t resultNumber = | |||
280 | sliceOpToTile.getSource().cast<OpResult>().getResultNumber(); | |||
281 | LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "resultNumber: " << resultNumber << "\n"; } } while (false); | |||
282 | ||||
283 | FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue( | |||
284 | rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), | |||
285 | sliceOpToTile.getMixedSizes()); | |||
286 | if (failed(tiledProducer)) { | |||
287 | diag.attachNote(tileableProducer->getLoc()) | |||
288 | << "failed to tile producer op: " << *tileableProducer; | |||
289 | return nullptr; | |||
290 | } | |||
291 | LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"; } } while (false); | |||
292 | ||||
293 | // Replace the extract op. | |||
294 | Operation *fusedOp = tiledProducer->getDefiningOp(); | |||
295 | auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( | |||
296 | rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), | |||
297 | sliceOpToTile->getResult(0) | |||
298 | .getType() | |||
299 | .cast<RankedTensorType>() | |||
300 | .getShape()); | |||
301 | assert(succeeded(maybeRankReduced) && "unexpected shape")(static_cast <bool> (succeeded(maybeRankReduced) && "unexpected shape") ? void (0) : __assert_fail ("succeeded(maybeRankReduced) && \"unexpected shape\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 301, __extension__ __PRETTY_FUNCTION__)); | |||
302 | rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); | |||
303 | return fusedOp; | |||
304 | } | |||
305 | ||||
306 | /// First, find the first "scf::ForeachThreadOp" user of `producerOp` and ensure | |||
307 | /// it is exactly the `containingOp`, otherwise bail. | |||
308 | /// Then, find the first "extract" user of the tied block argument and tile it | |||
309 | /// right before its "extract" use. The tiled op is fused under the | |||
310 | /// `containingOp`. | |||
311 | /// Return this fused op on success or nullptr if anything fails. | |||
312 | static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( | |||
313 | RewriterBase &rewriter, Diagnostic &diag, Operation *producerOp, | |||
314 | Operation *containingOp) { | |||
315 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "Try to fuse an extract use through block argument\n" ; } } while (false) | |||
316 | llvm::dbgs() << "Try to fuse an extract use through block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "Try to fuse an extract use through block argument\n" ; } } while (false); | |||
317 | ||||
318 | auto tileableProducer = dyn_cast<TilingInterface>(producerOp); | |||
319 | if (!tileableProducer) { | |||
320 | diag.attachNote(producerOp->getLoc()) | |||
321 | << "producer is not a TileableInterface: " << *producerOp; | |||
322 | return nullptr; | |||
323 | } | |||
324 | ||||
325 | // Search the first use by a "scf::ForeachThreadOp" user. | |||
326 | scf::ForeachThreadOp foreachThreadOp; | |||
327 | auto itProducerUses = | |||
328 | llvm::find_if(tileableProducer->getUses(), [&](OpOperand &use) { | |||
329 | foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(use.getOwner()); | |||
330 | return foreachThreadOp; | |||
331 | }); | |||
332 | // If it's not from the containing op, return. | |||
333 | if (!foreachThreadOp || foreachThreadOp != containingOp) { | |||
334 | diag.attachNote(tileableProducer->getLoc()) | |||
335 | << "could not find a use by the containing op: " << *tileableProducer; | |||
336 | return nullptr; | |||
337 | } | |||
338 | ||||
339 | // Search the producer slices accessed within the containing | |||
340 | // operation. | |||
341 | // TODO: Generalize to more extract/insert/parallel_insert triples. | |||
342 | // Maybe evolve into an interface. | |||
343 | OpOperand *pUse = &(*itProducerUses); | |||
344 | BlockArgument bbArg = foreachThreadOp.getTiedBlockArgument(pUse); | |||
345 | ||||
346 | // Search the producer slices accessed within the containing operation. | |||
347 | // TODO: Generalize to more extract/insert/parallel_insert triples, maybe | |||
348 | // evolve into an interface. | |||
349 | auto itBBArgUsers = llvm::find_if(bbArg.getUsers(), [&](Operation *user) { | |||
350 | auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user); | |||
351 | return sliceOp && containingOp->isProperAncestor(sliceOp); | |||
352 | }); | |||
353 | ||||
354 | // Find a fusion opportunity. | |||
355 | if (itBBArgUsers == bbArg.getUsers().end()) { | |||
356 | diag.attachNote(containingOp->getLoc()) | |||
357 | << "could not find fusion opportunity for bbArg: " << bbArg; | |||
358 | return nullptr; | |||
359 | } | |||
360 | auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*itBBArgUsers); | |||
361 | ||||
362 | // Try to fuse the producer in-place. | |||
363 | OpBuilder::InsertionGuard guard(rewriter); | |||
364 | rewriter.setInsertionPoint(sliceOpToTile); | |||
365 | ||||
366 | // Replace the use in the tileableProducer before tiling: clone, replace and | |||
367 | // then tile. | |||
368 | int64_t resultNumber = pUse->get().cast<OpResult>().getResultNumber(); | |||
369 | LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "resultNumber: " << resultNumber << "\n"; } } while (false); | |||
370 | ||||
371 | // Gather destination tensors. | |||
372 | SmallVector<Value> destinationTensors; | |||
373 | if (failed(tensor::getOrCreateDestinations( | |||
374 | rewriter, tileableProducer->getLoc(), tileableProducer, | |||
375 | destinationTensors))) { | |||
376 | diag.attachNote(tileableProducer->getLoc()) | |||
377 | << "failed to get destination tensors for: " << *tileableProducer; | |||
378 | return nullptr; | |||
379 | } | |||
380 | ||||
381 | BlockAndValueMapping bvm; | |||
382 | bvm.map(destinationTensors[resultNumber], bbArg); | |||
383 | auto tileableProducerClone = | |||
384 | cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm)); | |||
385 | auto scopeGuard = | |||
386 | llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); }); | |||
387 | ||||
388 | // Tile the producer. | |||
389 | FailureOr<Value> tiledProducer = | |||
390 | tileableProducerClone.generateResultTileValue( | |||
391 | rewriter, resultNumber, sliceOpToTile.getMixedOffsets(), | |||
392 | sliceOpToTile.getMixedSizes()); | |||
393 | if (failed(tiledProducer)) { | |||
394 | diag.attachNote(tileableProducer->getLoc()) | |||
395 | << "failed to tile producer op: " << *tileableProducer; | |||
396 | return nullptr; | |||
397 | } | |||
398 | LLVM_DEBUG(llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "tiledProducer: " << *tiledProducer << "\n"; } } while (false); | |||
399 | ||||
400 | // Replace the extract op. | |||
401 | Operation *fusedOp = tiledProducer->getDefiningOp(); | |||
402 | auto maybeRankReduced = tensor::ExtractSliceOp::rankReduceIfNeeded( | |||
403 | rewriter, sliceOpToTile->getLoc(), fusedOp->getResult(resultNumber), | |||
404 | sliceOpToTile->getResult(0) | |||
405 | .getType() | |||
406 | .cast<RankedTensorType>() | |||
407 | .getShape()); | |||
408 | assert(succeeded(maybeRankReduced) && "unexpected shape")(static_cast <bool> (succeeded(maybeRankReduced) && "unexpected shape") ? void (0) : __assert_fail ("succeeded(maybeRankReduced) && \"unexpected shape\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 408, __extension__ __PRETTY_FUNCTION__)); | |||
409 | rewriter.replaceOp(sliceOpToTile, *maybeRankReduced); | |||
410 | ||||
411 | // Replace the use in containingOp. | |||
412 | rewriter.updateRootInPlace(containingOp, [&]() { | |||
413 | containingOp->setOperand(pUse->getOperandNumber(), | |||
414 | destinationTensors.front()); | |||
415 | }); | |||
416 | ||||
417 | return fusedOp; | |||
418 | } | |||
419 | ||||
420 | static Operation *cloneAndFuseFirstUse(RewriterBase &rewriter, Diagnostic &diag, | |||
421 | Operation *producerOp, | |||
422 | Operation *containingOp) { | |||
423 | LLVM_DEBUG(llvm::dbgs() << "Try to fuse an use by cloning\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "Try to fuse an use by cloning\n" ; } } while (false); | |||
424 | ||||
425 | // Gather all uses inside the containing op. | |||
426 | SmallVector<OpOperand *> uses; | |||
427 | for (OpResult result : producerOp->getOpResults()) { | |||
428 | for (OpOperand &use : result.getUses()) { | |||
429 | if (containingOp->isProperAncestor(use.getOwner())) { | |||
430 | uses.push_back(&use); | |||
431 | continue; | |||
432 | } | |||
433 | // Cannot clone and fuse if the use is by the containing op itself: fail | |||
434 | // immediately. | |||
435 | if (containingOp == use.getOwner()) { | |||
436 | diag.attachNote(producerOp->getLoc()) | |||
437 | << "producer op use by containing op cannot be fused by cloning"; | |||
438 | return nullptr; | |||
439 | } | |||
440 | } | |||
441 | } | |||
442 | ||||
443 | // Check for a non-empty list of fusion opportunities. | |||
444 | if (uses.empty()) { | |||
445 | diag.attachNote(producerOp->getLoc()) << "no fusion opportunity by cloning"; | |||
446 | return nullptr; | |||
447 | } | |||
448 | ||||
449 | // Clone and fuse inside the containing op. | |||
450 | Operation *fusedOp = nullptr; | |||
451 | OpOperand *use = uses.front(); | |||
452 | // Parallel insert slice is not a valid clone destination. | |||
453 | // TODO: Generalize to other type of ops. | |||
454 | assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp >(use->getOwner()) && "Parallel insert slice is not a valid clone destination" ) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 455, __extension__ __PRETTY_FUNCTION__)) | |||
455 | "Parallel insert slice is not a valid clone destination")(static_cast <bool> (!isa<tensor::ParallelInsertSliceOp >(use->getOwner()) && "Parallel insert slice is not a valid clone destination" ) ? void (0) : __assert_fail ("!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) && \"Parallel insert slice is not a valid clone destination\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 455, __extension__ __PRETTY_FUNCTION__)); | |||
456 | unsigned resultNumber = use->get().cast<OpResult>().getResultNumber(); | |||
457 | LLVM_DEBUG(llvm::dbgs() << "resultNumber: " << resultNumber << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "resultNumber: " << resultNumber << "\n"; } } while (false); | |||
458 | ||||
459 | OpBuilder::InsertionGuard guard(rewriter); | |||
460 | rewriter.setInsertionPoint(use->getOwner()); | |||
461 | fusedOp = rewriter.clone(*producerOp); | |||
462 | rewriter.updateRootInPlace( | |||
463 | use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); }); | |||
464 | ||||
465 | return fusedOp; | |||
466 | } | |||
467 | ||||
468 | DiagnosedSilenceableFailure | |||
469 | transform::FuseIntoContainingOp::apply(transform::TransformResults &results, | |||
470 | transform::TransformState &state) { | |||
471 | SmallVector<Operation *> fusedOps; | |||
472 | ArrayRef<Operation *> producerOps = state.getPayloadOps(getProducerOp()); | |||
473 | // If nothing to fuse, propagate success. | |||
474 | if (producerOps.empty()) { | |||
475 | results.set(getFusedOp().cast<OpResult>(), | |||
476 | SmallVector<mlir::Operation *>{}); | |||
477 | return DiagnosedSilenceableFailure::success(); | |||
478 | } | |||
479 | ArrayRef<Operation *> containingOps = state.getPayloadOps(getContainingOp()); | |||
480 | if (containingOps.size() != 1) { | |||
481 | return emitDefiniteFailure() | |||
482 | << "requires exactly one containing_op handle (got " | |||
483 | << containingOps.size() << ")"; | |||
484 | } | |||
485 | Operation *containingOp = containingOps.front(); | |||
486 | ||||
487 | // Helper function to find the next producer that should be fused. Take any | |||
488 | // producer that has a use inside the containing op. | |||
489 | SmallVector<Operation *> remainingProducers(producerOps.begin(), | |||
490 | producerOps.end()); | |||
491 | auto getNextProducer = [&]() -> FailureOr<Operation *> { | |||
492 | for (const auto &it : enumerate(remainingProducers)) { | |||
493 | Operation *producerOp = it.value(); | |||
494 | // The containing op may be a user of producerOp: use isAncestor. | |||
495 | int64_t numUsesInContainingOp = | |||
496 | llvm::count_if(producerOp->getUsers(), [&](Operation *op) { | |||
497 | return containingOp->isAncestor(op); | |||
498 | }); | |||
499 | // TODO: When resolving the TODO below (no duplicate ops), take an op | |||
500 | // that has no use among the remaining producers. This is a topological | |||
501 | // sorting. | |||
502 | if (numUsesInContainingOp > 0) { | |||
503 | if (numUsesInContainingOp == 1) | |||
504 | remainingProducers.erase(remainingProducers.begin() + it.index()); | |||
505 | return producerOp; | |||
506 | } | |||
507 | } | |||
508 | return failure(); | |||
509 | }; | |||
510 | ||||
511 | IRRewriter rewriter(getContext()); | |||
512 | while (!remainingProducers.empty()) { | |||
513 | auto nextProducer = getNextProducer(); | |||
514 | if (failed(nextProducer)) { | |||
515 | results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>()); | |||
516 | Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark); | |||
517 | diag << "could not find next producer to fuse into container"; | |||
518 | return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); | |||
519 | } | |||
520 | ||||
521 | Operation *producerOp = *nextProducer; | |||
522 | ||||
523 | // Default diagnostic, to be complemented with more failure information. | |||
524 | Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark); | |||
525 | diag << "could not fuse " << *producerOp << " into " << *containingOp; | |||
526 | ||||
527 | // TODO: If there are multiple uses of the producer in the containing op, | |||
528 | // we currently tile/clone the op multiple times (once per use). In some | |||
529 | // cases, we can tile/clone once and reuse the value for each use. | |||
530 | // Futhermore, producers should then be traversed according to a | |||
531 | // topological sorting. | |||
532 | Operation *tiled = | |||
533 | tileAndFuseFirstExtractUse(rewriter, diag, producerOp, containingOp); | |||
534 | if (tiled) { | |||
535 | LLVM_DEBUG(llvm::dbgs() << "\nFused a direct extract use\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused a direct extract use\n" << *containingOp; } } while (false) | |||
536 | << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused a direct extract use\n" << *containingOp; } } while (false); | |||
537 | fusedOps.push_back(tiled); | |||
538 | continue; | |||
539 | } | |||
540 | ||||
541 | Operation *tiledContainingOpOperand = | |||
542 | tileAndFuseFirstExtractUseThroughContainingOpBlockArgument( | |||
543 | rewriter, diag, producerOp, containingOp); | |||
544 | if (tiledContainingOpOperand) { | |||
545 | LLVM_DEBUG(llvm::dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n" << *containingOp; } } while (false) | |||
546 | << "\nFused an extract use through block argument\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n" << *containingOp; } } while (false) | |||
547 | << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused an extract use through block argument\n" << *containingOp; } } while (false); | |||
548 | fusedOps.push_back(tiledContainingOpOperand); | |||
549 | continue; | |||
550 | } | |||
551 | ||||
552 | Operation *cloned = | |||
553 | cloneAndFuseFirstUse(rewriter, diag, producerOp, containingOp); | |||
554 | if (cloned) { | |||
555 | LLVM_DEBUG(llvm::dbgs() << "\nFused an use by cloning\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused an use by cloning\n" << *containingOp; } } while (false) | |||
556 | << *containingOp)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("linalg-transforms")) { llvm::dbgs() << "\nFused an use by cloning\n" << *containingOp; } } while (false); | |||
557 | fusedOps.push_back(cloned); | |||
558 | continue; | |||
559 | } | |||
560 | results.set(getFusedOp().cast<OpResult>(), ArrayRef<Operation *>()); | |||
561 | return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); | |||
562 | } | |||
563 | ||||
564 | results.set(getFusedOp().cast<OpResult>(), fusedOps); | |||
565 | return DiagnosedSilenceableFailure::success(); | |||
566 | } | |||
567 | ||||
568 | //===----------------------------------------------------------------------===// | |||
569 | // GeneralizeOp | |||
570 | //===----------------------------------------------------------------------===// | |||
571 | ||||
572 | DiagnosedSilenceableFailure | |||
573 | transform::GeneralizeOp::applyToOne(linalg::LinalgOp target, | |||
574 | SmallVectorImpl<Operation *> &results, | |||
575 | transform::TransformState &state) { | |||
576 | // Exit early if no transformation is needed. | |||
577 | if (isa<GenericOp>(target)) { | |||
578 | results.push_back(target); | |||
579 | return DiagnosedSilenceableFailure::success(); | |||
580 | } | |||
581 | FailureOr<LinalgOp> generic = tryApply<LinalgGeneralizationPattern>(target); | |||
582 | if (succeeded(generic)) { | |||
583 | results.push_back(generic->getOperation()); | |||
584 | return DiagnosedSilenceableFailure::success(); | |||
585 | } | |||
586 | results.assign(1, nullptr); | |||
587 | return emitDefaultSilenceableFailure(target); | |||
588 | } | |||
589 | ||||
590 | //===----------------------------------------------------------------------===// | |||
591 | // InterchangeOp | |||
592 | //===----------------------------------------------------------------------===// | |||
593 | ||||
594 | DiagnosedSilenceableFailure | |||
595 | transform::InterchangeOp::applyToOne(linalg::GenericOp target, | |||
596 | SmallVectorImpl<Operation *> &results, | |||
597 | transform::TransformState &state) { | |||
598 | ArrayRef<int64_t> interchangeVector = getIteratorInterchange(); | |||
599 | // Exit early if no transformation is needed. | |||
600 | if (interchangeVector.empty()) { | |||
601 | results.push_back(target); | |||
602 | return DiagnosedSilenceableFailure::success(); | |||
603 | } | |||
604 | TrivialPatternRewriter rewriter(target->getContext()); | |||
605 | FailureOr<GenericOp> res = | |||
606 | interchangeGenericOp(rewriter, target, | |||
607 | SmallVector<unsigned>(interchangeVector.begin(), | |||
608 | interchangeVector.end())); | |||
609 | if (failed(res)) | |||
610 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
611 | results.push_back(res->getOperation()); | |||
612 | return DiagnosedSilenceableFailure::success(); | |||
613 | } | |||
614 | ||||
615 | LogicalResult transform::InterchangeOp::verify() { | |||
616 | ArrayRef<int64_t> permutation = getIteratorInterchange(); | |||
617 | auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, permutation.size())); | |||
618 | if (!std::is_permutation(sequence.begin(), sequence.end(), | |||
619 | permutation.begin(), permutation.end())) { | |||
620 | return emitOpError() | |||
621 | << "expects iterator_interchange to be a permutation, found " | |||
622 | << getIteratorInterchange(); | |||
623 | } | |||
624 | return success(); | |||
625 | } | |||
626 | ||||
627 | //===---------------------------------------------------------------------===// | |||
628 | // MatchOp | |||
629 | //===---------------------------------------------------------------------===// | |||
630 | ||||
631 | void transform::MatchOp::build(OpBuilder &builder, OperationState &result, | |||
632 | Value target, ArrayRef<StringRef> opNames) { | |||
633 | result.addOperands(target); | |||
634 | result.addAttribute(MatchOp::getOpsAttrName(result.name), | |||
635 | builder.getStrArrayAttr(opNames)); | |||
636 | result.addTypes(pdl::OperationType::get(builder.getContext())); | |||
637 | } | |||
638 | ||||
639 | DiagnosedSilenceableFailure | |||
640 | transform::MatchOp::apply(transform::TransformResults &results, | |||
641 | transform::TransformState &state) { | |||
642 | llvm::StringSet<> strs; | |||
643 | if (getOps().has_value()) | |||
644 | strs.insert(getOps()->getAsValueRange<StringAttr>().begin(), | |||
645 | getOps()->getAsValueRange<StringAttr>().end()); | |||
646 | ||||
647 | ArrayRef<Operation *> payloadOps = state.getPayloadOps(getTarget()); | |||
648 | if (payloadOps.size() != 1) { | |||
649 | results.set(getResult().cast<OpResult>(), {}); | |||
650 | return emitDefiniteFailure("requires exactly one target handle"); | |||
651 | } | |||
652 | ||||
653 | SmallVector<Operation *> res; | |||
654 | auto matchFun = [&](Operation *op) { | |||
655 | if (getOps().has_value() && !strs.contains(op->getName().getStringRef())) | |||
656 | return; | |||
657 | ||||
658 | // Interfaces cannot be matched by name, just by ID. | |||
659 | // So we specifically encode the interfaces we care about for this op. | |||
660 | if (getInterface().has_value()) { | |||
661 | auto iface = getInterface().value(); | |||
662 | if (iface == transform::MatchInterfaceEnum::LinalgOp && | |||
663 | !isa<linalg::LinalgOp>(op)) | |||
664 | return; | |||
665 | if (iface == transform::MatchInterfaceEnum::TilingInterface && | |||
666 | isa<TilingInterface>(op)) | |||
667 | return; | |||
668 | } | |||
669 | ||||
670 | // Check if all specified attributes match. | |||
671 | if (getOpAttrs().has_value()) { | |||
672 | DictionaryAttr opAttrs = getOpAttrs().value(); | |||
673 | for (NamedAttribute attr : opAttrs) { | |||
674 | if (attr.getName() == getInterfaceAttrName() || | |||
675 | attr.getName() == getOpsAttrName()) | |||
676 | continue; | |||
677 | if (!op->hasAttr(attr.getName())) | |||
678 | return; | |||
679 | if (op->getAttr(attr.getName()) != attr.getValue()) | |||
680 | return; | |||
681 | } | |||
682 | } | |||
683 | ||||
684 | if (getFilterResultType().has_value()) { | |||
685 | Type t = getFilterResultType().value(); | |||
686 | if (op->getNumResults() != 1 || op->getResultTypes().front() != t) | |||
687 | return; | |||
688 | } | |||
689 | ||||
690 | // All constraints are satisfied. | |||
691 | res.push_back(op); | |||
692 | return; | |||
693 | }; | |||
694 | ||||
695 | payloadOps.front()->walk(matchFun); | |||
696 | results.set(getResult().cast<OpResult>(), res); | |||
697 | return DiagnosedSilenceableFailure::success(); | |||
698 | } | |||
699 | ||||
700 | //===---------------------------------------------------------------------===// | |||
701 | // MultiTileSizesOp | |||
702 | //===---------------------------------------------------------------------===// | |||
703 | ||||
704 | DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( | |||
705 | LinalgOp target, SmallVector<Operation *> &results, TransformState &state) { | |||
706 | OpBuilder builder(target.getContext()); | |||
707 | builder.setInsertionPoint(target); | |||
708 | OpFoldResult targetSize = builder.getIndexAttr(getTargetSize()); | |||
709 | OpFoldResult divisor = builder.getIndexAttr(getDivisor()); | |||
710 | FailureOr<MultiSizeSpecification> spec = computeMultiTileSizes( | |||
711 | builder, target, getDimension(), targetSize, divisor); | |||
712 | if (failed(spec)) { | |||
713 | return emitSilenceableError() << "could not generate tile size computation"; | |||
714 | } | |||
715 | ||||
716 | AffineExpr s0 = builder.getAffineSymbolExpr(0); | |||
717 | AffineExpr s1 = builder.getAffineSymbolExpr(1); | |||
718 | Operation *splitPoint = | |||
719 | makeComposedAffineApply(builder, target.getLoc(), s0 * s1, | |||
720 | {spec->lowTileSize, spec->lowTripCount}); | |||
721 | Operation *lowTileSize = spec->lowTileSize.getDefiningOp(); | |||
722 | Operation *highTileSize = spec->highTileSize.getDefiningOp(); | |||
723 | assert(lowTileSize && highTileSize && splitPoint &&(static_cast <bool> (lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations" ) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 724, __extension__ __PRETTY_FUNCTION__)) | |||
724 | "tile sizes are not produced by operations")(static_cast <bool> (lowTileSize && highTileSize && splitPoint && "tile sizes are not produced by operations" ) ? void (0) : __assert_fail ("lowTileSize && highTileSize && splitPoint && \"tile sizes are not produced by operations\"" , "mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp" , 724, __extension__ __PRETTY_FUNCTION__)); | |||
725 | results.reserve(results.size() + 3); | |||
726 | results.push_back(lowTileSize); | |||
727 | results.push_back(highTileSize); | |||
728 | results.push_back(splitPoint); | |||
729 | return DiagnosedSilenceableFailure::success(); | |||
730 | } | |||
731 | ||||
732 | void transform::MultiTileSizesOp::getEffects( | |||
733 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
734 | onlyReadsHandle(getTarget(), effects); | |||
735 | producesHandle(getResults(), effects); | |||
736 | modifiesPayload(effects); | |||
737 | } | |||
738 | ||||
739 | //===---------------------------------------------------------------------===// | |||
740 | // PadOp | |||
741 | //===---------------------------------------------------------------------===// | |||
742 | ||||
743 | DiagnosedSilenceableFailure | |||
744 | transform::PadOp::applyToOne(linalg::LinalgOp target, | |||
745 | SmallVectorImpl<Operation *> &results, | |||
746 | transform::TransformState &state) { | |||
747 | // Convert the integer packing flags to booleans. | |||
748 | SmallVector<bool> packPaddings; | |||
749 | for (int64_t packPadding : extractFromI64ArrayAttr(getPackPaddings())) | |||
750 | packPaddings.push_back(static_cast<bool>(packPadding)); | |||
751 | ||||
752 | // Convert the padding values to attributes. | |||
753 | SmallVector<Attribute> paddingValues; | |||
754 | for (auto const &it : | |||
755 | llvm::zip(getPaddingValues(), target->getOperandTypes())) { | |||
756 | auto attr = std::get<0>(it).dyn_cast<TypedAttr>(); | |||
757 | if (!attr) { | |||
758 | emitOpError("expects padding values to be typed attributes"); | |||
759 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
760 | } | |||
761 | Type elementType = getElementTypeOrSelf(std::get<1>(it)); | |||
762 | // Try to parse string attributes to obtain an attribute of element type. | |||
763 | if (auto stringAttr = attr.dyn_cast<StringAttr>()) { | |||
764 | paddingValues.push_back( | |||
765 | parseAttribute(attr.cast<StringAttr>(), elementType)); | |||
766 | if (!paddingValues.back()) { | |||
767 | auto diag = this->emitOpError("expects a padding that parses to ") | |||
768 | << elementType << ", got " << std::get<0>(it); | |||
769 | diag.attachNote(target.getLoc()) << "when applied to this op"; | |||
770 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
771 | } | |||
772 | continue; | |||
773 | } | |||
774 | // Otherwise, add the attribute directly. | |||
775 | if (attr.getType() != elementType) { | |||
776 | auto diag = this->emitOpError("expects a padding value of type ") | |||
777 | << elementType << ", got " << attr; | |||
778 | diag.attachNote(target.getLoc()) << "when applied to this op"; | |||
779 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
780 | } | |||
781 | paddingValues.push_back(attr); | |||
782 | } | |||
783 | ||||
784 | // Extract the transpose vectors. | |||
785 | SmallVector<SmallVector<int64_t>> transposePaddings; | |||
786 | for (Attribute transposeVector : getTransposePaddings().cast<ArrayAttr>()) | |||
787 | transposePaddings.push_back( | |||
788 | extractFromI64ArrayAttr(transposeVector.cast<ArrayAttr>())); | |||
789 | ||||
790 | LinalgPaddingOptions paddingOptions; | |||
791 | paddingOptions.setPaddingValues(paddingValues); | |||
792 | paddingOptions.setPaddingDimensions( | |||
793 | extractFromI64ArrayAttr(getPaddingDimensions())); | |||
794 | paddingOptions.setPackPaddings(packPaddings); | |||
795 | paddingOptions.setHoistPaddings(extractFromI64ArrayAttr(getHoistPaddings())); | |||
796 | paddingOptions.setTransposePaddings(transposePaddings); | |||
797 | ||||
798 | FailureOr<LinalgOp> result = | |||
799 | tryApply<LinalgPaddingPattern>(target, paddingOptions); | |||
800 | if (succeeded(result)) { | |||
801 | results.push_back(result->getOperation()); | |||
802 | return DiagnosedSilenceableFailure::success(); | |||
803 | } | |||
804 | ||||
805 | results.assign(1, nullptr); | |||
806 | return emitDefaultSilenceableFailure(target); | |||
807 | } | |||
808 | ||||
809 | LogicalResult transform::PadOp::verify() { | |||
810 | SmallVector<int64_t> packPaddings = | |||
811 | extractFromI64ArrayAttr(getPackPaddings()); | |||
812 | if (any_of(packPaddings, [](int64_t packPadding) { | |||
813 | return packPadding != 0 && packPadding != 1; | |||
814 | })) { | |||
815 | return emitOpError() | |||
816 | << "expects pack_paddings to contain booleans (0/1), found " | |||
817 | << getPackPaddings(); | |||
818 | } | |||
819 | ||||
820 | SmallVector<int64_t> paddingDimensions = | |||
821 | extractFromI64ArrayAttr(getPaddingDimensions()); | |||
822 | if (any_of(paddingDimensions, | |||
823 | [](int64_t paddingDimension) { return paddingDimension < 0; })) { | |||
824 | return emitOpError() << "expects padding_dimensions to contain positive " | |||
825 | "integers, found " | |||
826 | << getPaddingDimensions(); | |||
827 | } | |||
828 | ||||
829 | SmallVector<int64_t> hoistPaddings = | |||
830 | extractFromI64ArrayAttr(getHoistPaddings()); | |||
831 | if (any_of(hoistPaddings, | |||
832 | [](int64_t hoistPadding) { return hoistPadding < 0; })) { | |||
833 | return emitOpError() | |||
834 | << "expects hoist_paddings to contain positive integers, found " | |||
835 | << getHoistPaddings(); | |||
836 | } | |||
837 | ||||
838 | ArrayAttr transposes = getTransposePaddings(); | |||
839 | for (Attribute attr : transposes) { | |||
840 | SmallVector<int64_t> transpose = extractFromI64ArrayAttr(attr); | |||
841 | auto sequence = llvm::to_vector(llvm::seq<int64_t>(0, transpose.size())); | |||
842 | if (!std::is_permutation(sequence.begin(), sequence.end(), | |||
843 | transpose.begin(), transpose.end())) { | |||
844 | return emitOpError() | |||
845 | << "expects transpose_paddings to be a permutation, found " | |||
846 | << attr; | |||
847 | } | |||
848 | } | |||
849 | return success(); | |||
850 | } | |||
851 | ||||
852 | //===----------------------------------------------------------------------===// | |||
853 | // PromoteOp | |||
854 | //===----------------------------------------------------------------------===// | |||
855 | ||||
856 | DiagnosedSilenceableFailure | |||
857 | transform::PromoteOp::applyToOne(linalg::LinalgOp target, | |||
858 | SmallVectorImpl<Operation *> &results, | |||
859 | transform::TransformState &state) { | |||
860 | LinalgPromotionOptions promotionOptions; | |||
861 | if (!getOperandsToPromote().empty()) | |||
862 | promotionOptions = promotionOptions.setOperandsToPromote( | |||
863 | extractFromI64ArrayAttr(getOperandsToPromote())); | |||
864 | if (getUseFullTilesByDefault()) | |||
865 | promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( | |||
866 | getUseFullTilesByDefault()); | |||
867 | if (getUseAlloca()) | |||
868 | promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); | |||
869 | if (!getUseFullTileBuffers().empty()) | |||
870 | promotionOptions = promotionOptions.setUseFullTileBuffers( | |||
871 | llvm::to_vector(getUseFullTileBuffers().getAsValueRange<BoolAttr>())); | |||
872 | if (getAlignment().has_value()) | |||
873 | promotionOptions = promotionOptions.setAlignment(*getAlignment()); | |||
874 | ||||
875 | if (failed(promoteSubviewsPrecondition(target, promotionOptions))) | |||
876 | return emitDefaultDefiniteFailure(target); | |||
877 | ||||
878 | TrivialPatternRewriter rewriter(target->getContext()); | |||
879 | rewriter.setInsertionPoint(target); | |||
880 | FailureOr<LinalgOp> res = promoteSubViews(rewriter, target, promotionOptions); | |||
881 | if (failed(res)) | |||
882 | return emitDefaultDefiniteFailure(target); | |||
883 | results.push_back(target); | |||
884 | return DiagnosedSilenceableFailure::success(); | |||
885 | } | |||
886 | ||||
887 | //===----------------------------------------------------------------------===// | |||
888 | // ReplaceOp | |||
889 | //===----------------------------------------------------------------------===// | |||
890 | ||||
891 | DiagnosedSilenceableFailure | |||
892 | transform::ReplaceOp::apply(TransformResults &transformResults, | |||
893 | TransformState &state) { | |||
894 | ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); | |||
895 | ||||
896 | // Check for invalid targets. | |||
897 | for (Operation *target : payload) { | |||
898 | if (target->getNumOperands() > 0) | |||
899 | return emitDefiniteFailure() << "expected target without operands"; | |||
900 | if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>() && | |||
901 | target->getNumRegions() > 0) | |||
902 | return emitDefiniteFailure() | |||
903 | << "expected target that is isloated from above"; | |||
904 | } | |||
905 | ||||
906 | // Clone and replace. | |||
907 | IRRewriter rewriter(getContext()); | |||
908 | Operation *pattern = &getBodyRegion().front().front(); | |||
909 | SmallVector<Operation *> replacements; | |||
910 | for (Operation *target : payload) { | |||
911 | if (getOperation()->isAncestor(target)) | |||
912 | continue; | |||
913 | rewriter.setInsertionPoint(target); | |||
914 | Operation *replacement = rewriter.clone(*pattern); | |||
915 | rewriter.replaceOp(target, replacement->getResults()); | |||
916 | replacements.push_back(replacement); | |||
917 | } | |||
918 | transformResults.set(getReplacement().cast<OpResult>(), replacements); | |||
919 | return DiagnosedSilenceableFailure::success(); | |||
920 | } | |||
921 | ||||
922 | void transform::ReplaceOp::getEffects( | |||
923 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
924 | consumesHandle(getTarget(), effects); | |||
925 | producesHandle(getReplacement(), effects); | |||
926 | modifiesPayload(effects); | |||
927 | } | |||
928 | ||||
929 | LogicalResult transform::ReplaceOp::verify() { | |||
930 | if (!getBodyRegion().hasOneBlock()) | |||
931 | return emitOpError() << "expected one block"; | |||
932 | if (std::distance(getBodyRegion().front().begin(), | |||
933 | getBodyRegion().front().end()) != 1) | |||
934 | return emitOpError() << "expected one operation in block"; | |||
935 | Operation *replacement = &getBodyRegion().front().front(); | |||
936 | if (replacement->getNumOperands() > 0) | |||
937 | return replacement->emitOpError() | |||
938 | << "expected replacement without operands"; | |||
939 | if (!replacement->hasTrait<OpTrait::IsIsolatedFromAbove>() && | |||
940 | replacement->getNumRegions() > 0) | |||
941 | return replacement->emitOpError() | |||
942 | << "expect op that is isolated from above"; | |||
943 | return success(); | |||
944 | } | |||
945 | ||||
946 | //===----------------------------------------------------------------------===// | |||
947 | // ScalarizeOp | |||
948 | //===----------------------------------------------------------------------===// | |||
949 | ||||
950 | DiagnosedSilenceableFailure | |||
951 | transform::ScalarizeOp::applyToOne(linalg::LinalgOp target, | |||
952 | SmallVectorImpl<Operation *> &results, | |||
953 | transform::TransformState &state) { | |||
954 | scf::SCFTilingOptions tilingOptions; | |||
955 | tilingOptions.setTileSizeComputationFunction([&](OpBuilder &b, Operation *) { | |||
956 | SmallVector<Value, 4> tileSizes; | |||
957 | Location loc = target.getLoc(); | |||
958 | SmallVector<OpFoldResult> allShapeSizes = | |||
959 | target.createFlatListOfOperandDims(b, loc); | |||
960 | AffineMap map = target.getShapesToLoopsMap(); | |||
961 | if (!map) | |||
962 | return tileSizes; | |||
963 | IRRewriter rewriter(b); | |||
964 | SmallVector<OpFoldResult> shapeSizes = | |||
965 | makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, | |||
966 | allShapeSizes); | |||
967 | // If the shape size is dynamic, tile by 1. | |||
968 | // Otherwise, do not tile (i.e. tile size 0). | |||
969 | for (OpFoldResult shapeSize : shapeSizes) { | |||
970 | tileSizes.push_back(getConstantIntValue(shapeSize) | |||
971 | ? b.create<arith::ConstantIndexOp>(loc, 0) | |||
972 | : b.create<arith::ConstantIndexOp>(loc, 1)); | |||
973 | } | |||
974 | return tileSizes; | |||
975 | }); | |||
976 | SmallVector<int64_t> emptyTileSizes; | |||
977 | TrivialPatternRewriter rewriter(getContext()); | |||
978 | rewriter.setInsertionPoint(target); | |||
979 | FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp( | |||
980 | rewriter, cast<TilingInterface>(target.getOperation()), tilingOptions); | |||
981 | if (failed(maybeTilingResult)) | |||
982 | return emitDefaultDefiniteFailure(target); | |||
983 | ||||
984 | results.append(maybeTilingResult->tiledOps); | |||
985 | return DiagnosedSilenceableFailure::success(); | |||
986 | } | |||
987 | ||||
988 | //===----------------------------------------------------------------------===// | |||
989 | // SplitOp | |||
990 | //===----------------------------------------------------------------------===// | |||
991 | ||||
992 | DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, | |||
993 | TransformState &state) { | |||
994 | // Collect the dynamic split points if provided. | |||
995 | ArrayRef<Operation *> payload = state.getPayloadOps(getTarget()); | |||
996 | TrivialPatternRewriter rewriter(getContext()); | |||
997 | SmallVector<OpFoldResult> splitPoints; | |||
998 | splitPoints.reserve(payload.size()); | |||
999 | if (getDynamicSplitPoint()) { | |||
1000 | auto diag = DiagnosedSilenceableFailure::success(); | |||
1001 | splitPoints = llvm::to_vector(llvm::map_range( | |||
1002 | state.getPayloadOps(getDynamicSplitPoint()), [&](Operation *op) { | |||
1003 | if (op->getNumResults() != 1 || | |||
1004 | !op->getResult(0).getType().isIndex()) { | |||
1005 | diag = emitSilenceableError() | |||
1006 | << "expected dynamic split point handle to point to a " | |||
1007 | "single-result index-typed op"; | |||
1008 | diag.attachNote(op->getLoc()) << "dynamic split point"; | |||
1009 | } | |||
1010 | return OpFoldResult(op->getResult(0)); | |||
1011 | })); | |||
1012 | if (diag.isSilenceableFailure()) { | |||
1013 | results.set(getFirst().cast<OpResult>(), {}); | |||
1014 | results.set(getSecond().cast<OpResult>(), {}); | |||
1015 | return diag; | |||
1016 | } | |||
1017 | ||||
1018 | if (splitPoints.size() != payload.size()) { | |||
1019 | return emitDefiniteFailure() | |||
1020 | << "expected the dynamic split point handle to point to as " | |||
1021 | "many operations (" | |||
1022 | << splitPoints.size() << ") as the target handle (" | |||
1023 | << payload.size() << ")"; | |||
1024 | } | |||
1025 | } else { | |||
1026 | splitPoints.resize(payload.size(), | |||
1027 | rewriter.getIndexAttr(getStaticSplitPoint())); | |||
1028 | } | |||
1029 | ||||
1030 | // Split each target operation. | |||
1031 | SmallVector<Operation *> first, second; | |||
1032 | for (const auto &pair : llvm::zip(payload, splitPoints)) { | |||
1033 | Operation *target = std::get<0>(pair); | |||
1034 | auto linalgOp = dyn_cast<LinalgOp>(target); | |||
1035 | if (!linalgOp) { | |||
1036 | auto diag = emitSilenceableError() << "only applies to structured ops"; | |||
1037 | diag.attachNote(target->getLoc()) << "target op"; | |||
1038 | results.set(getFirst().cast<OpResult>(), {}); | |||
1039 | results.set(getSecond().cast<OpResult>(), {}); | |||
1040 | return diag; | |||
1041 | } | |||
1042 | ||||
1043 | if (getDimension() >= linalgOp.getNumLoops()) { | |||
1044 | auto diag = emitSilenceableError() << "dimension " << getDimension() | |||
1045 | << " does not exist in target op"; | |||
1046 | diag.attachNote(target->getLoc()) << "target op"; | |||
1047 | results.set(getFirst().cast<OpResult>(), {}); | |||
1048 | results.set(getSecond().cast<OpResult>(), {}); | |||
1049 | return diag; | |||
1050 | } | |||
1051 | ||||
1052 | rewriter.setInsertionPoint(linalgOp); | |||
1053 | std::tie(first.emplace_back(), second.emplace_back()) = linalg::splitOp( | |||
1054 | rewriter, cast<TilingInterface>(linalgOp.getOperation()), | |||
1055 | getDimension(), std::get<1>(pair)); | |||
1056 | } | |||
1057 | ||||
1058 | results.set(getFirst().cast<OpResult>(), first); | |||
1059 | results.set(getSecond().cast<OpResult>(), second); | |||
1060 | return DiagnosedSilenceableFailure::success(); | |||
1061 | } | |||
1062 | ||||
1063 | void SplitOp::getEffects( | |||
1064 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
1065 | consumesHandle(getTarget(), effects); | |||
1066 | if (getDynamicSplitPoint()) | |||
1067 | onlyReadsHandle(getDynamicSplitPoint(), effects); | |||
1068 | producesHandle(getResults(), effects); | |||
1069 | modifiesPayload(effects); | |||
1070 | } | |||
1071 | ||||
1072 | ParseResult SplitOp::parse(OpAsmParser &parser, OperationState &result) { | |||
1073 | OpAsmParser::UnresolvedOperand target, dynamicSplitPoint; | |||
1074 | IntegerAttr staticSplitPoint; | |||
1075 | auto pdlOperationType = | |||
1076 | pdl::OperationType::get(parser.getBuilder().getContext()); | |||
1077 | if (parser.parseOperand(target) || | |||
| ||||
1078 | parser.resolveOperand(target, pdlOperationType, result.operands) || | |||
1079 | parser.parseKeyword("after")) | |||
1080 | return failure(); | |||
1081 | ||||
1082 | OptionalParseResult dynamicPointParseResult = | |||
1083 | parser.parseOptionalOperand(dynamicSplitPoint); | |||
1084 | if (!dynamicPointParseResult.has_value()) { | |||
1085 | int64_t staticSplitPointValue; | |||
1086 | if (failed(parser.parseInteger(staticSplitPointValue))) | |||
1087 | return failure(); | |||
1088 | ||||
1089 | staticSplitPoint = | |||
1090 | parser.getBuilder().getI64IntegerAttr(staticSplitPointValue); | |||
| ||||
1091 | } else { | |||
1092 | if (failed(*dynamicPointParseResult) || | |||
1093 | parser.resolveOperand(dynamicSplitPoint, pdlOperationType, | |||
1094 | result.operands)) { | |||
1095 | return failure(); | |||
1096 | } | |||
1097 | ||||
1098 | staticSplitPoint = | |||
1099 | parser.getBuilder().getI64IntegerAttr(ShapedType::kDynamic); | |||
1100 | } | |||
1101 | ||||
1102 | result.addAttribute( | |||
1103 | SplitOp::getStaticSplitPointAttrName(result.name).getValue(), | |||
1104 | staticSplitPoint); | |||
1105 | if (failed(parser.parseOptionalAttrDict(result.attributes))) | |||
1106 | return failure(); | |||
1107 | ||||
1108 | result.addTypes({pdlOperationType, pdlOperationType}); | |||
1109 | return success(); | |||
1110 | } | |||
1111 | ||||
1112 | void SplitOp::print(OpAsmPrinter &printer) { | |||
1113 | printer << " " << getTarget() << " after "; | |||
1114 | int64_t staticSplitSize = static_cast<int64_t>(getStaticSplitPoint()); | |||
1115 | if (staticSplitSize != ShapedType::kDynamic) | |||
1116 | printer << staticSplitSize; | |||
1117 | else | |||
1118 | printer << getDynamicSplitPoint(); | |||
1119 | printer << " "; | |||
1120 | printer.printOptionalAttrDict(getOperation()->getAttrs(), | |||
1121 | {getStaticSplitPointAttrName()}); | |||
1122 | } | |||
1123 | ||||
1124 | LogicalResult SplitOp::verify() { | |||
1125 | if ((static_cast<int64_t>(getStaticSplitPoint()) != ShapedType::kDynamic) ^ | |||
1126 | (getDynamicSplitPoint() == nullptr)) { | |||
1127 | return emitOpError() << "expects either a dynamic or a static split " | |||
1128 | "point to be provided"; | |||
1129 | } | |||
1130 | return success(); | |||
1131 | } | |||
1132 | ||||
1133 | //===----------------------------------------------------------------------===// | |||
1134 | // SplitReductionOp | |||
1135 | //===----------------------------------------------------------------------===// | |||
1136 | ||||
1137 | void transform::SplitReductionOp::build( | |||
1138 | OpBuilder &builder, OperationState &result, Value target, | |||
1139 | int64_t splitFactor, int64_t insertSplitDimension, bool innerParallel, | |||
1140 | bool useScalingAlgorithm, bool useAlloc) { | |||
1141 | MLIRContext *ctx = builder.getContext(); | |||
1142 | result.addOperands(target); | |||
1143 | result.addAttribute(SplitReductionOp::getSplitFactorAttrName(result.name), | |||
1144 | builder.getI64IntegerAttr(splitFactor)); | |||
1145 | result.addAttribute( | |||
1146 | SplitReductionOp::getInsertSplitDimensionAttrName(result.name), | |||
1147 | builder.getI64IntegerAttr(insertSplitDimension)); | |||
1148 | if (innerParallel) { | |||
1149 | result.addAttribute(SplitReductionOp::getInnerParallelAttrName(result.name), | |||
1150 | builder.getUnitAttr()); | |||
1151 | } | |||
1152 | if (useScalingAlgorithm) { | |||
1153 | result.addAttribute( | |||
1154 | SplitReductionOp::getUseScalingAlgorithmAttrName(result.name), | |||
1155 | builder.getUnitAttr()); | |||
1156 | } | |||
1157 | if (useAlloc) { | |||
1158 | result.addAttribute(SplitReductionOp::getUseAllocAttrName(result.name), | |||
1159 | builder.getUnitAttr()); | |||
1160 | } | |||
1161 | auto resultType = pdl::OperationType::get(ctx); | |||
1162 | result.addTypes({resultType, resultType, resultType, resultType}); | |||
1163 | } | |||
1164 | ||||
1165 | DiagnosedSilenceableFailure | |||
1166 | transform::SplitReductionOp::applyToOne(linalg::LinalgOp target, | |||
1167 | SmallVectorImpl<Operation *> &results, | |||
1168 | transform::TransformState &state) { | |||
1169 | ControlSplitReductionFn splitFn = [&](LinalgOp) { | |||
1170 | return linalg::SplitReductionOptions{int64_t(getSplitFactor()), | |||
1171 | unsigned(getInsertSplitDimension()), | |||
1172 | bool(getInnerParallel())}; | |||
1173 | }; | |||
1174 | TrivialPatternRewriter rewriter(getContext()); | |||
1175 | rewriter.setInsertionPoint(target); | |||
1176 | FailureOr<SplitReductionResult> splitResult = | |||
1177 | (getUseScalingAlgorithm()) | |||
1178 | ? splitReductionByScaling(rewriter, target, splitFn, getUseAlloc()) | |||
1179 | : splitReduction(rewriter, target, splitFn, getUseAlloc()); | |||
1180 | if (failed(splitResult)) | |||
1181 | return emitDefaultDefiniteFailure(target); | |||
1182 | ||||
1183 | results.push_back(splitResult->initOrAlloc); | |||
1184 | results.push_back(splitResult->fillOp); | |||
1185 | results.push_back(splitResult->splitLinalgOp); | |||
1186 | results.push_back(splitResult->resultCombiningLinalgOp); | |||
1187 | return DiagnosedSilenceableFailure::success(); | |||
1188 | } | |||
1189 | ||||
1190 | //===----------------------------------------------------------------------===// | |||
1191 | // TileReductionUsingScfOp | |||
1192 | //===----------------------------------------------------------------------===// | |||
1193 | ||||
1194 | void transform::TileReductionUsingScfOp::build( | |||
1195 | OpBuilder &builder, OperationState &result, Value target, | |||
1196 | ArrayRef<int64_t> staticTileSizes) { | |||
1197 | // Call the default builder. | |||
1198 | // This is future-proof re mixed static-dynamic and setting up the proper | |||
1199 | // operands segment sizes attributes for multiple variadic operands. | |||
1200 | // In the absence of this, horrible bugs ensue. | |||
1201 | // TODO: support mixed static-dynamic (see TileToForeachThreadOp). | |||
1202 | MLIRContext *ctx = builder.getContext(); | |||
1203 | auto opTy = pdl::OperationType::get(ctx); | |||
1204 | auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); | |||
1205 | build(builder, result, | |||
1206 | /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, | |||
1207 | /*target=*/target, | |||
1208 | /*tile_sizes=*/staticTileSizesAttr); | |||
1209 | } | |||
1210 | ||||
1211 | DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( | |||
1212 | linalg::LinalgOp target, SmallVectorImpl<Operation *> &results, | |||
1213 | transform::TransformState &state) { | |||
1214 | TrivialPatternRewriter rewriter(getContext()); | |||
1215 | rewriter.setInsertionPoint(target); | |||
1216 | FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf( | |||
1217 | rewriter, cast<PartialReductionOpInterface>(target.getOperation()), | |||
1218 | getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); | |||
1219 | ||||
1220 | if (failed(result)) | |||
1221 | return emitDefaultSilenceableFailure(target); | |||
1222 | results.push_back(result->loops.front()); | |||
1223 | results.push_back(result->initialOp); | |||
1224 | results.push_back(result->parallelTiledOp); | |||
1225 | results.push_back(result->mergeOp); | |||
1226 | return DiagnosedSilenceableFailure::success(); | |||
1227 | } | |||
1228 | ||||
1229 | //===----------------------------------------------------------------------===// | |||
1230 | // TileReductionUsingForeachThreadOp | |||
1231 | //===----------------------------------------------------------------------===// | |||
1232 | ||||
1233 | void transform::TileReductionUsingForeachThreadOp::build( | |||
1234 | OpBuilder &builder, OperationState &result, Value target, | |||
1235 | ArrayRef<int64_t> staticNumThreads, ArrayRef<int64_t> staticTileSizes, | |||
1236 | ArrayAttr mapping) { | |||
1237 | // Call the default builder. | |||
1238 | // This is future-proof re mixed static-dynamic and setting up the proper | |||
1239 | // operands segment sizes attributes for multiple variadic operands. | |||
1240 | // In the absence of this, horrible bugs ensue. | |||
1241 | // TODO: support mixed static-dynamic (see TileToForeachThreadOp). | |||
1242 | MLIRContext *ctx = builder.getContext(); | |||
1243 | auto opTy = pdl::OperationType::get(ctx); | |||
1244 | auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); | |||
1245 | auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); | |||
1246 | build(builder, result, | |||
1247 | /*resultTypes=*/TypeRange{opTy, opTy, opTy, opTy}, | |||
1248 | /*target=*/target, | |||
1249 | /*num_threads=*/staticNumThreadsAttr, | |||
1250 | /*tile_sizes=*/staticTileSizesAttr, | |||
1251 | /*mapping=*/mapping); | |||
1252 | } | |||
1253 | ||||
1254 | DiagnosedSilenceableFailure | |||
1255 | transform::TileReductionUsingForeachThreadOp::applyToOne( | |||
1256 | linalg::LinalgOp target, SmallVectorImpl<Operation *> &results, | |||
1257 | transform::TransformState &state) { | |||
1258 | TrivialPatternRewriter rewriter(getContext()); | |||
1259 | rewriter.setInsertionPoint(target); | |||
1260 | SmallVector<OpFoldResult> numThreads = | |||
1261 | getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); | |||
1262 | SmallVector<OpFoldResult> tileSizes = | |||
1263 | getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())); | |||
1264 | FailureOr<linalg::ForeachThreadReductionTilingResult> result = | |||
1265 | linalg::tileReductionUsingForeachThread( | |||
1266 | rewriter, cast<PartialReductionOpInterface>(target.getOperation()), | |||
1267 | numThreads, tileSizes, getMapping()); | |||
1268 | ||||
1269 | if (failed(result)) { | |||
1270 | results.assign(3, nullptr); | |||
1271 | Diagnostic diag(target->getLoc(), DiagnosticSeverity::Remark); | |||
1272 | diag << "could not tile reduction in target."; | |||
1273 | return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); | |||
1274 | } | |||
1275 | results.push_back(result->loops); | |||
1276 | results.push_back(result->initialOp); | |||
1277 | results.push_back(result->parallelTiledOp); | |||
1278 | results.push_back(result->mergeOp); | |||
1279 | return DiagnosedSilenceableFailure::success(); | |||
1280 | } | |||
1281 | ||||
1282 | //===----------------------------------------------------------------------===// | |||
1283 | // TileOp | |||
1284 | //===----------------------------------------------------------------------===// | |||
1285 | void transform::TileOp::build(OpBuilder &builder, OperationState &result, | |||
1286 | Value target, ArrayRef<int64_t> staticTileSizes, | |||
1287 | ArrayRef<int64_t> interchange) { | |||
1288 | return build(builder, result, | |||
1289 | /*target=*/target, | |||
1290 | /*mixedTileSizes=*/ | |||
1291 | getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), | |||
1292 | interchange); | |||
1293 | } | |||
1294 | ||||
1295 | void transform::TileOp::build(OpBuilder &builder, OperationState &result, | |||
1296 | Value target, | |||
1297 | ArrayRef<OpFoldResult> mixedTileSizes, | |||
1298 | ArrayRef<int64_t> interchange) { | |||
1299 | SmallVector<int64_t> staticTileSizes; | |||
1300 | SmallVector<Value> dynamicTileSizes; | |||
1301 | dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); | |||
1302 | // Call the default builder which sets up the proper operands segment sizes | |||
1303 | // attributes for multiple variadic operands. In the absence of this, horrible | |||
1304 | // bugs ensue. | |||
1305 | MLIRContext *ctx = builder.getContext(); | |||
1306 | auto operationType = pdl::OperationType::get(ctx); | |||
1307 | auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); | |||
1308 | build(builder, result, | |||
1309 | /*resultTypes=*/TypeRange{operationType, operationType}, | |||
1310 | /*target=*/target, | |||
1311 | /*dynamic_sizes=*/dynamicTileSizes, | |||
1312 | /*static_sizes=*/staticTileSizesAttr, | |||
1313 | /*interchange=*/builder.getDenseI64ArrayAttr(interchange)); | |||
1314 | } | |||
1315 | ||||
1316 | DiagnosedSilenceableFailure | |||
1317 | transform::TileOp::apply(TransformResults &transformResults, | |||
1318 | TransformState &state) { | |||
1319 | ArrayRef<int64_t> tileSizes = getStaticSizes(); | |||
1320 | ||||
1321 | ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); | |||
1322 | SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; | |||
1323 | dynamicSizeProducers.reserve(getDynamicSizes().size()); | |||
1324 | for (Value dynamicSizeProducerHandle : getDynamicSizes()) { | |||
1325 | dynamicSizeProducers.push_back( | |||
1326 | state.getPayloadOps(dynamicSizeProducerHandle)); | |||
1327 | ||||
1328 | if (dynamicSizeProducers.back().size() != targets.size()) { | |||
1329 | DiagnosedSilenceableFailure diag = | |||
1330 | emitSilenceableError() | |||
1331 | << "expected as many dynamic size-producing operations (" | |||
1332 | << dynamicSizeProducers.back().size() << ") as target ops (" | |||
1333 | << targets.size() << ")"; | |||
1334 | diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; | |||
1335 | return diag; | |||
1336 | } | |||
1337 | ||||
1338 | for (Operation *op : dynamicSizeProducers.back()) { | |||
1339 | if (op->getNumResults() == 1 && | |||
1340 | op->getResult(0).getType().isa<IndexType>()) | |||
1341 | continue; | |||
1342 | DiagnosedSilenceableFailure diag = | |||
1343 | emitSilenceableError() << "expected sizes to be produced by ops " | |||
1344 | "with a single index-type result"; | |||
1345 | diag.attachNote(op->getLoc()) << "size producer op"; | |||
1346 | diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; | |||
1347 | return diag; | |||
1348 | } | |||
1349 | } | |||
1350 | ||||
1351 | SmallVector<Operation *> tiled; | |||
1352 | SmallVector<SmallVector<Operation *, 4>, 4> loops; | |||
1353 | loops.resize(getLoops().size()); | |||
1354 | for (auto &en : llvm::enumerate(targets)) { | |||
1355 | auto linalgOp = dyn_cast<LinalgOp>(en.value()); | |||
1356 | if (!linalgOp) { | |||
1357 | DiagnosedSilenceableFailure diag = emitSilenceableError() | |||
1358 | << "only linalg ops are supported"; | |||
1359 | diag.attachNote(en.value()->getLoc()) << "target op"; | |||
1360 | return diag; | |||
1361 | } | |||
1362 | ||||
1363 | scf::SCFTilingOptions tilingOptions; | |||
1364 | unsigned index = en.index(); | |||
1365 | if (!tileSizes.empty()) { | |||
1366 | tilingOptions.setTileSizeComputationFunction( | |||
1367 | [&, index](OpBuilder &b, Operation *) { | |||
1368 | SmallVector<Value, 4> sizes; | |||
1369 | sizes.reserve(tileSizes.size()); | |||
1370 | unsigned dynamicIdx = 0; | |||
1371 | for (OpFoldResult ofr : getMixedSizes()) { | |||
1372 | if (auto attr = ofr.dyn_cast<Attribute>()) { | |||
1373 | sizes.push_back(b.create<arith::ConstantIndexOp>( | |||
1374 | getLoc(), attr.cast<IntegerAttr>().getInt())); | |||
1375 | } else { | |||
1376 | sizes.push_back( | |||
1377 | dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); | |||
1378 | } | |||
1379 | } | |||
1380 | return sizes; | |||
1381 | }); | |||
1382 | } | |||
1383 | ||||
1384 | tilingOptions.setInterchange(getInterchange()); | |||
1385 | TrivialPatternRewriter rewriter(linalgOp.getContext()); | |||
1386 | FailureOr<scf::SCFTilingResult> maybeTilingResult = tileUsingSCFForOp( | |||
1387 | rewriter, cast<TilingInterface>(linalgOp.getOperation()), | |||
1388 | tilingOptions); | |||
1389 | if (failed(maybeTilingResult)) | |||
1390 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
1391 | ||||
1392 | if (linalgOp.hasBufferSemantics()) | |||
1393 | rewriter.eraseOp(linalgOp); | |||
1394 | else | |||
1395 | rewriter.replaceOp(linalgOp, | |||
1396 | maybeTilingResult->loops.front()->getResults()); | |||
1397 | ||||
1398 | tiled.append(maybeTilingResult->tiledOps); | |||
1399 | for (const auto &en2 : llvm::enumerate(maybeTilingResult->loops)) | |||
1400 | loops[en2.index()].push_back(en2.value()); | |||
1401 | } | |||
1402 | ||||
1403 | transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); | |||
1404 | for (const auto &en : llvm::enumerate(loops)) | |||
1405 | transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); | |||
1406 | ||||
1407 | return DiagnosedSilenceableFailure::success(); | |||
1408 | } | |||
1409 | ||||
1410 | SmallVector<OpFoldResult> transform::TileOp::getMixedSizes() { | |||
1411 | ValueRange dynamic = getDynamicSizes(); | |||
1412 | ArrayRef<int64_t> tileSizes = getStaticSizes(); | |||
1413 | SmallVector<OpFoldResult> results; | |||
1414 | results.reserve(tileSizes.size()); | |||
1415 | unsigned dynamicPos = 0; | |||
1416 | Builder builder(getContext()); | |||
1417 | for (int64_t size : tileSizes) { | |||
1418 | if (size == ShapedType::kDynamic) { | |||
1419 | results.push_back(dynamic[dynamicPos++]); | |||
1420 | } else { | |||
1421 | results.push_back(builder.getIndexAttr(size)); | |||
1422 | } | |||
1423 | } | |||
1424 | return results; | |||
1425 | } | |||
1426 | ||||
1427 | // We want to parse `DenseI64ArrayAttr` using the short form without the | |||
1428 | // `array` prefix to be consistent in the IR with `parseDynamicIndexList`. | |||
1429 | ParseResult parseOptionalInterchange(OpAsmParser &parser, | |||
1430 | OperationState &result) { | |||
1431 | if (succeeded(parser.parseOptionalLBrace())) { | |||
1432 | if (failed(parser.parseKeyword("interchange"))) | |||
1433 | return parser.emitError(parser.getNameLoc()) << "expect `interchange`"; | |||
1434 | if (failed(parser.parseEqual())) | |||
1435 | return parser.emitError(parser.getNameLoc()) << "expect `=`"; | |||
1436 | result.addAttribute("interchange", | |||
1437 | DenseI64ArrayAttr::parse(parser, Type{})); | |||
1438 | if (failed(parser.parseRBrace())) | |||
1439 | return parser.emitError(parser.getNameLoc()) << "expect `}`"; | |||
1440 | } | |||
1441 | return success(); | |||
1442 | } | |||
1443 | ||||
1444 | void printOptionalInterchange(OpAsmPrinter &p, | |||
1445 | ArrayRef<int64_t> interchangeVals) { | |||
1446 | if (!interchangeVals.empty()) { | |||
1447 | p << " {interchange = ["; | |||
1448 | llvm::interleaveComma(interchangeVals, p, | |||
1449 | [&](int64_t integer) { p << integer; }); | |||
1450 | p << "]}"; | |||
1451 | } | |||
1452 | } | |||
1453 | ||||
1454 | ParseResult transform::TileOp::parse(OpAsmParser &parser, | |||
1455 | OperationState &result) { | |||
1456 | OpAsmParser::UnresolvedOperand target; | |||
1457 | SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; | |||
1458 | DenseI64ArrayAttr staticSizes; | |||
1459 | auto pdlOperationType = pdl::OperationType::get(parser.getContext()); | |||
1460 | if (parser.parseOperand(target) || | |||
1461 | parser.resolveOperand(target, pdlOperationType, result.operands) || | |||
1462 | parseDynamicIndexList(parser, dynamicSizes, staticSizes) || | |||
1463 | parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) | |||
1464 | return ParseResult::failure(); | |||
1465 | ||||
1466 | // Parse optional interchange. | |||
1467 | if (failed(parseOptionalInterchange(parser, result))) | |||
1468 | return ParseResult::failure(); | |||
1469 | result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); | |||
1470 | size_t numExpectedLoops = | |||
1471 | staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); | |||
1472 | result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); | |||
1473 | return success(); | |||
1474 | } | |||
1475 | ||||
1476 | void TileOp::print(OpAsmPrinter &p) { | |||
1477 | p << ' ' << getTarget(); | |||
1478 | printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); | |||
1479 | printOptionalInterchange(p, getInterchange()); | |||
1480 | } | |||
1481 | ||||
1482 | void transform::TileOp::getEffects( | |||
1483 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
1484 | consumesHandle(getTarget(), effects); | |||
1485 | onlyReadsHandle(getDynamicSizes(), effects); | |||
1486 | producesHandle(getTiledLinalgOp(), effects); | |||
1487 | producesHandle(getLoops(), effects); | |||
1488 | modifiesPayload(effects); | |||
1489 | } | |||
1490 | ||||
1491 | //===----------------------------------------------------------------------===// | |||
1492 | // TileToForeachThreadOp | |||
1493 | //===----------------------------------------------------------------------===// | |||
1494 | ||||
1495 | void transform::TileToForeachThreadOp::build(OpBuilder &builder, | |||
1496 | OperationState &result, | |||
1497 | Value target, | |||
1498 | ArrayRef<int64_t> staticTileSizes, | |||
1499 | transform::TileSizesSpec, | |||
1500 | ArrayAttr mapping) { | |||
1501 | return build(builder, result, | |||
1502 | /*target=*/target, | |||
1503 | /*mixedTileSizes=*/ | |||
1504 | getAsOpFoldResult(builder.getI64ArrayAttr(staticTileSizes)), | |||
1505 | /*_=*/TileSizesSpec(), | |||
1506 | /*mapping=*/mapping); | |||
1507 | } | |||
1508 | ||||
1509 | void transform::TileToForeachThreadOp::build( | |||
1510 | OpBuilder &builder, OperationState &result, Value target, | |||
1511 | ArrayRef<OpFoldResult> mixedTileSizes, transform::TileSizesSpec, | |||
1512 | ArrayAttr mapping) { | |||
1513 | SmallVector<int64_t> staticTileSizes; | |||
1514 | SmallVector<Value> dynamicTileSizes; | |||
1515 | dispatchIndexOpFoldResults(mixedTileSizes, dynamicTileSizes, staticTileSizes); | |||
1516 | // Call the default builder which sets up the proper operands segment sizes | |||
1517 | // attributes for multiple variadic operands. In the absence of this, horrible | |||
1518 | // bugs ensue. | |||
1519 | MLIRContext *ctx = builder.getContext(); | |||
1520 | auto operationType = pdl::OperationType::get(ctx); | |||
1521 | auto staticTileSizesAttr = builder.getDenseI64ArrayAttr(staticTileSizes); | |||
1522 | build(builder, result, | |||
1523 | /*resultTypes=*/TypeRange{operationType, operationType}, | |||
1524 | /*target=*/target, | |||
1525 | /*num_threads=*/ValueRange{}, | |||
1526 | /*tile_sizes=*/dynamicTileSizes, | |||
1527 | /*packed_num_threads=*/Value(), | |||
1528 | /*packed_tile_sizes=*/Value(), | |||
1529 | /*static_num_threads=*/builder.getDenseI64ArrayAttr({}), | |||
1530 | /*static_tile_sizes=*/staticTileSizesAttr, | |||
1531 | /*mapping=*/mapping); | |||
1532 | } | |||
1533 | ||||
1534 | void transform::TileToForeachThreadOp::build(OpBuilder &builder, | |||
1535 | OperationState &result, | |||
1536 | Value target, | |||
1537 | ArrayRef<int64_t> staticNumThreads, | |||
1538 | transform::NumThreadsSpec, | |||
1539 | ArrayAttr mapping) { | |||
1540 | return build(builder, result, target, | |||
1541 | getAsOpFoldResult(builder.getI64ArrayAttr(staticNumThreads)), | |||
1542 | NumThreadsSpec(), mapping); | |||
1543 | } | |||
1544 | ||||
1545 | void transform::TileToForeachThreadOp::build( | |||
1546 | OpBuilder &builder, OperationState &result, Value target, | |||
1547 | ArrayRef<OpFoldResult> mixedNumThreads, transform::NumThreadsSpec, | |||
1548 | ArrayAttr mapping) { | |||
1549 | SmallVector<int64_t> staticNumThreads; | |||
1550 | SmallVector<Value> dynamicNumThreads; | |||
1551 | dispatchIndexOpFoldResults(mixedNumThreads, dynamicNumThreads, | |||
1552 | staticNumThreads); | |||
1553 | // Call the default builder which sets up the proper operands segment sizes | |||
1554 | // attributes for multiple variadic operands. In the absence of this, horrible | |||
1555 | // bugs ensue. | |||
1556 | MLIRContext *ctx = builder.getContext(); | |||
1557 | auto operationType = pdl::OperationType::get(ctx); | |||
1558 | auto staticNumThreadsAttr = builder.getDenseI64ArrayAttr(staticNumThreads); | |||
1559 | build(builder, result, | |||
1560 | /*resultTypes=*/TypeRange{operationType, operationType}, | |||
1561 | /*target=*/target, | |||
1562 | /*num_threads=*/dynamicNumThreads, | |||
1563 | /*tile_sizes=*/ValueRange{}, | |||
1564 | /*packed_num_threads=*/Value(), | |||
1565 | /*packed_tile_sizes=*/Value(), | |||
1566 | /*static_num_threads=*/staticNumThreadsAttr, | |||
1567 | /*static_tile_sizes=*/builder.getDenseI64ArrayAttr({}), | |||
1568 | /*mapping=*/mapping); | |||
1569 | } | |||
1570 | ||||
1571 | /// Assuming that `ofr` is an index attr or a transform dialect handle mapped | |||
1572 | /// to exactly one op with one index result, return that value. | |||
1573 | static DiagnosedSilenceableFailure unpackPDLOperations( | |||
1574 | transform::TransformState &state, TransformOpInterface transformOp, | |||
1575 | SmallVector<OpFoldResult> &result, ArrayRef<OpFoldResult> ofrs) { | |||
1576 | for (OpFoldResult ofr : ofrs) { | |||
1577 | if (ofr.is<Attribute>()) { | |||
1578 | if (!ofr.get<Attribute>().isa<IntegerAttr>()) | |||
1579 | return transformOp.emitDefiniteFailure() << "expected IntegerAttr"; | |||
1580 | result.push_back(ofr); | |||
1581 | continue; | |||
1582 | } | |||
1583 | ArrayRef<Operation *> payloadOps = state.getPayloadOps(ofr.get<Value>()); | |||
1584 | if (payloadOps.size() != 1) { | |||
1585 | DiagnosedSilenceableFailure diag = | |||
1586 | transformOp.emitSilenceableError() | |||
1587 | << "handle must be mapped to exactly one payload op"; | |||
1588 | diag.attachNote(ofr.get<Value>().getLoc()) | |||
1589 | << "mapped to " << payloadOps.size() << " payload ops"; | |||
1590 | return diag; | |||
1591 | } | |||
1592 | ||||
1593 | Operation *op = payloadOps[0]; | |||
1594 | if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { | |||
1595 | DiagnosedSilenceableFailure diag = | |||
1596 | transformOp.emitSilenceableError() | |||
1597 | << "payload op must have exactly 1 index result"; | |||
1598 | diag.attachNote(op->getLoc()) | |||
1599 | << "has " << op->getNumResults() << " results"; | |||
1600 | return diag; | |||
1601 | } | |||
1602 | result.push_back(op->getResult(0)); | |||
1603 | } | |||
1604 | ||||
1605 | return DiagnosedSilenceableFailure::success(); | |||
1606 | } | |||
1607 | ||||
1608 | // Given a list of OpFoldResults that are either index attrs or op | |||
1609 | // handles, return a list of OpFoldResults where all op handles are | |||
1610 | // replaced with the first (and only) OpResult of that payload op. (There | |||
1611 | // must be exactly one mapped payload op and it must have exactly one | |||
1612 | // index result.) | |||
1613 | static DiagnosedSilenceableFailure | |||
1614 | unpackPDLOperations(transform::TransformState &state, | |||
1615 | TransformOpInterface transformOp, | |||
1616 | SmallVector<OpFoldResult> &result, Value packedHandle) { | |||
1617 | ArrayRef<Operation *> payloadOps = state.getPayloadOps(packedHandle); | |||
1618 | for (Operation *op : payloadOps) { | |||
1619 | if (op->getNumResults() != 1 || !op->getResult(0).getType().isIndex()) { | |||
1620 | DiagnosedSilenceableFailure diag = | |||
1621 | transformOp.emitSilenceableError() | |||
1622 | << "payload op must have exactly 1 index result"; | |||
1623 | diag.attachNote(op->getLoc()) | |||
1624 | << "has " << op->getNumResults() << " results"; | |||
1625 | return diag; | |||
1626 | } | |||
1627 | result.push_back(op->getResult(0)); | |||
1628 | } | |||
1629 | ||||
1630 | return DiagnosedSilenceableFailure::success(); | |||
1631 | } | |||
1632 | ||||
1633 | DiagnosedSilenceableFailure transform::tileToForeachThreadOpImpl( | |||
1634 | RewriterBase &rewriter, transform::TransformState &state, | |||
1635 | TransformOpInterface transformOp, ArrayRef<Operation *> targets, | |||
1636 | ArrayRef<OpFoldResult> mixedNumThreads, | |||
1637 | ArrayRef<OpFoldResult> mixedTileSizes, std::optional<ArrayAttr> mapping, | |||
1638 | SmallVector<Operation *> &tileOps, SmallVector<Operation *> &tiledOps) { | |||
1639 | if (targets.empty()) | |||
1640 | return DiagnosedSilenceableFailure::success(); | |||
1641 | ||||
1642 | // Transform all targets one by one. | |||
1643 | for (Operation *target : targets) { | |||
1644 | auto tilableOp = dyn_cast<TilingInterface>(target); | |||
1645 | if (!tilableOp) { | |||
1646 | DiagnosedSilenceableFailure diag = | |||
1647 | transformOp.emitSilenceableError() | |||
1648 | << "only TilingInterface ops are supported"; | |||
1649 | diag.attachNote(target->getLoc()) << "target op"; | |||
1650 | return diag; | |||
1651 | } | |||
1652 | rewriter.setInsertionPoint(tilableOp); | |||
1653 | FailureOr<linalg::ForeachThreadTilingResult> tilingResult = failure(); | |||
1654 | if (!mixedNumThreads.empty()) { | |||
1655 | tilingResult = linalg::tileToForeachThreadOp(rewriter, tilableOp, | |||
1656 | mixedNumThreads, mapping); | |||
1657 | } else { | |||
1658 | tilingResult = linalg::tileToForeachThreadOpUsingTileSizes( | |||
1659 | rewriter, tilableOp, mixedTileSizes, mapping); | |||
1660 | } | |||
1661 | ||||
1662 | if (failed(tilingResult)) | |||
1663 | return transformOp.emitDefaultSilenceableFailure(tilableOp); | |||
1664 | rewriter.replaceOp(tilableOp, tilingResult->tileOp->getResults()); | |||
1665 | ||||
1666 | tileOps.push_back(tilingResult->tileOp); | |||
1667 | tiledOps.push_back(tilingResult->tiledOp); | |||
1668 | } | |||
1669 | return DiagnosedSilenceableFailure::success(); | |||
1670 | } | |||
1671 | ||||
1672 | DiagnosedSilenceableFailure transform::TileToForeachThreadOp::apply( | |||
1673 | transform::TransformResults &transformResults, | |||
1674 | transform::TransformState &state) { | |||
1675 | IRRewriter rewriter(getContext()); | |||
1676 | auto transformOp = cast<TransformOpInterface>(getOperation()); | |||
1677 | ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); | |||
1678 | ||||
1679 | // Result payload ops. | |||
1680 | SmallVector<Operation *> tileOps; | |||
1681 | SmallVector<Operation *> tiledOps; | |||
1682 | ||||
1683 | // Unpack handles. | |||
1684 | SmallVector<OpFoldResult> mixedNumThreads; | |||
1685 | DiagnosedSilenceableFailure status = | |||
1686 | getPackedNumThreads() | |||
1687 | ? unpackPDLOperations(state, transformOp, mixedNumThreads, | |||
1688 | getPackedNumThreads()) | |||
1689 | : unpackPDLOperations(state, transformOp, mixedNumThreads, | |||
1690 | getMixedNumThreads()); | |||
1691 | if (!status.succeeded()) | |||
1692 | return status; | |||
1693 | SmallVector<OpFoldResult> mixedTileSizes; | |||
1694 | status = getPackedTileSizes() | |||
1695 | ? unpackPDLOperations(state, transformOp, mixedTileSizes, | |||
1696 | getPackedTileSizes()) | |||
1697 | : unpackPDLOperations(state, transformOp, mixedTileSizes, | |||
1698 | getMixedTileSizes()); | |||
1699 | if (!status.succeeded()) | |||
1700 | return status; | |||
1701 | ||||
1702 | DiagnosedSilenceableFailure diag = tileToForeachThreadOpImpl( | |||
1703 | rewriter, state, transformOp, targets, mixedNumThreads, mixedTileSizes, | |||
1704 | getMapping(), tileOps, tiledOps); | |||
1705 | ||||
1706 | if (!diag.succeeded()) { | |||
1707 | transformResults.set(getForeachThreadOp().cast<OpResult>(), {}); | |||
1708 | transformResults.set(getTiledOp().cast<OpResult>(), {}); | |||
1709 | return diag; | |||
1710 | } | |||
1711 | ||||
1712 | transformResults.set(getForeachThreadOp().cast<OpResult>(), tileOps); | |||
1713 | transformResults.set(getTiledOp().cast<OpResult>(), tiledOps); | |||
1714 | ||||
1715 | return DiagnosedSilenceableFailure::success(); | |||
1716 | } | |||
1717 | ||||
1718 | void transform::TileToForeachThreadOp::getEffects( | |||
1719 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
1720 | consumesHandle(getTarget(), effects); | |||
1721 | onlyReadsHandle(getTileSizes(), effects); | |||
1722 | onlyReadsHandle(getNumThreads(), effects); | |||
1723 | producesHandle(getResults(), effects); | |||
1724 | } | |||
1725 | ||||
1726 | SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedNumThreads() { | |||
1727 | Builder b(getContext()); | |||
1728 | return getMixedValues(getStaticNumThreads(), getNumThreads(), b); | |||
1729 | } | |||
1730 | ||||
1731 | SmallVector<OpFoldResult> TileToForeachThreadOp::getMixedTileSizes() { | |||
1732 | Builder b(getContext()); | |||
1733 | return getMixedValues(getStaticTileSizes(), getTileSizes(), b); | |||
1734 | } | |||
1735 | ||||
1736 | LogicalResult TileToForeachThreadOp::verify() { | |||
1737 | int numThreadsSpec = static_cast<int>(!getMixedNumThreads().empty()) + | |||
1738 | static_cast<int>(getPackedNumThreads() != Value()); | |||
1739 | if (numThreadsSpec > 1) | |||
1740 | return emitOpError( | |||
1741 | "num_threads and packed_num_threads are mutually exclusive"); | |||
1742 | int tileSizesSpec = static_cast<int>(!getMixedTileSizes().empty()) + | |||
1743 | static_cast<int>(getPackedTileSizes() != Value()); | |||
1744 | if (tileSizesSpec > 1) | |||
1745 | return emitOpError( | |||
1746 | "tile_sizes and packed_tile_sizes are mutually exclusive"); | |||
1747 | if (numThreadsSpec == 0 && tileSizesSpec == 0) | |||
1748 | return emitOpError( | |||
1749 | "either (packed_)num_threads or (packed_)tile_sizes must be specified"); | |||
1750 | return success(); | |||
1751 | } | |||
1752 | ||||
1753 | //===----------------------------------------------------------------------===// | |||
1754 | // TileToScfForOp | |||
1755 | //===----------------------------------------------------------------------===// | |||
1756 | ||||
1757 | DiagnosedSilenceableFailure | |||
1758 | transform::TileToScfForOp::apply(TransformResults &transformResults, | |||
1759 | TransformState &state) { | |||
1760 | ArrayRef<int64_t> tileSizes = getStaticSizes(); | |||
1761 | ||||
1762 | ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); | |||
1763 | SmallVector<ArrayRef<Operation *>> dynamicSizeProducers; | |||
1764 | dynamicSizeProducers.reserve(getDynamicSizes().size()); | |||
1765 | for (Value dynamicSizeProducerHandle : getDynamicSizes()) { | |||
1766 | dynamicSizeProducers.push_back( | |||
1767 | state.getPayloadOps(dynamicSizeProducerHandle)); | |||
1768 | ||||
1769 | if (dynamicSizeProducers.back().size() != targets.size()) { | |||
1770 | DiagnosedSilenceableFailure diag = | |||
1771 | emitSilenceableError() | |||
1772 | << "expected as many dynamic size-producing operations (" | |||
1773 | << dynamicSizeProducers.back().size() << ") as target ops (" | |||
1774 | << targets.size() << ")"; | |||
1775 | diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; | |||
1776 | return diag; | |||
1777 | } | |||
1778 | ||||
1779 | for (Operation *op : dynamicSizeProducers.back()) { | |||
1780 | if (op->getNumResults() == 1 && | |||
1781 | op->getResult(0).getType().isa<IndexType>()) | |||
1782 | continue; | |||
1783 | DiagnosedSilenceableFailure diag = | |||
1784 | emitSilenceableError() << "expected sizes to be produced by ops " | |||
1785 | "with a single index-type result"; | |||
1786 | diag.attachNote(op->getLoc()) << "size producer op"; | |||
1787 | diag.attachNote(dynamicSizeProducerHandle.getLoc()) << "for this handle"; | |||
1788 | return diag; | |||
1789 | } | |||
1790 | } | |||
1791 | ||||
1792 | SmallVector<Operation *> tiled; | |||
1793 | SmallVector<SmallVector<Operation *, 4>, 4> loops; | |||
1794 | loops.resize(getLoops().size()); | |||
1795 | for (auto &en : llvm::enumerate(targets)) { | |||
1796 | auto tilingInterfaceOp = dyn_cast<TilingInterface>(en.value()); | |||
1797 | if (!tilingInterfaceOp) { | |||
1798 | DiagnosedSilenceableFailure diag = | |||
1799 | emitSilenceableError() << "only TilingInterface ops are supported"; | |||
1800 | diag.attachNote(en.value()->getLoc()) << "target op"; | |||
1801 | return diag; | |||
1802 | } | |||
1803 | ||||
1804 | scf::SCFTilingOptions tilingOptions; | |||
1805 | unsigned index = en.index(); | |||
1806 | if (!tileSizes.empty()) { | |||
1807 | tilingOptions.setTileSizeComputationFunction( | |||
1808 | [&, index](OpBuilder &b, Operation *) { | |||
1809 | SmallVector<Value, 4> sizes; | |||
1810 | sizes.reserve(tileSizes.size()); | |||
1811 | unsigned dynamicIdx = 0; | |||
1812 | for (OpFoldResult ofr : getMixedSizes()) { | |||
1813 | if (auto attr = ofr.dyn_cast<Attribute>()) { | |||
1814 | sizes.push_back(b.create<arith::ConstantIndexOp>( | |||
1815 | getLoc(), attr.cast<IntegerAttr>().getInt())); | |||
1816 | } else { | |||
1817 | sizes.push_back( | |||
1818 | dynamicSizeProducers[dynamicIdx++][index]->getResult(0)); | |||
1819 | } | |||
1820 | } | |||
1821 | return sizes; | |||
1822 | }); | |||
1823 | } | |||
1824 | ||||
1825 | tilingOptions.setInterchange(getInterchange()); | |||
1826 | TrivialPatternRewriter rewriter(tilingInterfaceOp.getContext()); | |||
1827 | FailureOr<scf::SCFTilingResult> tilingResult = | |||
1828 | tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); | |||
1829 | if (failed(tilingResult)) | |||
1830 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
1831 | ||||
1832 | rewriter.replaceOp(tilingInterfaceOp, tilingResult->replacements); | |||
1833 | ||||
1834 | tiled.append(tilingResult->tiledOps); | |||
1835 | for (const auto &en2 : llvm::enumerate(tilingResult->loops)) | |||
1836 | loops[en2.index()].push_back(en2.value()); | |||
1837 | } | |||
1838 | ||||
1839 | transformResults.set(getTiledLinalgOp().cast<OpResult>(), tiled); | |||
1840 | for (const auto &en : llvm::enumerate(loops)) | |||
1841 | transformResults.set(getLoops()[en.index()].cast<OpResult>(), en.value()); | |||
1842 | ||||
1843 | return DiagnosedSilenceableFailure::success(); | |||
1844 | } | |||
1845 | ||||
1846 | SmallVector<OpFoldResult> transform::TileToScfForOp::getMixedSizes() { | |||
1847 | ValueRange dynamic = getDynamicSizes(); | |||
1848 | ArrayRef<int64_t> tileSizes = getStaticSizes(); | |||
1849 | SmallVector<OpFoldResult> results; | |||
1850 | results.reserve(tileSizes.size()); | |||
1851 | unsigned dynamicPos = 0; | |||
1852 | Builder builder(getContext()); | |||
1853 | for (int64_t size : tileSizes) { | |||
1854 | if (size == ShapedType::kDynamic) { | |||
1855 | results.push_back(dynamic[dynamicPos++]); | |||
1856 | } else { | |||
1857 | results.push_back(builder.getIndexAttr(size)); | |||
1858 | } | |||
1859 | } | |||
1860 | return results; | |||
1861 | } | |||
1862 | ||||
1863 | ParseResult transform::TileToScfForOp::parse(OpAsmParser &parser, | |||
1864 | OperationState &result) { | |||
1865 | OpAsmParser::UnresolvedOperand target; | |||
1866 | SmallVector<OpAsmParser::UnresolvedOperand> dynamicSizes; | |||
1867 | DenseI64ArrayAttr staticSizes; | |||
1868 | auto pdlOperationType = pdl::OperationType::get(parser.getContext()); | |||
1869 | if (parser.parseOperand(target) || | |||
1870 | parser.resolveOperand(target, pdlOperationType, result.operands) || | |||
1871 | parseDynamicIndexList(parser, dynamicSizes, staticSizes) || | |||
1872 | parser.resolveOperands(dynamicSizes, pdlOperationType, result.operands)) | |||
1873 | return ParseResult::failure(); | |||
1874 | ||||
1875 | // Parse optional interchange. | |||
1876 | if (failed(parseOptionalInterchange(parser, result))) | |||
1877 | return ParseResult::failure(); | |||
1878 | result.addAttribute(getStaticSizesAttrName(result.name), staticSizes); | |||
1879 | size_t numExpectedLoops = | |||
1880 | staticSizes.size() - llvm::count(staticSizes.asArrayRef(), 0); | |||
1881 | result.addTypes(SmallVector<Type>(numExpectedLoops + 1, pdlOperationType)); | |||
1882 | return success(); | |||
1883 | } | |||
1884 | ||||
1885 | void TileToScfForOp::print(OpAsmPrinter &p) { | |||
1886 | p << ' ' << getTarget(); | |||
1887 | printDynamicIndexList(p, getOperation(), getDynamicSizes(), getStaticSizes()); | |||
1888 | printOptionalInterchange(p, getInterchange()); | |||
1889 | } | |||
1890 | ||||
1891 | void transform::TileToScfForOp::getEffects( | |||
1892 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
1893 | consumesHandle(getTarget(), effects); | |||
1894 | onlyReadsHandle(getDynamicSizes(), effects); | |||
1895 | producesHandle(getTiledLinalgOp(), effects); | |||
1896 | producesHandle(getLoops(), effects); | |||
1897 | modifiesPayload(effects); | |||
1898 | } | |||
1899 | ||||
1900 | //===----------------------------------------------------------------------===// | |||
1901 | // VectorizeOp | |||
1902 | //===----------------------------------------------------------------------===// | |||
1903 | ||||
1904 | void transform::VectorizeOp::build(OpBuilder &builder, OperationState &result, | |||
1905 | Value target, bool vectorizePadding, | |||
1906 | bool vectorizeExtract) { | |||
1907 | result.addOperands(target); | |||
1908 | if (vectorizePadding) { | |||
1909 | result.addAttribute(VectorizeOp::getVectorizePaddingAttrName(result.name), | |||
1910 | builder.getUnitAttr()); | |||
1911 | } | |||
1912 | if (vectorizeExtract) { | |||
1913 | result.addAttribute(VectorizeOp::getVectorizeNdExtractAttrName(result.name), | |||
1914 | builder.getUnitAttr()); | |||
1915 | } | |||
1916 | result.addTypes(pdl::OperationType::get(builder.getContext())); | |||
1917 | } | |||
1918 | ||||
1919 | namespace { | |||
1920 | /// This is an helper only to call vectorize via a pattern inside of | |||
1921 | /// VectorizeOp::applyToOne. | |||
1922 | struct VectorizationPattern : public RewritePattern { | |||
1923 | explicit VectorizationPattern(MLIRContext *context, | |||
1924 | bool vectorizeExtract = false) | |||
1925 | : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), | |||
1926 | vectorizeNDExtract(vectorizeExtract) {} | |||
1927 | LogicalResult matchAndRewrite(Operation *op, | |||
1928 | PatternRewriter &rewriter) const override { | |||
1929 | LinalgOp linalgOp = dyn_cast<LinalgOp>(op); | |||
1930 | if (!linalgOp) | |||
1931 | return rewriter.notifyMatchFailure(op, "expected Linalg Op"); | |||
1932 | return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{}, | |||
1933 | vectorizeNDExtract); | |||
1934 | } | |||
1935 | ||||
1936 | private: | |||
1937 | /// Controls whether to vectorize `tensor.extract` when the input tensor is | |||
1938 | /// rank >= 2. | |||
1939 | bool vectorizeNDExtract = false; | |||
1940 | }; | |||
1941 | } // namespace | |||
1942 | ||||
1943 | DiagnosedSilenceableFailure | |||
1944 | transform::VectorizeOp::applyToOne(Operation *target, | |||
1945 | SmallVectorImpl<Operation *> &results, | |||
1946 | transform::TransformState &state) { | |||
1947 | if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) { | |||
1948 | auto diag = this->emitOpError("requires isolated-from-above targets"); | |||
1949 | diag.attachNote(target->getLoc()) << "non-isolated target"; | |||
1950 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
1951 | } | |||
1952 | ||||
1953 | MLIRContext *ctx = getContext(); | |||
1954 | RewritePatternSet patterns(ctx); | |||
1955 | patterns.add<VectorizationPattern>(ctx, getVectorizeNdExtract()); | |||
1956 | ||||
1957 | if (!getDisableTransferPermutationMapLoweringPatterns()) | |||
1958 | vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); | |||
1959 | ||||
1960 | if (!getDisableMultiReductionToContractPatterns()) | |||
1961 | vector::populateVectorReductionToContractPatterns(patterns); | |||
1962 | ||||
1963 | patterns.add<linalg::LinalgCopyVTRForwardingPattern, | |||
1964 | linalg::LinalgCopyVTWForwardingPattern>(ctx, | |||
1965 | /*benefit=*/2); | |||
1966 | vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); | |||
1967 | vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); | |||
1968 | ||||
1969 | patterns.add<CopyVectorizationPattern>(ctx); | |||
1970 | ||||
1971 | if (getVectorizePadding()) | |||
1972 | linalg::populatePadOpVectorizationPatterns(patterns); | |||
1973 | ||||
1974 | if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) | |||
1975 | return emitDefaultDefiniteFailure(target); | |||
1976 | ||||
1977 | results.push_back(target); | |||
1978 | return DiagnosedSilenceableFailure::success(); | |||
1979 | } | |||
1980 | ||||
1981 | //===----------------------------------------------------------------------===// | |||
1982 | // MaskedVectorizeOp | |||
1983 | //===----------------------------------------------------------------------===// | |||
1984 | ||||
1985 | DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( | |||
1986 | mlir::transform::TransformResults &transformResults, | |||
1987 | mlir::transform::TransformState &state) { | |||
1988 | IRRewriter rewriter(getContext()); | |||
1989 | ArrayRef<Operation *> targets = state.getPayloadOps(getTarget()); | |||
1990 | if (targets.empty()) | |||
1991 | return DiagnosedSilenceableFailure::success(); | |||
1992 | ||||
1993 | SmallVector<int64_t> vectorSizes; | |||
1994 | for (OpFoldResult sz : getMixedVectorSizes()) { | |||
1995 | if (sz.is<Attribute>()) { | |||
1996 | auto attr = sz.get<Attribute>(); | |||
1997 | vectorSizes.push_back(attr.cast<IntegerAttr>().getInt()); | |||
1998 | continue; | |||
1999 | } | |||
2000 | ||||
2001 | ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>()); | |||
2002 | if (szPayloads.size() != 1) { | |||
2003 | auto diag = this->emitOpError( | |||
2004 | "requires vector size handle that is mapped to 1 payload op"); | |||
2005 | diag.attachNote(sz.get<Value>().getLoc()) | |||
2006 | << "mapped to " << szPayloads.size() << " payload ops"; | |||
2007 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
2008 | } | |||
2009 | ||||
2010 | Operation *szPayloadOp = szPayloads[0]; | |||
2011 | if (szPayloadOp->getNumResults() != 1 || | |||
2012 | !szPayloadOp->getResult(0).getType().isIndex()) { | |||
2013 | auto diag = this->emitOpError( | |||
2014 | "requires vector size payload op with 1 index result"); | |||
2015 | diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; | |||
2016 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
2017 | } | |||
2018 | ||||
2019 | IntegerAttr attr; | |||
2020 | if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) { | |||
2021 | auto diag = this->emitOpError("requires constant vector size"); | |||
2022 | diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op"; | |||
2023 | return DiagnosedSilenceableFailure::definiteFailure(); | |||
2024 | } | |||
2025 | ||||
2026 | vectorSizes.push_back(attr.getInt()); | |||
2027 | } | |||
2028 | ||||
2029 | // TODO: Check that the correct number of vectorSizes was provided. | |||
2030 | ||||
2031 | for (Operation *target : targets) { | |||
2032 | auto linalgOp = dyn_cast<LinalgOp>(target); | |||
2033 | if (!linalgOp) { | |||
2034 | Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); | |||
2035 | diag << "cannot vectorize non-Linalg op"; | |||
2036 | return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); | |||
2037 | } | |||
2038 | ||||
2039 | if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) { | |||
2040 | Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error); | |||
2041 | diag << "failed to vectorize op"; | |||
2042 | return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag)); | |||
2043 | } | |||
2044 | } | |||
2045 | ||||
2046 | return DiagnosedSilenceableFailure::success(); | |||
2047 | } | |||
2048 | ||||
2049 | void transform::MaskedVectorizeOp::getEffects( | |||
2050 | SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | |||
2051 | consumesHandle(getTarget(), effects); | |||
2052 | onlyReadsHandle(getVectorSizes(), effects); | |||
2053 | } | |||
2054 | ||||
2055 | SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() { | |||
2056 | OpBuilder b(getContext()); | |||
2057 | return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b); | |||
2058 | } | |||
2059 | ||||
2060 | //===----------------------------------------------------------------------===// | |||
2061 | // Transform op registration | |||
2062 | //===----------------------------------------------------------------------===// | |||
2063 | ||||
2064 | namespace { | |||
2065 | /// Registers new ops and declares PDL as dependent dialect since the | |||
2066 | /// additional ops are using PDL types for operands and results. | |||
2067 | class LinalgTransformDialectExtension | |||
2068 | : public transform::TransformDialectExtension< | |||
2069 | LinalgTransformDialectExtension> { | |||
2070 | public: | |||
2071 | using Base::Base; | |||
2072 | ||||
2073 | void init() { | |||
2074 | declareDependentDialect<pdl::PDLDialect>(); | |||
2075 | declareDependentDialect<LinalgDialect>(); | |||
2076 | declareGeneratedDialect<AffineDialect>(); | |||
2077 | declareGeneratedDialect<arith::ArithDialect>(); | |||
2078 | declareGeneratedDialect<scf::SCFDialect>(); | |||
2079 | declareGeneratedDialect<vector::VectorDialect>(); | |||
2080 | declareGeneratedDialect<gpu::GPUDialect>(); | |||
2081 | ||||
2082 | registerTransformOps< | |||
2083 | #define GET_OP_LIST | |||
2084 | #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" | |||
2085 | >(); | |||
2086 | } | |||
2087 | }; | |||
2088 | } // namespace | |||
2089 | ||||
2090 | #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" | |||
2091 | ||||
2092 | #define GET_OP_CLASSES | |||
2093 | #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp.inc" | |||
2094 | ||||
2095 | void mlir::linalg::registerTransformDialectExtension( | |||
2096 | DialectRegistry ®istry) { | |||
2097 | registry.addExtensions<LinalgTransformDialectExtension>(); | |||
2098 | } |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | class AsmParsedResourceEntry; |
24 | class AsmResourceBuilder; |
25 | class Builder; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // AsmDialectResourceHandle |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// This class represents an opaque handle to a dialect resource entry. |
32 | class AsmDialectResourceHandle { |
33 | public: |
34 | AsmDialectResourceHandle() = default; |
35 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
36 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
37 | bool operator==(const AsmDialectResourceHandle &other) const { |
38 | return resource == other.resource; |
39 | } |
40 | |
41 | /// Return an opaque pointer to the referenced resource. |
42 | void *getResource() const { return resource; } |
43 | |
44 | /// Return the type ID of the resource. |
45 | TypeID getTypeID() const { return opaqueID; } |
46 | |
47 | /// Return the dialect that owns the resource. |
48 | Dialect *getDialect() const { return dialect; } |
49 | |
50 | private: |
51 | /// The opaque handle to the dialect resource. |
52 | void *resource = nullptr; |
53 | /// The type of the resource referenced. |
54 | TypeID opaqueID; |
55 | /// The dialect owning the given resource. |
56 | Dialect *dialect; |
57 | }; |
58 | |
59 | /// This class represents a CRTP base class for dialect resource handles. It |
60 | /// abstracts away various utilities necessary for defined derived resource |
61 | /// handles. |
62 | template <typename DerivedT, typename ResourceT, typename DialectT> |
63 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
64 | public: |
65 | using Dialect = DialectT; |
66 | |
67 | /// Construct a handle from a pointer to the resource. The given pointer |
68 | /// should be guaranteed to live beyond the life of this handle. |
69 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
70 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
71 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
72 | : AsmDialectResourceHandle(handle) { |
73 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__ __PRETTY_FUNCTION__)); |
74 | } |
75 | |
76 | /// Return the resource referenced by this handle. |
77 | ResourceT *getResource() { |
78 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
79 | } |
80 | const ResourceT *getResource() const { |
81 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
82 | } |
83 | |
84 | /// Return the dialect that owns the resource. |
85 | DialectT *getDialect() const { |
86 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
87 | } |
88 | |
89 | /// Support llvm style casting. |
90 | static bool classof(const AsmDialectResourceHandle *handle) { |
91 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
92 | } |
93 | }; |
94 | |
95 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
96 | return llvm::hash_value(param.getResource()); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // AsmPrinter |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | /// This base class exposes generic asm printer hooks, usable across the various |
104 | /// derived printers. |
105 | class AsmPrinter { |
106 | public: |
107 | /// This class contains the internal default implementation of the base |
108 | /// printer methods. |
109 | class Impl; |
110 | |
111 | /// Initialize the printer with the given internal implementation. |
112 | AsmPrinter(Impl &impl) : impl(&impl) {} |
113 | virtual ~AsmPrinter(); |
114 | |
115 | /// Return the raw output stream used by this printer. |
116 | virtual raw_ostream &getStream() const; |
117 | |
118 | /// Print the given floating point value in a stabilized form that can be |
119 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
120 | /// hook on the AsmParser. |
121 | virtual void printFloat(const APFloat &value); |
122 | |
123 | virtual void printType(Type type); |
124 | virtual void printAttribute(Attribute attr); |
125 | |
126 | /// Trait to check if `AttrType` provides a `print` method. |
127 | template <typename AttrOrType> |
128 | using has_print_method = |
129 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
130 | template <typename AttrOrType> |
131 | using detect_has_print_method = |
132 | llvm::is_detected<has_print_method, AttrOrType>; |
133 | |
134 | /// Print the provided attribute in the context of an operation custom |
135 | /// printer/parser: this will invoke directly the print method on the |
136 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
137 | template <typename AttrOrType, |
138 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
139 | *sfinae = nullptr> |
140 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
141 | if (succeeded(printAlias(attrOrType))) |
142 | return; |
143 | attrOrType.print(*this); |
144 | } |
145 | |
146 | /// Print the provided array of attributes or types in the context of an |
147 | /// operation custom printer/parser: this will invoke directly the print |
148 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
149 | /// most cases. |
150 | template <typename AttrOrType, |
151 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
152 | *sfinae = nullptr> |
153 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
154 | llvm::interleaveComma( |
155 | attrOrTypes, getStream(), |
156 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
157 | } |
158 | |
159 | /// SFINAE for printing the provided attribute in the context of an operation |
160 | /// custom printer in the case where the attribute does not define a print |
161 | /// method. |
162 | template <typename AttrOrType, |
163 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
164 | *sfinae = nullptr> |
165 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
166 | *this << attrOrType; |
167 | } |
168 | |
169 | /// Print the given attribute without its type. The corresponding parser must |
170 | /// provide a valid type for the attribute. |
171 | virtual void printAttributeWithoutType(Attribute attr); |
172 | |
173 | /// Print the given string as a keyword, or a quoted and escaped string if it |
174 | /// has any special or non-printable characters in it. |
175 | virtual void printKeywordOrString(StringRef keyword); |
176 | |
177 | /// Print the given string as a symbol reference, i.e. a form representable by |
178 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
179 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
180 | /// special or non-printable characters in it. |
181 | virtual void printSymbolName(StringRef symbolRef); |
182 | |
183 | /// Print a handle to the given dialect resource. |
184 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); |
185 | |
186 | /// Print an optional arrow followed by a type list. |
187 | template <typename TypeRange> |
188 | void printOptionalArrowTypeList(TypeRange &&types) { |
189 | if (types.begin() != types.end()) |
190 | printArrowTypeList(types); |
191 | } |
192 | template <typename TypeRange> |
193 | void printArrowTypeList(TypeRange &&types) { |
194 | auto &os = getStream() << " -> "; |
195 | |
196 | bool wrapped = !llvm::hasSingleElement(types) || |
197 | (*types.begin()).template isa<FunctionType>(); |
198 | if (wrapped) |
199 | os << '('; |
200 | llvm::interleaveComma(types, *this); |
201 | if (wrapped) |
202 | os << ')'; |
203 | } |
204 | |
205 | /// Print the two given type ranges in a functional form. |
206 | template <typename InputRangeT, typename ResultRangeT> |
207 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
208 | auto &os = getStream(); |
209 | os << '('; |
210 | llvm::interleaveComma(inputs, *this); |
211 | os << ')'; |
212 | printArrowTypeList(results); |
213 | } |
214 | |
215 | protected: |
216 | /// Initialize the printer with no internal implementation. In this case, all |
217 | /// virtual methods of this class must be overriden. |
218 | AsmPrinter() = default; |
219 | |
220 | private: |
221 | AsmPrinter(const AsmPrinter &) = delete; |
222 | void operator=(const AsmPrinter &) = delete; |
223 | |
224 | /// Print the alias for the given attribute, return failure if no alias could |
225 | /// be printed. |
226 | virtual LogicalResult printAlias(Attribute attr); |
227 | |
228 | /// Print the alias for the given type, return failure if no alias could |
229 | /// be printed. |
230 | virtual LogicalResult printAlias(Type type); |
231 | |
232 | /// The internal implementation of the printer. |
233 | Impl *impl{nullptr}; |
234 | }; |
235 | |
236 | template <typename AsmPrinterT> |
237 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
238 | AsmPrinterT &> |
239 | operator<<(AsmPrinterT &p, Type type) { |
240 | p.printType(type); |
241 | return p; |
242 | } |
243 | |
244 | template <typename AsmPrinterT> |
245 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
246 | AsmPrinterT &> |
247 | operator<<(AsmPrinterT &p, Attribute attr) { |
248 | p.printAttribute(attr); |
249 | return p; |
250 | } |
251 | |
252 | template <typename AsmPrinterT> |
253 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
254 | AsmPrinterT &> |
255 | operator<<(AsmPrinterT &p, const APFloat &value) { |
256 | p.printFloat(value); |
257 | return p; |
258 | } |
259 | template <typename AsmPrinterT> |
260 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
261 | AsmPrinterT &> |
262 | operator<<(AsmPrinterT &p, float value) { |
263 | return p << APFloat(value); |
264 | } |
265 | template <typename AsmPrinterT> |
266 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
267 | AsmPrinterT &> |
268 | operator<<(AsmPrinterT &p, double value) { |
269 | return p << APFloat(value); |
270 | } |
271 | |
272 | // Support printing anything that isn't convertible to one of the other |
273 | // streamable types, even if it isn't exactly one of them. For example, we want |
274 | // to print FunctionType with the Type version above, not have it match this. |
275 | template <typename AsmPrinterT, typename T, |
276 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && |
277 | !std::is_convertible<T &, Type &>::value && |
278 | !std::is_convertible<T &, Attribute &>::value && |
279 | !std::is_convertible<T &, ValueRange>::value && |
280 | !std::is_convertible<T &, APFloat &>::value && |
281 | !llvm::is_one_of<T, bool, float, double>::value, |
282 | T> * = nullptr> |
283 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
284 | AsmPrinterT &> |
285 | operator<<(AsmPrinterT &p, const T &other) { |
286 | p.getStream() << other; |
287 | return p; |
288 | } |
289 | |
290 | template <typename AsmPrinterT> |
291 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
292 | AsmPrinterT &> |
293 | operator<<(AsmPrinterT &p, bool value) { |
294 | return p << (value ? StringRef("true") : "false"); |
295 | } |
296 | |
297 | template <typename AsmPrinterT, typename ValueRangeT> |
298 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
299 | AsmPrinterT &> |
300 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
301 | llvm::interleaveComma(types, p); |
302 | return p; |
303 | } |
304 | template <typename AsmPrinterT> |
305 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
306 | AsmPrinterT &> |
307 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
308 | llvm::interleaveComma(types, p); |
309 | return p; |
310 | } |
311 | template <typename AsmPrinterT, typename ElementT> |
312 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
313 | AsmPrinterT &> |
314 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
315 | llvm::interleaveComma(types, p); |
316 | return p; |
317 | } |
318 | |
319 | //===----------------------------------------------------------------------===// |
320 | // OpAsmPrinter |
321 | //===----------------------------------------------------------------------===// |
322 | |
323 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
324 | /// necessary to implement a custom print() method. |
325 | class OpAsmPrinter : public AsmPrinter { |
326 | public: |
327 | using AsmPrinter::AsmPrinter; |
328 | ~OpAsmPrinter() override; |
329 | |
330 | /// Print a loc(...) specifier if printing debug info is enabled. |
331 | virtual void printOptionalLocationSpecifier(Location loc) = 0; |
332 | |
333 | /// Print a newline and indent the printer to the start of the current |
334 | /// operation. |
335 | virtual void printNewline() = 0; |
336 | |
337 | /// Increase indentation. |
338 | virtual void increaseIndent() = 0; |
339 | |
340 | /// Decrease indentation. |
341 | virtual void decreaseIndent() = 0; |
342 | |
343 | /// Print a block argument in the usual format of: |
344 | /// %ssaName : type {attr1=42} loc("here") |
345 | /// where location printing is controlled by the standard internal option. |
346 | /// You may pass omitType=true to not print a type, and pass an empty |
347 | /// attribute list if you don't care for attributes. |
348 | virtual void printRegionArgument(BlockArgument arg, |
349 | ArrayRef<NamedAttribute> argAttrs = {}, |
350 | bool omitType = false) = 0; |
351 | |
352 | /// Print implementations for various things an operation contains. |
353 | virtual void printOperand(Value value) = 0; |
354 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
355 | |
356 | /// Print a comma separated list of operands. |
357 | template <typename ContainerType> |
358 | void printOperands(const ContainerType &container) { |
359 | printOperands(container.begin(), container.end()); |
360 | } |
361 | |
362 | /// Print a comma separated list of operands. |
363 | template <typename IteratorType> |
364 | void printOperands(IteratorType it, IteratorType end) { |
365 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), |
366 | [this](Value value) { printOperand(value); }); |
367 | } |
368 | |
369 | /// Print the given successor. |
370 | virtual void printSuccessor(Block *successor) = 0; |
371 | |
372 | /// Print the successor and its operands. |
373 | virtual void printSuccessorAndUseList(Block *successor, |
374 | ValueRange succOperands) = 0; |
375 | |
376 | /// If the specified operation has attributes, print out an attribute |
377 | /// dictionary with their values. elidedAttrs allows the client to ignore |
378 | /// specific well known attributes, commonly used if the attribute value is |
379 | /// printed some other way (like as a fixed operand). |
380 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
381 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
382 | |
383 | /// If the specified operation has attributes, print out an attribute |
384 | /// dictionary prefixed with 'attributes'. |
385 | virtual void |
386 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
387 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
388 | |
389 | /// Prints the entire operation with the custom assembly form, if available, |
390 | /// or the generic assembly form, otherwise. |
391 | virtual void printCustomOrGenericOp(Operation *op) = 0; |
392 | |
393 | /// Print the entire operation with the default generic assembly form. |
394 | /// If `printOpName` is true, then the operation name is printed (the default) |
395 | /// otherwise it is omitted and the print will start with the operand list. |
396 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
397 | |
398 | /// Prints a region. |
399 | /// If 'printEntryBlockArgs' is false, the arguments of the |
400 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
401 | /// operation of the block is not printed. If printEmptyBlock is true, then |
402 | /// the block header is printed even if the block is empty. |
403 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
404 | bool printBlockTerminators = true, |
405 | bool printEmptyBlock = false) = 0; |
406 | |
407 | /// Renumber the arguments for the specified region to the same names as the |
408 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
409 | /// operations. If any entry in namesToUse is null, the corresponding |
410 | /// argument name is left alone. |
411 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
412 | |
413 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
414 | /// of dims/symbols. |
415 | /// Operand values must come from single-result sources, and be valid |
416 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
417 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
418 | ValueRange operands) = 0; |
419 | |
420 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
421 | /// dims and symbols. |
422 | /// Operand values must come from single-result sources, and be valid |
423 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
424 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
425 | ValueRange symOperands) = 0; |
426 | |
427 | /// Print the complete type of an operation in functional form. |
428 | void printFunctionalType(Operation *op); |
429 | using AsmPrinter::printFunctionalType; |
430 | }; |
431 | |
432 | // Make the implementations convenient to use. |
433 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
434 | p.printOperand(value); |
435 | return p; |
436 | } |
437 | |
438 | template <typename T, |
439 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && |
440 | !std::is_convertible<T &, Value &>::value, |
441 | T> * = nullptr> |
442 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
443 | p.printOperands(values); |
444 | return p; |
445 | } |
446 | |
447 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
448 | p.printSuccessor(value); |
449 | return p; |
450 | } |
451 | |
452 | //===----------------------------------------------------------------------===// |
453 | // AsmParser |
454 | //===----------------------------------------------------------------------===// |
455 | |
456 | /// This base class exposes generic asm parser hooks, usable across the various |
457 | /// derived parsers. |
458 | class AsmParser { |
459 | public: |
460 | AsmParser() = default; |
461 | virtual ~AsmParser(); |
462 | |
463 | MLIRContext *getContext() const; |
464 | |
465 | /// Return the location of the original name token. |
466 | virtual SMLoc getNameLoc() const = 0; |
467 | |
468 | //===--------------------------------------------------------------------===// |
469 | // Utilities |
470 | //===--------------------------------------------------------------------===// |
471 | |
472 | /// Emit a diagnostic at the specified location and return failure. |
473 | virtual InFlightDiagnostic emitError(SMLoc loc, |
474 | const Twine &message = {}) = 0; |
475 | |
476 | /// Return a builder which provides useful access to MLIRContext, global |
477 | /// objects like types and attributes. |
478 | virtual Builder &getBuilder() const = 0; |
479 | |
480 | /// Get the location of the next token and store it into the argument. This |
481 | /// always succeeds. |
482 | virtual SMLoc getCurrentLocation() = 0; |
483 | ParseResult getCurrentLocation(SMLoc *loc) { |
484 | *loc = getCurrentLocation(); |
485 | return success(); |
486 | } |
487 | |
488 | /// Re-encode the given source location as an MLIR location and return it. |
489 | /// Note: This method should only be used when a `Location` is necessary, as |
490 | /// the encoding process is not efficient. |
491 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
492 | |
493 | //===--------------------------------------------------------------------===// |
494 | // Token Parsing |
495 | //===--------------------------------------------------------------------===// |
496 | |
497 | /// Parse a '->' token. |
498 | virtual ParseResult parseArrow() = 0; |
499 | |
500 | /// Parse a '->' token if present |
501 | virtual ParseResult parseOptionalArrow() = 0; |
502 | |
503 | /// Parse a `{` token. |
504 | virtual ParseResult parseLBrace() = 0; |
505 | |
506 | /// Parse a `{` token if present. |
507 | virtual ParseResult parseOptionalLBrace() = 0; |
508 | |
509 | /// Parse a `}` token. |
510 | virtual ParseResult parseRBrace() = 0; |
511 | |
512 | /// Parse a `}` token if present. |
513 | virtual ParseResult parseOptionalRBrace() = 0; |
514 | |
515 | /// Parse a `:` token. |
516 | virtual ParseResult parseColon() = 0; |
517 | |
518 | /// Parse a `:` token if present. |
519 | virtual ParseResult parseOptionalColon() = 0; |
520 | |
521 | /// Parse a `,` token. |
522 | virtual ParseResult parseComma() = 0; |
523 | |
524 | /// Parse a `,` token if present. |
525 | virtual ParseResult parseOptionalComma() = 0; |
526 | |
527 | /// Parse a `=` token. |
528 | virtual ParseResult parseEqual() = 0; |
529 | |
530 | /// Parse a `=` token if present. |
531 | virtual ParseResult parseOptionalEqual() = 0; |
532 | |
533 | /// Parse a '<' token. |
534 | virtual ParseResult parseLess() = 0; |
535 | |
536 | /// Parse a '<' token if present. |
537 | virtual ParseResult parseOptionalLess() = 0; |
538 | |
539 | /// Parse a '>' token. |
540 | virtual ParseResult parseGreater() = 0; |
541 | |
542 | /// Parse a '>' token if present. |
543 | virtual ParseResult parseOptionalGreater() = 0; |
544 | |
545 | /// Parse a '?' token. |
546 | virtual ParseResult parseQuestion() = 0; |
547 | |
548 | /// Parse a '?' token if present. |
549 | virtual ParseResult parseOptionalQuestion() = 0; |
550 | |
551 | /// Parse a '+' token. |
552 | virtual ParseResult parsePlus() = 0; |
553 | |
554 | /// Parse a '+' token if present. |
555 | virtual ParseResult parseOptionalPlus() = 0; |
556 | |
557 | /// Parse a '*' token. |
558 | virtual ParseResult parseStar() = 0; |
559 | |
560 | /// Parse a '*' token if present. |
561 | virtual ParseResult parseOptionalStar() = 0; |
562 | |
563 | /// Parse a '|' token. |
564 | virtual ParseResult parseVerticalBar() = 0; |
565 | |
566 | /// Parse a '|' token if present. |
567 | virtual ParseResult parseOptionalVerticalBar() = 0; |
568 | |
569 | /// Parse a quoted string token. |
570 | ParseResult parseString(std::string *string) { |
571 | auto loc = getCurrentLocation(); |
572 | if (parseOptionalString(string)) |
573 | return emitError(loc, "expected string"); |
574 | return success(); |
575 | } |
576 | |
577 | /// Parse a quoted string token if present. |
578 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
579 | |
580 | /// Parses a Base64 encoded string of bytes. |
581 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; |
582 | |
583 | /// Parse a `(` token. |
584 | virtual ParseResult parseLParen() = 0; |
585 | |
586 | /// Parse a `(` token if present. |
587 | virtual ParseResult parseOptionalLParen() = 0; |
588 | |
589 | /// Parse a `)` token. |
590 | virtual ParseResult parseRParen() = 0; |
591 | |
592 | /// Parse a `)` token if present. |
593 | virtual ParseResult parseOptionalRParen() = 0; |
594 | |
595 | /// Parse a `[` token. |
596 | virtual ParseResult parseLSquare() = 0; |
597 | |
598 | /// Parse a `[` token if present. |
599 | virtual ParseResult parseOptionalLSquare() = 0; |
600 | |
601 | /// Parse a `]` token. |
602 | virtual ParseResult parseRSquare() = 0; |
603 | |
604 | /// Parse a `]` token if present. |
605 | virtual ParseResult parseOptionalRSquare() = 0; |
606 | |
607 | /// Parse a `...` token. |
608 | virtual ParseResult parseEllipsis() = 0; |
609 | |
610 | /// Parse a `...` token if present; |
611 | virtual ParseResult parseOptionalEllipsis() = 0; |
612 | |
613 | /// Parse a floating point value from the stream. |
614 | virtual ParseResult parseFloat(double &result) = 0; |
615 | |
616 | /// Parse an integer value from the stream. |
617 | template <typename IntT> |
618 | ParseResult parseInteger(IntT &result) { |
619 | auto loc = getCurrentLocation(); |
620 | OptionalParseResult parseResult = parseOptionalInteger(result); |
621 | if (!parseResult.has_value()) |
622 | return emitError(loc, "expected integer value"); |
623 | return *parseResult; |
624 | } |
625 | |
626 | /// Parse an optional integer value from the stream. |
627 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
628 | |
629 | template <typename IntT> |
630 | OptionalParseResult parseOptionalInteger(IntT &result) { |
631 | auto loc = getCurrentLocation(); |
632 | |
633 | // Parse the unsigned variant. |
634 | APInt uintResult; |
635 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
636 | if (!parseResult.has_value() || failed(*parseResult)) |
637 | return parseResult; |
638 | |
639 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
640 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
641 | // zero for non-negated integers. |
642 | result = |
643 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
644 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
645 | return emitError(loc, "integer value too large"); |
646 | return success(); |
647 | } |
648 | |
649 | /// These are the supported delimiters around operand lists and region |
650 | /// argument lists, used by parseOperandList. |
651 | enum class Delimiter { |
652 | /// Zero or more operands with no delimiters. |
653 | None, |
654 | /// Parens surrounding zero or more operands. |
655 | Paren, |
656 | /// Square brackets surrounding zero or more operands. |
657 | Square, |
658 | /// <> brackets surrounding zero or more operands. |
659 | LessGreater, |
660 | /// {} brackets surrounding zero or more operands. |
661 | Braces, |
662 | /// Parens supporting zero or more operands, or nothing. |
663 | OptionalParen, |
664 | /// Square brackets supporting zero or more ops, or nothing. |
665 | OptionalSquare, |
666 | /// <> brackets supporting zero or more ops, or nothing. |
667 | OptionalLessGreater, |
668 | /// {} brackets surrounding zero or more operands, or nothing. |
669 | OptionalBraces, |
670 | }; |
671 | |
672 | /// Parse a list of comma-separated items with an optional delimiter. If a |
673 | /// delimiter is provided, then an empty list is allowed. If not, then at |
674 | /// least one element will be parsed. |
675 | /// |
676 | /// contextMessage is an optional message appended to "expected '('" sorts of |
677 | /// diagnostics when parsing the delimeters. |
678 | virtual ParseResult |
679 | parseCommaSeparatedList(Delimiter delimiter, |
680 | function_ref<ParseResult()> parseElementFn, |
681 | StringRef contextMessage = StringRef()) = 0; |
682 | |
683 | /// Parse a comma separated list of elements that must have at least one entry |
684 | /// in it. |
685 | ParseResult |
686 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
687 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
688 | } |
689 | |
690 | //===--------------------------------------------------------------------===// |
691 | // Keyword Parsing |
692 | //===--------------------------------------------------------------------===// |
693 | |
694 | /// This class represents a StringSwitch like class that is useful for parsing |
695 | /// expected keywords. On construction, it invokes `parseKeyword` and |
696 | /// processes each of the provided cases statements until a match is hit. The |
697 | /// provided `ResultT` must be assignable from `failure()`. |
698 | template <typename ResultT = ParseResult> |
699 | class KeywordSwitch { |
700 | public: |
701 | KeywordSwitch(AsmParser &parser) |
702 | : parser(parser), loc(parser.getCurrentLocation()) { |
703 | if (failed(parser.parseKeywordOrCompletion(&keyword))) |
704 | result = failure(); |
705 | } |
706 | |
707 | /// Case that uses the provided value when true. |
708 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
709 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
710 | } |
711 | KeywordSwitch &Default(ResultT value) { |
712 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
713 | } |
714 | /// Case that invokes the provided functor when true. The parameters passed |
715 | /// to the functor are the keyword, and the location of the keyword (in case |
716 | /// any errors need to be emitted). |
717 | template <typename FnT> |
718 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
719 | Case(StringLiteral str, FnT &&fn) { |
720 | if (result) |
721 | return *this; |
722 | |
723 | // If the word was empty, record this as a completion. |
724 | if (keyword.empty()) |
725 | parser.codeCompleteExpectedTokens(str); |
726 | else if (keyword == str) |
727 | result.emplace(std::move(fn(keyword, loc))); |
728 | return *this; |
729 | } |
730 | template <typename FnT> |
731 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
732 | Default(FnT &&fn) { |
733 | if (!result) |
734 | result.emplace(fn(keyword, loc)); |
735 | return *this; |
736 | } |
737 | |
738 | /// Returns true if this switch has a value yet. |
739 | bool hasValue() const { return result.has_value(); } |
740 | |
741 | /// Return the result of the switch. |
742 | [[nodiscard]] operator ResultT() { |
743 | if (!result) |
744 | return parser.emitError(loc, "unexpected keyword: ") << keyword; |
745 | return std::move(*result); |
746 | } |
747 | |
748 | private: |
749 | /// The parser used to construct this switch. |
750 | AsmParser &parser; |
751 | |
752 | /// The location of the keyword, used to emit errors as necessary. |
753 | SMLoc loc; |
754 | |
755 | /// The parsed keyword itself. |
756 | StringRef keyword; |
757 | |
758 | /// The result of the switch statement or none if currently unknown. |
759 | Optional<ResultT> result; |
760 | }; |
761 | |
762 | /// Parse a given keyword. |
763 | ParseResult parseKeyword(StringRef keyword) { |
764 | return parseKeyword(keyword, ""); |
765 | } |
766 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
767 | |
768 | /// Parse a keyword into 'keyword'. |
769 | ParseResult parseKeyword(StringRef *keyword) { |
770 | auto loc = getCurrentLocation(); |
771 | if (parseOptionalKeyword(keyword)) |
772 | return emitError(loc, "expected valid keyword"); |
773 | return success(); |
774 | } |
775 | |
776 | /// Parse the given keyword if present. |
777 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
778 | |
779 | /// Parse a keyword, if present, into 'keyword'. |
780 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
781 | |
782 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
783 | /// into 'keyword' |
784 | virtual ParseResult |
785 | parseOptionalKeyword(StringRef *keyword, |
786 | ArrayRef<StringRef> allowedValues) = 0; |
787 | |
788 | /// Parse a keyword or a quoted string. |
789 | ParseResult parseKeywordOrString(std::string *result) { |
790 | if (failed(parseOptionalKeywordOrString(result))) |
791 | return emitError(getCurrentLocation()) |
792 | << "expected valid keyword or string"; |
793 | return success(); |
794 | } |
795 | |
796 | /// Parse an optional keyword or string. |
797 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
798 | |
799 | //===--------------------------------------------------------------------===// |
800 | // Attribute/Type Parsing |
801 | //===--------------------------------------------------------------------===// |
802 | |
803 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
804 | /// the provided location to emit errors in the case of failure. Note that |
805 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
806 | /// context parameter. |
807 | template <typename T, typename... ParamsT> |
808 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
809 | return T::getChecked([&] { return emitError(loc); }, |
810 | std::forward<ParamsT>(params)...); |
811 | } |
812 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
813 | /// errors. |
814 | template <typename T, typename... ParamsT> |
815 | auto getChecked(ParamsT &&...params) { |
816 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
817 | std::forward<ParamsT>(params)...); |
818 | } |
819 | |
820 | //===--------------------------------------------------------------------===// |
821 | // Attribute Parsing |
822 | //===--------------------------------------------------------------------===// |
823 | |
824 | /// Parse an arbitrary attribute of a given type and return it in result. |
825 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
826 | |
827 | /// Parse a custom attribute with the provided callback, unless the next |
828 | /// token is `#`, in which case the generic parser is invoked. |
829 | virtual ParseResult parseCustomAttributeWithFallback( |
830 | Attribute &result, Type type, |
831 | function_ref<ParseResult(Attribute &result, Type type)> |
832 | parseAttribute) = 0; |
833 | |
834 | /// Parse an attribute of a specific kind and type. |
835 | template <typename AttrType> |
836 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
837 | SMLoc loc = getCurrentLocation(); |
838 | |
839 | // Parse any kind of attribute. |
840 | Attribute attr; |
841 | if (parseAttribute(attr, type)) |
842 | return failure(); |
843 | |
844 | // Check for the right kind of attribute. |
845 | if (!(result = attr.dyn_cast<AttrType>())) |
846 | return emitError(loc, "invalid kind of attribute specified"); |
847 | |
848 | return success(); |
849 | } |
850 | |
851 | /// Parse an arbitrary attribute and return it in result. This also adds the |
852 | /// attribute to the specified attribute list with the specified name. |
853 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
854 | NamedAttrList &attrs) { |
855 | return parseAttribute(result, Type(), attrName, attrs); |
856 | } |
857 | |
858 | /// Parse an attribute of a specific kind and type. |
859 | template <typename AttrType> |
860 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
861 | NamedAttrList &attrs) { |
862 | return parseAttribute(result, Type(), attrName, attrs); |
863 | } |
864 | |
865 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
866 | /// This also adds the attribute to the specified attribute list with the |
867 | /// specified name. |
868 | template <typename AttrType> |
869 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
870 | NamedAttrList &attrs) { |
871 | SMLoc loc = getCurrentLocation(); |
872 | |
873 | // Parse any kind of attribute. |
874 | Attribute attr; |
875 | if (parseAttribute(attr, type)) |
876 | return failure(); |
877 | |
878 | // Check for the right kind of attribute. |
879 | result = attr.dyn_cast<AttrType>(); |
880 | if (!result) |
881 | return emitError(loc, "invalid kind of attribute specified"); |
882 | |
883 | attrs.append(attrName, result); |
884 | return success(); |
885 | } |
886 | |
887 | /// Trait to check if `AttrType` provides a `parse` method. |
888 | template <typename AttrType> |
889 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
890 | std::declval<Type>())); |
891 | template <typename AttrType> |
892 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
893 | |
894 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
895 | /// which case the generic parser is invoked. The parsed attribute is |
896 | /// populated in `result` and also added to the specified attribute list with |
897 | /// the specified name. |
898 | template <typename AttrType> |
899 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
900 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
901 | StringRef attrName, NamedAttrList &attrs) { |
902 | SMLoc loc = getCurrentLocation(); |
903 | |
904 | // Parse any kind of attribute. |
905 | Attribute attr; |
906 | if (parseCustomAttributeWithFallback( |
907 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
908 | result = AttrType::parse(*this, type); |
909 | if (!result) |
910 | return failure(); |
911 | return success(); |
912 | })) |
913 | return failure(); |
914 | |
915 | // Check for the right kind of attribute. |
916 | result = attr.dyn_cast<AttrType>(); |
917 | if (!result) |
918 | return emitError(loc, "invalid kind of attribute specified"); |
919 | |
920 | attrs.append(attrName, result); |
921 | return success(); |
922 | } |
923 | |
924 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
925 | template <typename AttrType> |
926 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
927 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
928 | StringRef attrName, NamedAttrList &attrs) { |
929 | return parseAttribute(result, type, attrName, attrs); |
930 | } |
931 | |
932 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
933 | /// which case the generic parser is invoked. The parsed attribute is |
934 | /// populated in `result`. |
935 | template <typename AttrType> |
936 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
937 | parseCustomAttributeWithFallback(AttrType &result) { |
938 | SMLoc loc = getCurrentLocation(); |
939 | |
940 | // Parse any kind of attribute. |
941 | Attribute attr; |
942 | if (parseCustomAttributeWithFallback( |
943 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
944 | result = AttrType::parse(*this, type); |
945 | return success(!!result); |
946 | })) |
947 | return failure(); |
948 | |
949 | // Check for the right kind of attribute. |
950 | result = attr.dyn_cast<AttrType>(); |
951 | if (!result) |
952 | return emitError(loc, "invalid kind of attribute specified"); |
953 | return success(); |
954 | } |
955 | |
956 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
957 | template <typename AttrType> |
958 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
959 | parseCustomAttributeWithFallback(AttrType &result) { |
960 | return parseAttribute(result); |
961 | } |
962 | |
963 | /// Parse an arbitrary optional attribute of a given type and return it in |
964 | /// result. |
965 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
966 | Type type = {}) = 0; |
967 | |
968 | /// Parse an optional array attribute and return it in result. |
969 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
970 | Type type = {}) = 0; |
971 | |
972 | /// Parse an optional string attribute and return it in result. |
973 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
974 | Type type = {}) = 0; |
975 | |
976 | /// Parse an optional symbol ref attribute and return it in result. |
977 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, |
978 | Type type = {}) = 0; |
979 | |
980 | /// Parse an optional attribute of a specific type and add it to the list with |
981 | /// the specified name. |
982 | template <typename AttrType> |
983 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
984 | StringRef attrName, |
985 | NamedAttrList &attrs) { |
986 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
987 | } |
988 | |
989 | /// Parse an optional attribute of a specific type and add it to the list with |
990 | /// the specified name. |
991 | template <typename AttrType> |
992 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
993 | StringRef attrName, |
994 | NamedAttrList &attrs) { |
995 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
996 | if (parseResult.has_value() && succeeded(*parseResult)) |
997 | attrs.append(attrName, result); |
998 | return parseResult; |
999 | } |
1000 | |
1001 | /// Parse a named dictionary into 'result' if it is present. |
1002 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
1003 | |
1004 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
1005 | /// present. |
1006 | virtual ParseResult |
1007 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
1008 | |
1009 | /// Parse an affine map instance into 'map'. |
1010 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
1011 | |
1012 | /// Parse an integer set instance into 'set'. |
1013 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
1014 | |
1015 | //===--------------------------------------------------------------------===// |
1016 | // Identifier Parsing |
1017 | //===--------------------------------------------------------------------===// |
1018 | |
1019 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1020 | /// attribute. |
1021 | ParseResult parseSymbolName(StringAttr &result) { |
1022 | if (failed(parseOptionalSymbolName(result))) |
1023 | return emitError(getCurrentLocation()) |
1024 | << "expected valid '@'-identifier for symbol name"; |
1025 | return success(); |
1026 | } |
1027 | |
1028 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1029 | /// attribute named 'attrName'. |
1030 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1031 | NamedAttrList &attrs) { |
1032 | if (parseSymbolName(result)) |
1033 | return failure(); |
1034 | attrs.append(attrName, result); |
1035 | return success(); |
1036 | } |
1037 | |
1038 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1039 | /// string attribute. |
1040 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; |
1041 | |
1042 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1043 | /// string attribute named 'attrName'. |
1044 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
1045 | NamedAttrList &attrs) { |
1046 | if (succeeded(parseOptionalSymbolName(result))) { |
1047 | attrs.append(attrName, result); |
1048 | return success(); |
1049 | } |
1050 | return failure(); |
1051 | } |
1052 | |
1053 | //===--------------------------------------------------------------------===// |
1054 | // Resource Parsing |
1055 | //===--------------------------------------------------------------------===// |
1056 | |
1057 | /// Parse a handle to a resource within the assembly format. |
1058 | template <typename ResourceT> |
1059 | FailureOr<ResourceT> parseResourceHandle() { |
1060 | SMLoc handleLoc = getCurrentLocation(); |
1061 | |
1062 | // Try to load the dialect that owns the handle. |
1063 | auto *dialect = |
1064 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1065 | if (!dialect) { |
1066 | return emitError(handleLoc) |
1067 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1068 | << "' is unknown"; |
1069 | } |
1070 | |
1071 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1072 | if (failed(handle)) |
1073 | return failure(); |
1074 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1075 | return std::move(*result); |
1076 | return emitError(handleLoc) << "provided resource handle differs from the " |
1077 | "expected resource type"; |
1078 | } |
1079 | |
1080 | //===--------------------------------------------------------------------===// |
1081 | // Type Parsing |
1082 | //===--------------------------------------------------------------------===// |
1083 | |
1084 | /// Parse a type. |
1085 | virtual ParseResult parseType(Type &result) = 0; |
1086 | |
1087 | /// Parse a custom type with the provided callback, unless the next |
1088 | /// token is `#`, in which case the generic parser is invoked. |
1089 | virtual ParseResult parseCustomTypeWithFallback( |
1090 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1091 | |
1092 | /// Parse an optional type. |
1093 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1094 | |
1095 | /// Parse a type of a specific type. |
1096 | template <typename TypeT> |
1097 | ParseResult parseType(TypeT &result) { |
1098 | SMLoc loc = getCurrentLocation(); |
1099 | |
1100 | // Parse any kind of type. |
1101 | Type type; |
1102 | if (parseType(type)) |
1103 | return failure(); |
1104 | |
1105 | // Check for the right kind of type. |
1106 | result = type.dyn_cast<TypeT>(); |
1107 | if (!result) |
1108 | return emitError(loc, "invalid kind of type specified"); |
1109 | |
1110 | return success(); |
1111 | } |
1112 | |
1113 | /// Trait to check if `TypeT` provides a `parse` method. |
1114 | template <typename TypeT> |
1115 | using type_has_parse_method = |
1116 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1117 | template <typename TypeT> |
1118 | using detect_type_has_parse_method = |
1119 | llvm::is_detected<type_has_parse_method, TypeT>; |
1120 | |
1121 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1122 | /// which case the generic parser is invoked. The parsed Type is |
1123 | /// populated in `result`. |
1124 | template <typename TypeT> |
1125 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1126 | parseCustomTypeWithFallback(TypeT &result) { |
1127 | SMLoc loc = getCurrentLocation(); |
1128 | |
1129 | // Parse any kind of Type. |
1130 | Type type; |
1131 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1132 | result = TypeT::parse(*this); |
1133 | return success(!!result); |
1134 | })) |
1135 | return failure(); |
1136 | |
1137 | // Check for the right kind of Type. |
1138 | result = type.dyn_cast<TypeT>(); |
1139 | if (!result) |
1140 | return emitError(loc, "invalid kind of Type specified"); |
1141 | return success(); |
1142 | } |
1143 | |
1144 | /// SFINAE parsing method for Type that don't implement a parse method. |
1145 | template <typename TypeT> |
1146 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1147 | parseCustomTypeWithFallback(TypeT &result) { |
1148 | return parseType(result); |
1149 | } |
1150 | |
1151 | /// Parse a type list. |
1152 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
1153 | return parseCommaSeparatedList( |
1154 | [&]() { return parseType(result.emplace_back()); }); |
1155 | } |
1156 | |
1157 | /// Parse an arrow followed by a type list. |
1158 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1159 | |
1160 | /// Parse an optional arrow followed by a type list. |
1161 | virtual ParseResult |
1162 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1163 | |
1164 | /// Parse a colon followed by a type. |
1165 | virtual ParseResult parseColonType(Type &result) = 0; |
1166 | |
1167 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1168 | template <typename TypeType> |
1169 | ParseResult parseColonType(TypeType &result) { |
1170 | SMLoc loc = getCurrentLocation(); |
1171 | |
1172 | // Parse any kind of type. |
1173 | Type type; |
1174 | if (parseColonType(type)) |
1175 | return failure(); |
1176 | |
1177 | // Check for the right kind of type. |
1178 | result = type.dyn_cast<TypeType>(); |
1179 | if (!result) |
1180 | return emitError(loc, "invalid kind of type specified"); |
1181 | |
1182 | return success(); |
1183 | } |
1184 | |
1185 | /// Parse a colon followed by a type list, which must have at least one type. |
1186 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1187 | |
1188 | /// Parse an optional colon followed by a type list, which if present must |
1189 | /// have at least one type. |
1190 | virtual ParseResult |
1191 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1192 | |
1193 | /// Parse a keyword followed by a type. |
1194 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1195 | return failure(parseKeyword(keyword) || parseType(result)); |
1196 | } |
1197 | |
1198 | /// Add the specified type to the end of the specified type list and return |
1199 | /// success. This is a helper designed to allow parse methods to be simple |
1200 | /// and chain through || operators. |
1201 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1202 | result.push_back(type); |
1203 | return success(); |
1204 | } |
1205 | |
1206 | /// Add the specified types to the end of the specified type list and return |
1207 | /// success. This is a helper designed to allow parse methods to be simple |
1208 | /// and chain through || operators. |
1209 | ParseResult addTypesToList(ArrayRef<Type> types, |
1210 | SmallVectorImpl<Type> &result) { |
1211 | result.append(types.begin(), types.end()); |
1212 | return success(); |
1213 | } |
1214 | |
1215 | /// Parse a dimension list of a tensor or memref type. This populates the |
1216 | /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set |
1217 | /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable. |
1218 | /// |
1219 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1220 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1221 | /// dimension ::= `?` | decimal-literal |
1222 | /// |
1223 | /// When `allowDynamic` is not set, this is used to parse: |
1224 | /// |
1225 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1226 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1227 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1228 | bool allowDynamic = true, |
1229 | bool withTrailingX = true) = 0; |
1230 | |
1231 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1232 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1233 | /// next token. |
1234 | virtual ParseResult parseXInDimensionList() = 0; |
1235 | |
1236 | protected: |
1237 | /// Parse a handle to a resource within the assembly format for the given |
1238 | /// dialect. |
1239 | virtual FailureOr<AsmDialectResourceHandle> |
1240 | parseResourceHandle(Dialect *dialect) = 0; |
1241 | |
1242 | //===--------------------------------------------------------------------===// |
1243 | // Code Completion |
1244 | //===--------------------------------------------------------------------===// |
1245 | |
1246 | /// Parse a keyword, or an empty string if the current location signals a code |
1247 | /// completion. |
1248 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1249 | |
1250 | /// Signal the code completion of a set of expected tokens. |
1251 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1252 | |
1253 | private: |
1254 | AsmParser(const AsmParser &) = delete; |
1255 | void operator=(const AsmParser &) = delete; |
1256 | }; |
1257 | |
1258 | //===----------------------------------------------------------------------===// |
1259 | // OpAsmParser |
1260 | //===----------------------------------------------------------------------===// |
1261 | |
1262 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1263 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1264 | /// that is designed to reduce/constrain syntax innovation in individual |
1265 | /// operations. |
1266 | /// |
1267 | /// For example, consider an op like this: |
1268 | /// |
1269 | /// %x = load %p[%1, %2] : memref<...> |
1270 | /// |
1271 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1272 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1273 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1274 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1275 | /// type. |
1276 | /// |
1277 | class OpAsmParser : public AsmParser { |
1278 | public: |
1279 | using AsmParser::AsmParser; |
1280 | ~OpAsmParser() override; |
1281 | |
1282 | /// Parse a loc(...) specifier if present, filling in result if so. |
1283 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1284 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1285 | /// completes. |
1286 | virtual ParseResult |
1287 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1288 | |
1289 | /// Return the name of the specified result in the specified syntax, as well |
1290 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1291 | /// invalid result numbers. For example, in this operation: |
1292 | /// |
1293 | /// %x, %y:2, %z = foo.op |
1294 | /// |
1295 | /// getResultName(0) == {"x", 0 } |
1296 | /// getResultName(1) == {"y", 0 } |
1297 | /// getResultName(2) == {"y", 1 } |
1298 | /// getResultName(3) == {"z", 0 } |
1299 | /// getResultName(4) == {"", ~0U } |
1300 | virtual std::pair<StringRef, unsigned> |
1301 | getResultName(unsigned resultNo) const = 0; |
1302 | |
1303 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1304 | /// example in the comment for `getResultName`. |
1305 | virtual size_t getNumResults() const = 0; |
1306 | |
1307 | // These methods emit an error and return failure or success. This allows |
1308 | // these to be chained together into a linear sequence of || expressions in |
1309 | // many cases. |
1310 | |
1311 | /// Parse an operation in its generic form. |
1312 | /// The parsed operation is parsed in the current context and inserted in the |
1313 | /// provided block and insertion point. The results produced by this operation |
1314 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1315 | /// failure. |
1316 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1317 | Block::iterator insertPt) = 0; |
1318 | |
1319 | /// Parse the name of an operation, in the custom form. On success, return a |
1320 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1321 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1322 | |
1323 | //===--------------------------------------------------------------------===// |
1324 | // Operand Parsing |
1325 | //===--------------------------------------------------------------------===// |
1326 | |
1327 | /// This is the representation of an operand reference. |
1328 | struct UnresolvedOperand { |
1329 | SMLoc location; // Location of the token. |
1330 | StringRef name; // Value name, e.g. %42 or %abc |
1331 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1332 | }; |
1333 | |
1334 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1335 | /// region(s), attribute(s) and function-type, of the generic form of an |
1336 | /// operation instance and populate the input operation-state 'result' with |
1337 | /// those components. If any of the components is explicitly provided, then |
1338 | /// skip parsing that component. |
1339 | virtual ParseResult parseGenericOperationAfterOpName( |
1340 | OperationState &result, |
1341 | Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt, |
1342 | Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, |
1343 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1344 | std::nullopt, |
1345 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, |
1346 | Optional<FunctionType> parsedFnType = std::nullopt) = 0; |
1347 | |
1348 | /// Parse a single SSA value operand name along with a result number if |
1349 | /// `allowResultNumber` is true. |
1350 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1351 | bool allowResultNumber = true) = 0; |
1352 | |
1353 | /// Parse a single operand if present. |
1354 | virtual OptionalParseResult |
1355 | parseOptionalOperand(UnresolvedOperand &result, |
1356 | bool allowResultNumber = true) = 0; |
1357 | |
1358 | /// Parse zero or more SSA comma-separated operand references with a specified |
1359 | /// surrounding delimiter, and an optional required operand count. |
1360 | virtual ParseResult |
1361 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1362 | Delimiter delimiter = Delimiter::None, |
1363 | bool allowResultNumber = true, |
1364 | int requiredOperandCount = -1) = 0; |
1365 | |
1366 | /// Parse a specified number of comma separated operands. |
1367 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1368 | int requiredOperandCount, |
1369 | Delimiter delimiter = Delimiter::None) { |
1370 | return parseOperandList(result, delimiter, |
1371 | /*allowResultNumber=*/true, requiredOperandCount); |
1372 | } |
1373 | |
1374 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1375 | /// references with a specified surrounding delimiter, and an optional |
1376 | /// required operand count. A leading comma is expected before the |
1377 | /// operands. |
1378 | ParseResult |
1379 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1380 | Delimiter delimiter = Delimiter::None) { |
1381 | if (failed(parseOptionalComma())) |
1382 | return success(); // The comma is optional. |
1383 | return parseOperandList(result, delimiter); |
1384 | } |
1385 | |
1386 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1387 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1388 | Type type, |
1389 | SmallVectorImpl<Value> &result) = 0; |
1390 | |
1391 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1392 | /// appending the results to the list on success. This method should be used |
1393 | /// when all operands have the same type. |
1394 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1395 | ParseResult resolveOperands(Operands &&operands, Type type, |
1396 | SmallVectorImpl<Value> &result) { |
1397 | for (const UnresolvedOperand &operand : operands) |
1398 | if (resolveOperand(operand, type, result)) |
1399 | return failure(); |
1400 | return success(); |
1401 | } |
1402 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1403 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1404 | SmallVectorImpl<Value> &result) { |
1405 | return resolveOperands(std::forward<Operands>(operands), type, result); |
1406 | } |
1407 | |
1408 | /// Resolve a list of operands and a list of operand types to SSA values, |
1409 | /// emitting an error and returning failure, or appending the results |
1410 | /// to the list on success. |
1411 | template <typename Operands = ArrayRef<UnresolvedOperand>, |
1412 | typename Types = ArrayRef<Type>> |
1413 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1414 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1415 | SmallVectorImpl<Value> &result) { |
1416 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1417 | size_t typeSize = std::distance(types.begin(), types.end()); |
1418 | if (operandSize != typeSize) |
1419 | return emitError(loc) |
1420 | << operandSize << " operands present, but expected " << typeSize; |
1421 | |
1422 | for (auto [operand, type] : llvm::zip(operands, types)) |
1423 | if (resolveOperand(operand, type, result)) |
1424 | return failure(); |
1425 | return success(); |
1426 | } |
1427 | |
1428 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1429 | /// Operand values must come from single-result sources, and be valid |
1430 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1431 | virtual ParseResult |
1432 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1433 | Attribute &map, StringRef attrName, |
1434 | NamedAttrList &attrs, |
1435 | Delimiter delimiter = Delimiter::Square) = 0; |
1436 | |
1437 | /// Parses an affine expression where dims and symbols are SSA operands. |
1438 | /// Operand values must come from single-result sources, and be valid |
1439 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1440 | virtual ParseResult |
1441 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1442 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1443 | AffineExpr &expr) = 0; |
1444 | |
1445 | //===--------------------------------------------------------------------===// |
1446 | // Argument Parsing |
1447 | //===--------------------------------------------------------------------===// |
1448 | |
1449 | struct Argument { |
1450 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1451 | Type type; // Type. |
1452 | DictionaryAttr attrs; // Attributes if present. |
1453 | Optional<Location> sourceLoc; // Source location specifier if present. |
1454 | }; |
1455 | |
1456 | /// Parse a single argument with the following syntax: |
1457 | /// |
1458 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1459 | /// |
1460 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1461 | /// parts of the grammar are not parsed. |
1462 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1463 | bool allowAttrs = false) = 0; |
1464 | |
1465 | /// Parse a single argument if present. |
1466 | virtual OptionalParseResult |
1467 | parseOptionalArgument(Argument &result, bool allowType = false, |
1468 | bool allowAttrs = false) = 0; |
1469 | |
1470 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1471 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1472 | Delimiter delimiter = Delimiter::None, |
1473 | bool allowType = false, |
1474 | bool allowAttrs = false) = 0; |
1475 | |
1476 | //===--------------------------------------------------------------------===// |
1477 | // Region Parsing |
1478 | //===--------------------------------------------------------------------===// |
1479 | |
1480 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1481 | /// moved to the op regions after the op is created. The first block of the |
1482 | /// region takes 'arguments'. |
1483 | /// |
1484 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1485 | /// shadow the names of other existing SSA values defined above the region |
1486 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1487 | /// to operations that are 'IsolatedFromAbove'. |
1488 | virtual ParseResult parseRegion(Region ®ion, |
1489 | ArrayRef<Argument> arguments = {}, |
1490 | bool enableNameShadowing = false) = 0; |
1491 | |
1492 | /// Parses a region if present. |
1493 | virtual OptionalParseResult |
1494 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1495 | bool enableNameShadowing = false) = 0; |
1496 | |
1497 | /// Parses a region if present. If the region is present, a new region is |
1498 | /// allocated and placed in `region`. If no region is present or on failure, |
1499 | /// `region` remains untouched. |
1500 | virtual OptionalParseResult |
1501 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1502 | ArrayRef<Argument> arguments = {}, |
1503 | bool enableNameShadowing = false) = 0; |
1504 | |
1505 | //===--------------------------------------------------------------------===// |
1506 | // Successor Parsing |
1507 | //===--------------------------------------------------------------------===// |
1508 | |
1509 | /// Parse a single operation successor. |
1510 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1511 | |
1512 | /// Parse an optional operation successor. |
1513 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1514 | |
1515 | /// Parse a single operation successor and its operand list. |
1516 | virtual ParseResult |
1517 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1518 | |
1519 | //===--------------------------------------------------------------------===// |
1520 | // Type Parsing |
1521 | //===--------------------------------------------------------------------===// |
1522 | |
1523 | /// Parse a list of assignments of the form |
1524 | /// (%x1 = %y1, %x2 = %y2, ...) |
1525 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1526 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1527 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1528 | if (!result.has_value()) |
1529 | return emitError(getCurrentLocation(), "expected '('"); |
1530 | return result.value(); |
1531 | } |
1532 | |
1533 | virtual OptionalParseResult |
1534 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1535 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1536 | }; |
1537 | |
1538 | //===--------------------------------------------------------------------===// |
1539 | // Dialect OpAsm interface. |
1540 | //===--------------------------------------------------------------------===// |
1541 | |
1542 | /// A functor used to set the name of the start of a result group of an |
1543 | /// operation. See 'getAsmResultNames' below for more details. |
1544 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1545 | |
1546 | /// A functor used to set the name of blocks in regions directly nested under |
1547 | /// an operation. |
1548 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1549 | |
1550 | class OpAsmDialectInterface |
1551 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1552 | public: |
1553 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1554 | |
1555 | //===------------------------------------------------------------------===// |
1556 | // Aliases |
1557 | //===------------------------------------------------------------------===// |
1558 | |
1559 | /// Holds the result of `getAlias` hook call. |
1560 | enum class AliasResult { |
1561 | /// The object (type or attribute) is not supported by the hook |
1562 | /// and an alias was not provided. |
1563 | NoAlias, |
1564 | /// An alias was provided, but it might be overriden by other hook. |
1565 | OverridableAlias, |
1566 | /// An alias was provided and it should be used |
1567 | /// (no other hooks will be checked). |
1568 | FinalAlias |
1569 | }; |
1570 | |
1571 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1572 | /// not necessarily a part of this dialect. The identifier is used in place of |
1573 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1574 | /// end with a numeric digit([0-9]+). |
1575 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1576 | return AliasResult::NoAlias; |
1577 | } |
1578 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1579 | return AliasResult::NoAlias; |
1580 | } |
1581 | |
1582 | //===--------------------------------------------------------------------===// |
1583 | // Resources |
1584 | //===--------------------------------------------------------------------===// |
1585 | |
1586 | /// Declare a resource with the given key, returning a handle to use for any |
1587 | /// references of this resource key within the IR during parsing. The result |
1588 | /// of `getResourceKey` on the returned handle is permitted to be different |
1589 | /// than `key`. |
1590 | virtual FailureOr<AsmDialectResourceHandle> |
1591 | declareResource(StringRef key) const { |
1592 | return failure(); |
1593 | } |
1594 | |
1595 | /// Return a key to use for the given resource. This key should uniquely |
1596 | /// identify this resource within the dialect. |
1597 | virtual std::string |
1598 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1599 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600) |
1600 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600); |
1601 | } |
1602 | |
1603 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1604 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1605 | /// can be emitted via the provided entry. |
1606 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1607 | |
1608 | /// Hook for building resources to use during printing. The given `op` may be |
1609 | /// inspected to help determine what information to include. |
1610 | /// `referencedResources` contains all of the resources detected when printing |
1611 | /// 'op'. |
1612 | virtual void |
1613 | buildResources(Operation *op, |
1614 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1615 | AsmResourceBuilder &builder) const {} |
1616 | }; |
1617 | } // namespace mlir |
1618 | |
1619 | //===--------------------------------------------------------------------===// |
1620 | // Operation OpAsm interface. |
1621 | //===--------------------------------------------------------------------===// |
1622 | |
1623 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1624 | #include "mlir/IR/OpAsmInterface.h.inc" |
1625 | |
1626 | namespace llvm { |
1627 | template <> |
1628 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1629 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1630 | return {DenseMapInfo<void *>::getEmptyKey(), |
1631 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1632 | } |
1633 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1634 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1635 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1636 | } |
1637 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1638 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1639 | } |
1640 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1641 | const mlir::AsmDialectResourceHandle &rhs) { |
1642 | return lhs.getResource() == rhs.getResource(); |
1643 | } |
1644 | }; |
1645 | } // namespace llvm |
1646 | |
1647 | #endif |