File: | build/source/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp |
Warning: | line 433, column 41 The left operand of '>=' is a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Tiling.cpp - Implementation of tiling using TilingInterface -------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements the tiling using TilingInterface. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" | |||
14 | ||||
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | |||
16 | #include "mlir/Dialect/Arith/IR/Arith.h" | |||
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" | |||
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" | |||
19 | #include "mlir/Dialect/SCF/Utils/Utils.h" | |||
20 | #include "mlir/Dialect/Tensor/IR/Tensor.h" | |||
21 | #include "mlir/Dialect/Utils/IndexingUtils.h" | |||
22 | #include "mlir/IR/Matchers.h" | |||
23 | #include "mlir/IR/PatternMatch.h" | |||
24 | #include "mlir/Interfaces/DestinationStyleOpInterface.h" | |||
25 | #include "mlir/Interfaces/TilingInterface.h" | |||
26 | #include "llvm/Support/Debug.h" | |||
27 | #include <optional> | |||
28 | ||||
29 | #define DEBUG_TYPE"tile-using-interface" "tile-using-interface" | |||
30 | ||||
31 | using namespace mlir; | |||
32 | ||||
33 | scf::SCFTilingOptions & | |||
34 | scf::SCFTilingOptions::setTileSizes(ArrayRef<int64_t> ts) { | |||
35 | assert(!tileSizeComputationFunction && "tile sizes already set")(static_cast <bool> (!tileSizeComputationFunction && "tile sizes already set") ? void (0) : __assert_fail ("!tileSizeComputationFunction && \"tile sizes already set\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 35 , __extension__ __PRETTY_FUNCTION__)); | |||
36 | SmallVector<int64_t> tileSizes(ts.begin(), ts.end()); | |||
37 | tileSizeComputationFunction = [tileSizes](OpBuilder &b, Operation *op) { | |||
38 | OpBuilder::InsertionGuard guard(b); | |||
39 | b.setInsertionPointToStart( | |||
40 | &op->getParentOfType<func::FuncOp>().getBody().front()); | |||
41 | return llvm::to_vector<4>(map_range(tileSizes, [&](int64_t s) { | |||
42 | Value v = b.create<arith::ConstantIndexOp>(op->getLoc(), s); | |||
43 | return v; | |||
44 | })); | |||
45 | }; | |||
46 | return *this; | |||
47 | } | |||
48 | ||||
49 | /// Helper method to adjust the interchange vector to match the iteration | |||
50 | /// domain. | |||
51 | static SmallVector<int64_t> | |||
52 | fillInterchangeVector(ArrayRef<int64_t> interchangeVector, | |||
53 | size_t iterationDomainSize) { | |||
54 | SmallVector<int64_t> filledVector = llvm::to_vector(interchangeVector); | |||
55 | if (filledVector.size() < iterationDomainSize) { | |||
56 | auto range = llvm::seq<int64_t>(filledVector.size(), iterationDomainSize); | |||
57 | filledVector.append(range.begin(), range.end()); | |||
58 | } | |||
59 | if (filledVector.size() > iterationDomainSize) | |||
60 | filledVector.resize(iterationDomainSize); | |||
61 | return filledVector; | |||
62 | } | |||
63 | ||||
64 | //===----------------------------------------------------------------------===// | |||
65 | // tileUsingSCFForOp implementation. | |||
66 | //===----------------------------------------------------------------------===// | |||
67 | ||||
68 | // Check if `stride` evenly divides the trip count `size - offset`. | |||
69 | static bool tileDividesIterationDomain(Range loopRange) { | |||
70 | std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset); | |||
71 | if (!offsetAsInt) | |||
72 | return false; | |||
73 | std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size); | |||
74 | if (!sizeAsInt) | |||
75 | return false; | |||
76 | std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride); | |||
77 | if (!strideAsInt) | |||
78 | return false; | |||
79 | return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0); | |||
80 | } | |||
81 | ||||
82 | /// Returns the bounded tile size given the current `iv`, `loopRange` and | |||
83 | /// `tileSize`, i.e., `min(tileSize, range.end() - iv)`. | |||
84 | static OpFoldResult getBoundedTileSize(OpBuilder &b, Location loc, | |||
85 | Range loopRange, Value iv, | |||
86 | Value tileSize) { | |||
87 | std::optional<int64_t> ts = getConstantIntValue(tileSize); | |||
88 | if (ts && ts.value() == 1) | |||
89 | return getAsOpFoldResult(tileSize); | |||
90 | ||||
91 | if (tileDividesIterationDomain( | |||
92 | Range{loopRange.offset, loopRange.size, tileSize})) | |||
93 | return tileSize; | |||
94 | ||||
95 | // The tile size to use (to avoid out of bounds access) is minimum of | |||
96 | // `tileSize` and `ub - iv`, where `iv` is the induction variable of the tiled | |||
97 | // loop. | |||
98 | AffineExpr s0, s1, d0; | |||
99 | bindDims(b.getContext(), d0); | |||
100 | bindSymbols(b.getContext(), s0, s1); | |||
101 | AffineMap minMap = AffineMap::get(1, 2, {s0, s1 - d0}, b.getContext()); | |||
102 | Value size = getValueOrCreateConstantIndexOp(b, loc, loopRange.size); | |||
103 | return makeComposedFoldedAffineMin( | |||
104 | b, loc, minMap, SmallVector<OpFoldResult>{iv, tileSize, size}); | |||
105 | } | |||
106 | ||||
107 | /// Generate an empty loop nest that represents the tiled loop nest shell. | |||
108 | /// - `loopRanges` specifies the lb, ub and step of the untiled iteration space. | |||
109 | /// - `tileSizeVals` is the tile sizes to use. Zero represent untiled loops. | |||
110 | /// - In `offsets` and `sizes` return the multi-dimensional offset and size of | |||
111 | /// the | |||
112 | /// tile processed within the inner most loop. | |||
113 | static SmallVector<scf::ForOp> | |||
114 | generateTileLoopNest(OpBuilder &builder, Location loc, | |||
115 | ArrayRef<Range> loopRanges, ArrayRef<Value> tileSizeVals, | |||
116 | SmallVector<OpFoldResult> &offsets, | |||
117 | SmallVector<OpFoldResult> &sizes) { | |||
118 | assert(!loopRanges.empty() && "expected at least one loop range")(static_cast <bool> (!loopRanges.empty() && "expected at least one loop range" ) ? void (0) : __assert_fail ("!loopRanges.empty() && \"expected at least one loop range\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 118 , __extension__ __PRETTY_FUNCTION__)); | |||
119 | assert(loopRanges.size() == tileSizeVals.size() &&(static_cast <bool> (loopRanges.size() == tileSizeVals. size() && "expected as many tile sizes as loop ranges" ) ? void (0) : __assert_fail ("loopRanges.size() == tileSizeVals.size() && \"expected as many tile sizes as loop ranges\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 120 , __extension__ __PRETTY_FUNCTION__)) | |||
120 | "expected as many tile sizes as loop ranges")(static_cast <bool> (loopRanges.size() == tileSizeVals. size() && "expected as many tile sizes as loop ranges" ) ? void (0) : __assert_fail ("loopRanges.size() == tileSizeVals.size() && \"expected as many tile sizes as loop ranges\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 120 , __extension__ __PRETTY_FUNCTION__)); | |||
121 | OpBuilder::InsertionGuard guard(builder); | |||
122 | SmallVector<scf::ForOp> loops; | |||
123 | offsets.resize(loopRanges.size()); | |||
124 | sizes.resize(loopRanges.size()); | |||
125 | ||||
126 | for (auto loopRange : llvm::enumerate(loopRanges)) { | |||
127 | Value offset = | |||
128 | getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().offset); | |||
129 | Value size = | |||
130 | getValueOrCreateConstantIndexOp(builder, loc, loopRange.value().size); | |||
131 | Value tileSize = tileSizeVals[loopRange.index()]; | |||
132 | // No loops if tile size is zero. Set offset and size to the loop | |||
133 | // offset and size. | |||
134 | if (matchPattern(tileSize, m_Zero())) { | |||
135 | offsets[loopRange.index()] = offset; | |||
136 | sizes[loopRange.index()] = size; | |||
137 | continue; | |||
138 | } | |||
139 | ||||
140 | auto loop = builder.create<scf::ForOp>( | |||
141 | loc, offset, size, tileSize, ValueRange{}, | |||
142 | [&](OpBuilder &bodyBuilder, Location bodyLoc, Value iv, | |||
143 | ValueRange /*iterArgs*/) { | |||
144 | sizes[loopRange.index()] = getBoundedTileSize( | |||
145 | bodyBuilder, bodyLoc, loopRange.value(), iv, tileSize); | |||
146 | builder.create<scf::YieldOp>(loc); | |||
147 | }); | |||
148 | offsets[loopRange.index()] = loop.getInductionVar(); | |||
149 | loops.push_back(loop); | |||
150 | builder.setInsertionPoint(loop.getBody()->getTerminator()); | |||
151 | } | |||
152 | return loops; | |||
153 | } | |||
154 | ||||
155 | /// For a value to be yielded (`yieldedValue`) from within a loop nest `loops`, | |||
156 | /// construct the destructive update pattern that inserts the yielded | |||
157 | /// value into a destination tensor provided by `initValue` at offset | |||
158 | /// `tileOffsets` and size `tileSizes`. For example, | |||
159 | /// | |||
160 | /// ```mlir | |||
161 | /// scf.for %iv0 = ... { | |||
162 | /// %0 = tiled_op | |||
163 | /// } | |||
164 | /// ``` | |||
165 | /// | |||
166 | /// is transformed to | |||
167 | /// | |||
168 | /// ```mlir | |||
169 | /// scf.for %iv0 = ... iter_args(%arg = %0) { | |||
170 | /// %1 = tensor.extract_slice %arg | |||
171 | /// %2 = tiled_op | |||
172 | /// %3 = tensor.insert_slice %2 into %arg | |||
173 | /// scf.yield %3 | |||
174 | /// } | |||
175 | /// ``` | |||
176 | /// TODO: This API can be cleaned up by using `SubsetExtractOpInterface`. | |||
177 | static SmallVector<Value> | |||
178 | yieldTiledValues(RewriterBase &rewriter, ValueRange initValues, | |||
179 | ValueRange yieldedValues, | |||
180 | ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, | |||
181 | ArrayRef<SmallVector<OpFoldResult>> tileSizesList, | |||
182 | MutableArrayRef<scf::ForOp> loops) { | |||
183 | NewYieldValueFn yieldValueFn = | |||
184 | [&](OpBuilder &b, Location loc, | |||
185 | ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> { | |||
186 | SmallVector<Value> inserts; | |||
187 | for (const auto &yieldedValue : llvm::enumerate(yieldedValues)) { | |||
188 | ArrayRef<OpFoldResult> tileOffsets = | |||
189 | tileOffsetsList[yieldedValue.index()]; | |||
190 | ArrayRef<OpFoldResult> tileSizes = tileSizesList[yieldedValue.index()]; | |||
191 | SmallVector<OpFoldResult> tileStrides(tileOffsets.size(), | |||
192 | b.getIndexAttr(1)); | |||
193 | Value insert = b.create<tensor::InsertSliceOp>( | |||
194 | loc, yieldedValue.value(), newBBArgs[yieldedValue.index()], | |||
195 | tileOffsets, tileSizes, tileStrides); | |||
196 | inserts.push_back(insert); | |||
197 | } | |||
198 | return inserts; | |||
199 | }; | |||
200 | ||||
201 | SmallVector<scf::ForOp> newLoops = | |||
202 | replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn, | |||
203 | /*replaceIterOperandsUsesInLoop =*/false); | |||
204 | for (const auto &loop : llvm::enumerate(loops)) { | |||
205 | rewriter.eraseOp(loop.value()); | |||
206 | loops[loop.index()] = newLoops[loop.index()]; | |||
207 | } | |||
208 | return llvm::to_vector(llvm::map_range( | |||
209 | loops.front().getResults().take_back(yieldedValues.size()), | |||
210 | [](OpResult r) -> Value { return r; })); | |||
211 | } | |||
212 | ||||
213 | /// If the tiled operation is destination passing style, update the | |||
214 | /// slice of the destination used (which refers to the untiled destination) | |||
215 | /// to use the corresponding region argument of the innermost loop. | |||
216 | /// | |||
217 | /// ```mlir | |||
218 | /// %0 = | |||
219 | /// scf.for %iv0 = ... iter_args(%arg = %0) { | |||
220 | /// %1 = tensor.extract_slice %0 | |||
221 | /// %2 = tiled_op | |||
222 | /// %3 = tensor.insert_slice %2 into %arg | |||
223 | /// scf.yield %3 | |||
224 | /// } | |||
225 | /// ``` | |||
226 | /// | |||
227 | /// is transformed to | |||
228 | /// | |||
229 | /// ```mlir | |||
230 | /// scf.for %iv0 = ... iter_args(%arg = %0) { | |||
231 | /// %1 = tensor.extract_slice %arg | |||
232 | /// %2 = tiled_op | |||
233 | /// %3 = tensor.insert_slice %2 into %arg | |||
234 | /// scf.yield %3 | |||
235 | /// } | |||
236 | /// ``` | |||
237 | static void | |||
238 | updateDestinationOperandsForTiledOp(OpBuilder &builder, | |||
239 | ValueRange tiledOpDestinationValues, | |||
240 | ValueRange bbArgsList) { | |||
241 | for (const auto &destValue : llvm::enumerate(tiledOpDestinationValues)) { | |||
242 | auto sliceOp = destValue.value().getDefiningOp<tensor::ExtractSliceOp>(); | |||
243 | if (!sliceOp) | |||
244 | continue; | |||
245 | sliceOp.setOperand(0, bbArgsList[destValue.index()]); | |||
246 | } | |||
247 | } | |||
248 | ||||
249 | /// Helper method to yield the values of the tiled op, as well as | |||
250 | /// update the destination operands of the tiled op, if it is | |||
251 | /// a destination passing style op. | |||
252 | static SmallVector<Value> | |||
253 | yieldTiledValues(RewriterBase &rewriter, ArrayRef<Value> initValues, | |||
254 | Operation *tiledOp, | |||
255 | ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList, | |||
256 | ArrayRef<SmallVector<OpFoldResult>> tileSizesList, | |||
257 | MutableArrayRef<scf::ForOp> loops) { | |||
258 | SmallVector<Value> replacements = | |||
259 | yieldTiledValues(rewriter, initValues, tiledOp->getResults(), | |||
260 | tileOffsetsList, tileSizesList, loops); | |||
261 | if (auto dstOp = dyn_cast<DestinationStyleOpInterface>(tiledOp)) { | |||
262 | auto innerMostLoop = loops.back(); | |||
263 | SmallVector<Value> tiledOpDestinationTensors = dstOp.getDpsInitOperands(); | |||
264 | updateDestinationOperandsForTiledOp(rewriter, tiledOpDestinationTensors, | |||
265 | innerMostLoop.getRegionIterArgs()); | |||
266 | } | |||
267 | return replacements; | |||
268 | } | |||
269 | ||||
270 | /// Implementation of tiling transformation of `op` that implements the | |||
271 | /// `TilingInterface` using `scf.for` to iterate over the tiles. | |||
272 | FailureOr<scf::SCFTilingResult> | |||
273 | mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, | |||
274 | const scf::SCFTilingOptions &options) { | |||
275 | OpBuilder::InsertionGuard guard(rewriter); | |||
276 | rewriter.setInsertionPointAfter(op); | |||
277 | ||||
278 | if (!options.tileSizeComputationFunction) { | |||
279 | return rewriter.notifyMatchFailure( | |||
280 | op, "missing tile size computation function"); | |||
281 | } | |||
282 | ||||
283 | // 1. Get the range of the loops that are represented by the operation. | |||
284 | SmallVector<Range> iterationDomain = op.getIterationDomain(rewriter); | |||
285 | size_t numLoops = iterationDomain.size(); | |||
286 | if (numLoops == 0) { | |||
287 | return rewriter.notifyMatchFailure( | |||
288 | op, "unable to tile op with no iteration domain"); | |||
289 | } | |||
290 | ||||
291 | // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" | |||
292 | // skips tiling a particular dimension. This convention is significantly | |||
293 | // simpler to handle instead of adjusting affine maps to account for missing | |||
294 | // dimensions. | |||
295 | SmallVector<Value> tileSizeVector = | |||
296 | options.tileSizeComputationFunction(rewriter, op); | |||
297 | if (tileSizeVector.size() < iterationDomain.size()) { | |||
298 | auto zero = rewriter.create<arith::ConstantIndexOp>(op.getLoc(), 0); | |||
299 | tileSizeVector.append(numLoops - tileSizeVector.size(), zero); | |||
300 | } | |||
301 | ||||
302 | scf::SCFTilingResult tilingResult; | |||
303 | SmallVector<OpFoldResult> offsets, sizes; | |||
304 | { | |||
305 | // If there is an interchange specified, permute the iteration domain and | |||
306 | // the tile sizes. | |||
307 | SmallVector<int64_t> interchangeVector; | |||
308 | if (!options.interchangeVector.empty()) { | |||
309 | interchangeVector = fillInterchangeVector(options.interchangeVector, | |||
310 | iterationDomain.size()); | |||
311 | } | |||
312 | if (!interchangeVector.empty()) { | |||
313 | if (!isPermutationVector(interchangeVector)) { | |||
314 | return rewriter.notifyMatchFailure( | |||
315 | op, "invalid intechange vector, not a permutation of the entire " | |||
316 | "iteration space"); | |||
317 | } | |||
318 | ||||
319 | applyPermutationToVector(iterationDomain, interchangeVector); | |||
320 | applyPermutationToVector(tileSizeVector, interchangeVector); | |||
321 | } | |||
322 | ||||
323 | // 3. Materialize an empty loop nest that iterates over the tiles. These | |||
324 | // loops for now do not return any values even if the original operation has | |||
325 | // results. | |||
326 | tilingResult.loops = generateTileLoopNest( | |||
327 | rewriter, op.getLoc(), iterationDomain, tileSizeVector, offsets, sizes); | |||
328 | ||||
329 | if (!interchangeVector.empty()) { | |||
330 | auto inversePermutation = invertPermutationVector(interchangeVector); | |||
331 | applyPermutationToVector(offsets, inversePermutation); | |||
332 | applyPermutationToVector(sizes, inversePermutation); | |||
333 | } | |||
334 | } | |||
335 | ||||
336 | LLVM_DEBUG({do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
337 | if (!tilingResult.loops.empty()) {do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
338 | llvm::dbgs() << "LoopNest shell :\n";do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
339 | tilingResult.loops.front().dump();do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
340 | llvm::dbgs() << "\n";do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
341 | }do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
342 | })do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "LoopNest shell :\n"; tilingResult. loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false); | |||
343 | ||||
344 | // 4. Generate the tiled implementation within the inner most loop. | |||
345 | if (!tilingResult.loops.empty()) | |||
346 | rewriter.setInsertionPoint( | |||
347 | tilingResult.loops.back().getBody()->getTerminator()); | |||
348 | SmallVector<Operation *> tiledImplementation = | |||
349 | op.getTiledImplementation(rewriter, offsets, sizes); | |||
350 | tilingResult.tiledOps.append(tiledImplementation); | |||
351 | if (op->getNumResults() == 0) { | |||
352 | // nothing more to do. | |||
353 | return tilingResult; | |||
354 | } | |||
355 | ||||
356 | // If loops are empty, the tiled op is used as the replacement for the untiled | |||
357 | // op. | |||
358 | if (tilingResult.loops.empty()) { | |||
359 | tilingResult.replacements = llvm::to_vector( | |||
360 | llvm::map_range(tiledImplementation[0]->getResults(), | |||
361 | [](OpResult result) -> Value { return result; })); | |||
362 | return tilingResult; | |||
363 | } | |||
364 | ||||
365 | // 5. Yield all the results of the tiled operation. The surrounding loop | |||
366 | // nest is modified to insert a destructive update pattern to yield | |||
367 | // from the loop nest values to replace the untiled op with. | |||
368 | int64_t numResults = op->getNumResults(); | |||
369 | SmallVector<SmallVector<OpFoldResult>> resultOffsetsList(numResults), | |||
370 | resultSizesList(numResults); | |||
371 | for (const auto &result : llvm::enumerate(op->getResults())) { | |||
372 | if (failed(op.getResultTilePosition(rewriter, result.index(), offsets, | |||
373 | sizes, | |||
374 | resultOffsetsList[result.index()], | |||
375 | resultSizesList[result.index()]))) { | |||
376 | return rewriter.notifyMatchFailure( | |||
377 | op, "failed to get slice of result produced"); | |||
378 | } | |||
379 | } | |||
380 | ||||
381 | SmallVector<Value> destinationTensors; | |||
382 | if (failed(tensor::getOrCreateDestinations(rewriter, op.getLoc(), op, | |||
383 | destinationTensors))) | |||
384 | return rewriter.notifyMatchFailure(op, "failed to get destinations"); | |||
385 | ||||
386 | tilingResult.replacements = yieldTiledValues( | |||
387 | rewriter, destinationTensors, tilingResult.tiledOps.back(), | |||
388 | resultOffsetsList, resultSizesList, tilingResult.loops); | |||
389 | ||||
390 | LLVM_DEBUG({do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
391 | if (!tilingResult.loops.empty()) {do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
392 | llvm::dbgs() << "After tiled implementation :\n";do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
393 | tilingResult.loops.front().dump();do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
394 | llvm::dbgs() << "\n";do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
395 | }do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false) | |||
396 | })do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("tile-using-interface")) { { if (!tilingResult.loops.empty() ) { llvm::dbgs() << "After tiled implementation :\n"; tilingResult .loops.front().dump(); llvm::dbgs() << "\n"; } }; } } while (false); | |||
397 | return tilingResult; | |||
398 | } | |||
399 | ||||
400 | FailureOr<scf::SCFReductionTilingResult> | |||
401 | mlir::scf::tileReductionUsingScf(PatternRewriter &b, | |||
402 | PartialReductionOpInterface op, | |||
403 | ArrayRef<OpFoldResult> tileSize) { | |||
404 | Location loc = op.getLoc(); | |||
405 | // Ops implementing PartialReductionOpInterface are expected to implement | |||
406 | // TilingInterface. | |||
407 | auto tilingInterfaceOp = cast<TilingInterface>(op.getOperation()); | |||
408 | SmallVector<Range> iterationDomain = tilingInterfaceOp.getIterationDomain(b); | |||
409 | SmallVector<Value> tileSizeVector = | |||
410 | getValueOrCreateConstantIndexOp(b, loc, tileSize); | |||
411 | if (tileSizeVector.size() < iterationDomain.size()) { | |||
| ||||
412 | auto zero = b.create<arith::ConstantIndexOp>(loc, 0); | |||
413 | tileSizeVector.append(iterationDomain.size() - tileSizeVector.size(), zero); | |||
414 | } | |||
415 | if (op->getNumResults() != 1) | |||
416 | return b.notifyMatchFailure( | |||
417 | op, "don't support ops with multiple results for now"); | |||
418 | SmallVector<utils::IteratorType> iterators = | |||
419 | tilingInterfaceOp.getLoopIteratorTypes(); | |||
420 | int64_t numReductionDims = llvm::count( | |||
421 | tilingInterfaceOp.getLoopIteratorTypes(), utils::IteratorType::reduction); | |||
422 | if (numReductionDims != 1) | |||
423 | return b.notifyMatchFailure( | |||
424 | op, "only support ops with one reduction dimension."); | |||
425 | int reductionDim; | |||
426 | for (auto &[idx, iteratorType] : | |||
427 | llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) { | |||
428 | if (iteratorType == utils::IteratorType::reduction) { | |||
429 | reductionDim = idx; | |||
430 | break; | |||
431 | } | |||
432 | } | |||
433 | if (static_cast<size_t>(reductionDim) >= tileSize.size()) | |||
| ||||
434 | return b.notifyMatchFailure(op, "reduction dimension must be tiled"); | |||
435 | ||||
436 | // 1. create the inital tensor value. | |||
437 | FailureOr<Operation *> identityTensor = | |||
438 | op.generateInitialTensorForPartialReduction(b, loc, tileSize, | |||
439 | reductionDim); | |||
440 | if (failed(identityTensor)) | |||
441 | return b.notifyMatchFailure(op, | |||
442 | "cannot create a tensor of identity value."); | |||
443 | // 2. Create the nested loops. | |||
444 | SmallVector<OpFoldResult> offsets, sizes; | |||
445 | SmallVector<scf::ForOp> loops = generateTileLoopNest( | |||
446 | b, loc, iterationDomain, tileSizeVector, offsets, sizes); | |||
447 | ||||
448 | // 3. Generate the tiled implementation within the inner most loop. | |||
449 | b.setInsertionPoint(loops.back().getBody()->getTerminator()); | |||
450 | Operation *parallelOp = op.tileToPartialReduction( | |||
451 | b, loc, (*identityTensor)->getResults(), offsets, sizes, reductionDim); | |||
452 | ||||
453 | SmallVector<OpFoldResult> resultSizesList; | |||
454 | for (size_t i = 0; i < offsets.size(); i++) | |||
455 | resultSizesList.push_back( | |||
456 | b.createOrFold<tensor::DimOp>(loc, parallelOp->getResult(0), i)); | |||
457 | SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); | |||
458 | SmallVector<Value> replacements = yieldTiledValues( | |||
459 | b, (*identityTensor)->getResults(), parallelOp->getResults(), outOffsets, | |||
460 | resultSizesList, loops); | |||
461 | ||||
462 | auto dstOp = cast<DestinationStyleOpInterface>(parallelOp); | |||
463 | auto innerMostLoop = loops.back(); | |||
464 | SmallVector<Value> destinationTensors = dstOp.getDpsInitOperands(); | |||
465 | assert(destinationTensors.size() ==(static_cast <bool> (destinationTensors.size() == innerMostLoop .getRegionIterArgs().size() && "unexpected number of outputs" ) ? void (0) : __assert_fail ("destinationTensors.size() == innerMostLoop.getRegionIterArgs().size() && \"unexpected number of outputs\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 467 , __extension__ __PRETTY_FUNCTION__)) | |||
466 | innerMostLoop.getRegionIterArgs().size() &&(static_cast <bool> (destinationTensors.size() == innerMostLoop .getRegionIterArgs().size() && "unexpected number of outputs" ) ? void (0) : __assert_fail ("destinationTensors.size() == innerMostLoop.getRegionIterArgs().size() && \"unexpected number of outputs\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 467 , __extension__ __PRETTY_FUNCTION__)) | |||
467 | "unexpected number of outputs")(static_cast <bool> (destinationTensors.size() == innerMostLoop .getRegionIterArgs().size() && "unexpected number of outputs" ) ? void (0) : __assert_fail ("destinationTensors.size() == innerMostLoop.getRegionIterArgs().size() && \"unexpected number of outputs\"" , "mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp", 467 , __extension__ __PRETTY_FUNCTION__)); | |||
468 | updateDestinationOperandsForTiledOp(b, destinationTensors, | |||
469 | innerMostLoop.getRegionIterArgs()); | |||
470 | ||||
471 | // 4. Apply the merge reduction to combine all the partial values. | |||
472 | b.setInsertionPointAfter(*loops.begin()); | |||
473 | Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDim); | |||
474 | b.replaceOp(op, mergeOp->getResults()); | |||
475 | ||||
476 | SCFReductionTilingResult results; | |||
477 | results.initialOp = *identityTensor; | |||
478 | results.loops = std::move(loops); | |||
479 | results.parallelTiledOp = parallelOp; | |||
480 | results.mergeOp = mergeOp; | |||
481 | return results; | |||
482 | } | |||
483 | //===----------------------------------------------------------------------===// | |||
484 | // tileConsumerAndFuseProducerGreedilyUsingSCFForOp implementation. | |||
485 | //===----------------------------------------------------------------------===// | |||
486 | ||||
487 | /// Return the untiled producer whose slice is used in a tiled consumer. The | |||
488 | /// method traverses the tile loop nest (`loops`) if needed, and returns the | |||
489 | /// `iter_args` of the outer most that is encountered. Traversing the iter_args | |||
490 | /// indicates that this is a destination operand of the consumer. If there was | |||
491 | /// no loop traversal needed, the second value of the returned tuple is empty. | |||
492 | static std::tuple<OpResult, std::optional<OpOperand *>> | |||
493 | getUntiledProducerFromSliceSource(OpOperand *source, | |||
494 | ArrayRef<scf::ForOp> loops) { | |||
495 | std::optional<OpOperand *> destinationIterArg; | |||
496 | auto loopIt = loops.rbegin(); | |||
497 | while (auto iterArg = source->get().dyn_cast<BlockArgument>()) { | |||
498 | scf::ForOp loop = *loopIt; | |||
499 | if (iterArg.getOwner()->getParentOp() != loop) | |||
500 | break; | |||
501 | source = &loop.getOpOperandForRegionIterArg(iterArg); | |||
502 | loopIt++; | |||
503 | } | |||
504 | if (loopIt == loops.rend()) | |||
505 | destinationIterArg = source; | |||
506 | return {source->get().dyn_cast<OpResult>(), destinationIterArg}; | |||
507 | } | |||
508 | ||||
509 | /// Implementation of fusing producer of a single slice by computing the | |||
510 | /// slice of the producer in-place. | |||
511 | std::optional<scf::SCFFuseProducerOfSliceResult> | |||
512 | mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter, | |||
513 | tensor::ExtractSliceOp candidateSliceOp, | |||
514 | MutableArrayRef<scf::ForOp> loops) { | |||
515 | // 1. Get the producer of the source (potentially walking through | |||
516 | // `iter_args` of nested `scf.for`) | |||
517 | auto [fusableProducer, destinationIterArg] = | |||
518 | getUntiledProducerFromSliceSource(&candidateSliceOp->getOpOperand(0), | |||
519 | loops); | |||
520 | if (!fusableProducer) | |||
521 | return std::nullopt; | |||
522 | ||||
523 | // 2. Generate the tiled implementation of the producer of the source | |||
524 | OpBuilder::InsertionGuard g(rewriter); | |||
525 | rewriter.setInsertionPoint(candidateSliceOp); | |||
526 | FailureOr<Value> fusedProducerValue = | |||
527 | tensor::replaceExtractSliceWithTiledProducer(rewriter, candidateSliceOp, | |||
528 | fusableProducer); | |||
529 | if (failed(fusedProducerValue)) | |||
530 | return std::nullopt; | |||
531 | rewriter.replaceAllUsesWith(candidateSliceOp, fusedProducerValue.value()); | |||
532 | ||||
533 | // 3. If the slice is for a destination operand, for example, | |||
534 | // | |||
535 | // ```mlir | |||
536 | // %0 = linalg.init | |||
537 | // %1 = linalg.fill .. outs(%0 : ) | |||
538 | // %2 = scf.for .. iter_args(%arg0 = %1) { | |||
539 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { | |||
540 | // %4 = tensor.extract_slice %arg1 [..] | |||
541 | // .. = linalg.matmul .. outs(%4 : ) | |||
542 | // } | |||
543 | // } | |||
544 | // ``` | |||
545 | // | |||
546 | // the IR is currently | |||
547 | // | |||
548 | // ``` | |||
549 | // %0 = linalg.init | |||
550 | // %1 = linalg.fill | |||
551 | // %2 = scf.for .. iter_args(%arg0 = %1 /* incorrect value */ ) { | |||
552 | // %3 = scf.for .. iter_args(%arg1 = %arg0) { | |||
553 | // %4 = tensor.extract_slice %0 /*incorrect value */ [..] | |||
554 | // %5 = linalg.fill .. outs(%4 : ) | |||
555 | // .. = linalg.matmul .. outs(%5 : ) | |||
556 | // } | |||
557 | // } | |||
558 | // ``` | |||
559 | // | |||
560 | // The untiled `linalg.fill` is still used as the `init_value` since it | |||
561 | // was originally a destination operand of the untiled `linalg.matmul`. | |||
562 | // When fusing an operand that is a destination operand. | |||
563 | // - Update the iter_arg of the outer most loop to use the destination | |||
564 | // of the untiled producer. | |||
565 | // - Update the destination of the slice of the tiled producer generated | |||
566 | // to use the same basic block argument as the slice that was used to | |||
567 | // generate inplace the tiled implementation of the producer. | |||
568 | // With this the IR will be. | |||
569 | // | |||
570 | // ``` | |||
571 | // %0 = linalg.init | |||
572 | // %1 = scf.for .. iter_args(%arg0 = %0 /* corrected value */ ) { | |||
573 | // %2 = scf.for .. iter_args(%arg1 = %arg0) { | |||
574 | // %3 = tensor.extract_slice %arg1 /* corrected value */ [..] | |||
575 | // %4 = linalg.fill .. outs(%3 : ) | |||
576 | // .. = linalg.matmul .. outs(%4 : ) | |||
577 | // } | |||
578 | // } | |||
579 | // ``` | |||
580 | // TODO: This can be modeled better if the `DestinationStyleOpInterface`. | |||
581 | // Update to use that when it does become available. | |||
582 | scf::ForOp outerMostLoop = loops.front(); | |||
583 | std::optional<unsigned> iterArgNumber; | |||
584 | if (destinationIterArg) { | |||
585 | iterArgNumber = | |||
586 | outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value()); | |||
587 | } | |||
588 | if (iterArgNumber) { | |||
589 | int64_t resultNumber = fusableProducer.getResultNumber(); | |||
590 | if (auto dstOp = | |||
591 | dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) { | |||
592 | outerMostLoop.setIterArg(iterArgNumber.value(), | |||
593 | dstOp.getTiedOpOperand(fusableProducer)->get()); | |||
594 | } | |||
595 | if (auto dstOp = fusedProducerValue.value() | |||
596 | .getDefiningOp<DestinationStyleOpInterface>()) { | |||
597 | scf::ForOp innerMostLoop = loops.back(); | |||
598 | updateDestinationOperandsForTiledOp( | |||
599 | rewriter, dstOp.getDpsInitOperand(resultNumber)->get(), | |||
600 | innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]); | |||
601 | } | |||
602 | } | |||
603 | return scf::SCFFuseProducerOfSliceResult{fusableProducer, | |||
604 | fusedProducerValue.value()}; | |||
605 | } | |||
606 | ||||
607 | /// Reconstruct the fused producer from within the tiled-and-fused code. | |||
608 | void mlir::scf::yieldReplacementForFusedProducer( | |||
609 | RewriterBase &rewriter, tensor::ExtractSliceOp sliceOp, | |||
610 | scf::SCFFuseProducerOfSliceResult fusedProducerInfo, | |||
611 | MutableArrayRef<scf::ForOp> loops) { | |||
612 | auto [fusableProducer, fusedProducerValue] = fusedProducerInfo; | |||
613 | SmallVector<Value> initValues; | |||
614 | FailureOr<Value> initValue = tensor::getOrCreateDestination( | |||
615 | rewriter, fusableProducer.getOwner()->getLoc(), fusableProducer); | |||
616 | if (succeeded(initValue)) { | |||
617 | SmallVector<OpFoldResult> resultOffsets = sliceOp.getMixedOffsets(); | |||
618 | SmallVector<OpFoldResult> resultSizes = sliceOp.getMixedSizes(); | |||
619 | SmallVector<Value> yieldedVals = | |||
620 | yieldTiledValues(rewriter, initValue.value(), fusedProducerValue, | |||
621 | resultOffsets, resultSizes, loops); | |||
622 | } | |||
623 | if (auto dstStyleProducer = | |||
624 | fusedProducerValue.getDefiningOp<DestinationStyleOpInterface>()) { | |||
625 | Value dstValue = | |||
626 | dstStyleProducer.getDpsInitOperand(fusableProducer.getResultNumber()) | |||
627 | ->get(); | |||
628 | updateDestinationOperandsForTiledOp( | |||
629 | rewriter, dstValue, loops.back().getRegionIterArgs().back()); | |||
630 | } | |||
631 | } | |||
632 | ||||
633 | /// Implementation of tile consumer and fuse producer greedily. | |||
634 | FailureOr<scf::SCFTileAndFuseResult> | |||
635 | mlir::scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( | |||
636 | RewriterBase &rewriter, TilingInterface consumer, | |||
637 | const scf::SCFTileAndFuseOptions &options) { | |||
638 | // This transformation is only valid for ops that return values (i.e. not | |||
639 | // valid to use with operations that have memref operands). | |||
640 | if (!consumer->getNumResults()) { | |||
641 | return rewriter.notifyMatchFailure( | |||
642 | consumer, "invalid pattern for op with no results"); | |||
643 | } | |||
644 | ||||
645 | // 1. First tile the consumer. | |||
646 | scf::SCFTileAndFuseResult tileAndFuseResult; | |||
647 | llvm::SmallDenseMap<Value, int64_t> yieldedValueToResultNumber; | |||
648 | { | |||
649 | FailureOr<scf::SCFTilingResult> tilingResult = | |||
650 | tileUsingSCFForOp(rewriter, consumer, options.tilingOptions); | |||
651 | if (failed(tilingResult)) | |||
652 | return rewriter.notifyMatchFailure(consumer, "failed to tile consumer"); | |||
653 | for (auto *tiledOp : tilingResult->tiledOps) | |||
654 | tileAndFuseResult.tiledAndFusedOps.insert(tiledOp); | |||
655 | tileAndFuseResult.loops = std::move(tilingResult->loops); | |||
656 | for (const auto &result : llvm::enumerate( | |||
657 | llvm::zip(consumer->getResults(), tilingResult->replacements))) { | |||
658 | tileAndFuseResult.replacements[std::get<0>(result.value())] = | |||
659 | std::get<1>(result.value()); | |||
660 | yieldedValueToResultNumber[tilingResult->tiledOps.back()->getResult( | |||
661 | result.index())] = result.index(); | |||
662 | } | |||
663 | } | |||
664 | ||||
665 | // If there are no loops generated, fusion is immaterial. | |||
666 | if (tileAndFuseResult.loops.empty()) | |||
667 | return tileAndFuseResult; | |||
668 | ||||
669 | // 2. Typically, the operands of the tiled operation are slices of the | |||
670 | // operands of the untiled operation. These are expressed in IR using | |||
671 | // `tensor.extract_slice` operations with source being the operands of the | |||
672 | // untiled operation. Create a worklist of these `tensor.extract_slice` | |||
673 | // operations. If the producers of the source of the `tensor.extract_slice` | |||
674 | // can be tiled such that the tiled value is generated in-place, that | |||
675 | // effectively tiles + fuses the operations. | |||
676 | auto addCandidateSlices = [](Operation *fusedOp, | |||
677 | std::deque<tensor::ExtractSliceOp> &candidates) { | |||
678 | for (Value operand : fusedOp->getOperands()) | |||
679 | if (auto sliceOp = operand.getDefiningOp<tensor::ExtractSliceOp>()) | |||
680 | candidates.push_back(sliceOp); | |||
681 | }; | |||
682 | ||||
683 | std::deque<tensor::ExtractSliceOp> candidates; | |||
684 | addCandidateSlices(tileAndFuseResult.tiledAndFusedOps.back(), candidates); | |||
685 | OpBuilder::InsertionGuard g(rewriter); | |||
686 | while (!candidates.empty()) { | |||
687 | // Traverse the slices in BFS fashion. | |||
688 | tensor::ExtractSliceOp candidateSliceOp = candidates.front(); | |||
689 | candidates.pop_front(); | |||
690 | ||||
691 | // The operands of the fused producer might themselved be slices of | |||
692 | // values produced by operations that implement the `TilingInterface`. | |||
693 | // Add these operations to the worklist. | |||
694 | std::optional<scf::SCFFuseProducerOfSliceResult> fusedProducer = | |||
695 | tileAndFuseProducerOfSlice(rewriter, candidateSliceOp, | |||
696 | tileAndFuseResult.loops); | |||
697 | if (!fusedProducer) | |||
698 | continue; | |||
699 | ||||
700 | if (Operation *tiledAndFusedOp = | |||
701 | fusedProducer->tiledAndFusedProducer.getDefiningOp()) { | |||
702 | tileAndFuseResult.tiledAndFusedOps.insert(tiledAndFusedOp); | |||
703 | addCandidateSlices(tiledAndFusedOp, candidates); | |||
704 | } | |||
705 | } | |||
706 | return tileAndFuseResult; | |||
707 | } | |||
708 | ||||
709 | //===----------------------------------------------------------------------===// | |||
710 | // lowerToLoopsUsingSCFForOp implementation. | |||
711 | //===----------------------------------------------------------------------===// | |||
712 | ||||
713 | FailureOr<SmallVector<scf::ForOp>> | |||
714 | mlir::scf::lowerToLoopsUsingSCFForOp(RewriterBase &rewriter, | |||
715 | TilingInterface op) { | |||
716 | // TODO: Handle cases where the op has results if needed. | |||
717 | if (op->getNumResults() > 0) { | |||
718 | return rewriter.notifyMatchFailure( | |||
719 | op, "unable to lower to loops operations with return values"); | |||
720 | } | |||
721 | ||||
722 | SmallVector<Range> domain = op.getIterationDomain(rewriter); | |||
723 | SmallVector<Value> ivs; | |||
724 | SmallVector<scf::ForOp> loops; | |||
725 | Location loc = op.getLoc(); | |||
726 | for (auto loopRange : domain) { | |||
727 | Value offsetVal = | |||
728 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.offset); | |||
729 | Value sizeVal = | |||
730 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.size); | |||
731 | Value strideVal = | |||
732 | getValueOrCreateConstantIndexOp(rewriter, loc, loopRange.stride); | |||
733 | auto loop = rewriter.create<scf::ForOp>(op.getLoc(), offsetVal, sizeVal, | |||
734 | strideVal, ValueRange{}); | |||
735 | loops.push_back(loop); | |||
736 | ivs.push_back(loop.getInductionVar()); | |||
737 | rewriter.setInsertionPoint(loop.getBody()->getTerminator()); | |||
738 | } | |||
739 | if (failed(op.generateScalarImplementation(rewriter, op.getLoc(), ivs))) { | |||
740 | return failure(); | |||
741 | } | |||
742 | return loops; | |||
743 | } |