File: | build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp |
Warning: | line 763, column 36 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- Sparsification.cpp - Implementation of sparsification --------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements converting sparse tensor types to actual sparse code. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "CodegenUtils.h" | |||
14 | ||||
15 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | |||
16 | #include "mlir/Dialect/Arith/IR/Arith.h" | |||
17 | #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" | |||
18 | #include "mlir/Dialect/Bufferization/IR/Bufferization.h" | |||
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" | |||
20 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h" | |||
21 | #include "mlir/Dialect/Linalg/IR/Linalg.h" | |||
22 | #include "mlir/Dialect/Linalg/Utils/Utils.h" | |||
23 | #include "mlir/Dialect/MemRef/IR/MemRef.h" | |||
24 | #include "mlir/Dialect/SCF/IR/SCF.h" | |||
25 | #include "mlir/Dialect/SCF/Transforms/Transforms.h" | |||
26 | #include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" | |||
27 | #include "mlir/Dialect/SparseTensor/Transforms/Passes.h" | |||
28 | #include "mlir/Dialect/SparseTensor/Utils/Merger.h" | |||
29 | #include "mlir/Dialect/Tensor/IR/Tensor.h" | |||
30 | #include "mlir/Dialect/Vector/IR/VectorOps.h" | |||
31 | #include "mlir/IR/Matchers.h" | |||
32 | #include "mlir/IR/TensorEncoding.h" | |||
33 | #include "llvm/ADT/SmallBitVector.h" | |||
34 | ||||
35 | using namespace mlir; | |||
36 | using namespace mlir::sparse_tensor; | |||
37 | ||||
38 | //===----------------------------------------------------------------------===// | |||
39 | // Declarations of data structures. | |||
40 | //===----------------------------------------------------------------------===// | |||
41 | ||||
42 | namespace { | |||
43 | ||||
44 | // Iteration graph sorting. | |||
45 | enum SortMask { | |||
46 | kSparseOnly = 0x0, | |||
47 | kIncludeDense = 0x1, | |||
48 | kIncludeUndef = 0x2, | |||
49 | kIncludeAll = 0x3 | |||
50 | }; | |||
51 | ||||
52 | // Reduction kinds. | |||
53 | enum Reduction { kNoReduc, kSum, kProduct, kAnd, kOr, kXor, kCustom }; | |||
54 | ||||
55 | // Code generation. | |||
56 | struct CodeGen { | |||
57 | CodeGen(SparsificationOptions o, unsigned numTensors, unsigned numLoops, | |||
58 | OpOperand *op, unsigned nest, std::vector<unsigned> &ts) | |||
59 | : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors), | |||
60 | pointers(numTensors, std::vector<Value>(numLoops)), | |||
61 | indices(numTensors, std::vector<Value>(numLoops)), | |||
62 | highs(numTensors, std::vector<Value>(numLoops)), | |||
63 | pidxs(numTensors, std::vector<Value>(numLoops)), | |||
64 | idxs(numTensors, std::vector<Value>(numLoops)), sparseOut(op), | |||
65 | outerParNest(nest), topSort(ts) {} | |||
66 | /// Sparsification options. | |||
67 | SparsificationOptions options; | |||
68 | /// Universal dense indices and upper bounds (by index). The loops array | |||
69 | /// is updated with the value of the universal dense index in the current | |||
70 | /// loop. The sizes array is set once with the inferred dimension sizes. | |||
71 | std::vector<Value> loops; | |||
72 | std::vector<Value> sizes; | |||
73 | /// Buffers for storing dense and sparse numerical values (by tensor). | |||
74 | /// This array is set once during bufferization of all tensors. | |||
75 | std::vector<Value> buffers; | |||
76 | /// Sparse storage schemes (1-D): pointers and indices (by tensor and index). | |||
77 | /// This array is set once during bufferization of all sparse tensors. | |||
78 | std::vector<std::vector<Value>> pointers; | |||
79 | std::vector<std::vector<Value>> indices; | |||
80 | /// Sparse iteration information (by tensor and index). These arrays | |||
81 | /// are updated to remain current within the current loop. | |||
82 | std::vector<std::vector<Value>> highs; | |||
83 | std::vector<std::vector<Value>> pidxs; | |||
84 | std::vector<std::vector<Value>> idxs; | |||
85 | /// Current reduction, updated during code generation. When indices of a | |||
86 | /// reduction are exhausted, all inner loops can use a scalarized reduction. | |||
87 | unsigned redExp = -1u; | |||
88 | Value redVal; | |||
89 | Reduction redKind = kNoReduc; | |||
90 | unsigned redCustom = -1u; | |||
91 | // Sparse tensor as output. Implemented either through direct injective | |||
92 | // insertion in lexicographic index order or through access pattern expansion | |||
93 | // in the innermost loop nest (`expValues` through `expCount`). | |||
94 | OpOperand *sparseOut; | |||
95 | unsigned outerParNest; | |||
96 | Value expValues; | |||
97 | Value expFilled; | |||
98 | Value expAdded; | |||
99 | Value expCount; | |||
100 | // Current vector length and mask. | |||
101 | unsigned curVecLength = 1; | |||
102 | Value curVecMask; | |||
103 | // Topsort (reference should remain in scope). | |||
104 | std::vector<unsigned> &topSort; | |||
105 | }; | |||
106 | ||||
107 | } // namespace | |||
108 | ||||
109 | //===----------------------------------------------------------------------===// | |||
110 | // Sparse compiler analysis methods. | |||
111 | //===----------------------------------------------------------------------===// | |||
112 | ||||
113 | /// Helper method to construct a permuted dimension ordering | |||
114 | /// that adheres to the given topological sort. | |||
115 | static AffineMap permute(MLIRContext *context, AffineMap m, | |||
116 | std::vector<unsigned> &topSort) { | |||
117 | unsigned sz = topSort.size(); | |||
118 | assert(m.getNumResults() == sz && "TopoSort/AffineMap size mismatch")(static_cast <bool> (m.getNumResults() == sz && "TopoSort/AffineMap size mismatch") ? void (0) : __assert_fail ("m.getNumResults() == sz && \"TopoSort/AffineMap size mismatch\"" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 118, __extension__ __PRETTY_FUNCTION__)); | |||
119 | // Construct the inverse of `m`; to avoid the asymptotic complexity | |||
120 | // of calling `m.getPermutedPosition` repeatedly. | |||
121 | SmallVector<unsigned, 4> inv(sz); | |||
122 | for (unsigned i = 0; i < sz; i++) | |||
123 | inv[i] = m.getDimPosition(i); | |||
124 | // Construct the permutation. | |||
125 | SmallVector<unsigned, 4> perm(sz); | |||
126 | for (unsigned i = 0; i < sz; i++) | |||
127 | perm[i] = inv[topSort[i]]; | |||
128 | return AffineMap::getPermutationMap(perm, context); | |||
129 | } | |||
130 | ||||
131 | /// Helper method to obtain the dimension level format from the encoding. | |||
132 | // | |||
133 | // TODO: note that we store, but currently completely *ignore* the properties | |||
134 | // | |||
135 | static DimLevelFormat toDimLevelFormat(const SparseTensorEncodingAttr &enc, | |||
136 | unsigned d) { | |||
137 | if (enc) { | |||
138 | switch (enc.getDimLevelType()[d]) { | |||
139 | case SparseTensorEncodingAttr::DimLevelType::Dense: | |||
140 | return DimLevelFormat(DimLvlType::kDense); | |||
141 | case SparseTensorEncodingAttr::DimLevelType::Compressed: | |||
142 | return DimLevelFormat(DimLvlType::kCompressed); | |||
143 | case SparseTensorEncodingAttr::DimLevelType::CompressedNu: | |||
144 | return DimLevelFormat(DimLvlType::kCompressed, true, false); | |||
145 | case SparseTensorEncodingAttr::DimLevelType::CompressedNo: | |||
146 | return DimLevelFormat(DimLvlType::kCompressed, false, true); | |||
147 | case SparseTensorEncodingAttr::DimLevelType::CompressedNuNo: | |||
148 | return DimLevelFormat(DimLvlType::kCompressed, false, false); | |||
149 | case SparseTensorEncodingAttr::DimLevelType::Singleton: | |||
150 | return DimLevelFormat(DimLvlType::kSingleton); | |||
151 | case SparseTensorEncodingAttr::DimLevelType::SingletonNu: | |||
152 | return DimLevelFormat(DimLvlType::kSingleton, true, false); | |||
153 | case SparseTensorEncodingAttr::DimLevelType::SingletonNo: | |||
154 | return DimLevelFormat(DimLvlType::kSingleton, false, true); | |||
155 | case SparseTensorEncodingAttr::DimLevelType::SingletonNuNo: | |||
156 | return DimLevelFormat(DimLvlType::kSingleton, false, false); | |||
157 | } | |||
158 | } | |||
159 | return DimLevelFormat(DimLvlType::kDense); | |||
160 | } | |||
161 | ||||
162 | /// Helper method to inspect affine expressions. Rejects cases where the | |||
163 | /// same index is used more than once. Also rejects compound affine | |||
164 | /// expressions in sparse dimensions. | |||
165 | static bool findAffine(Merger &merger, unsigned tensor, AffineExpr a, | |||
166 | DimLevelFormat dim) { | |||
167 | switch (a.getKind()) { | |||
168 | case AffineExprKind::DimId: { | |||
169 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
170 | if (!merger.isDimLevelType(tensor, idx, DimLvlType::kUndef)) | |||
171 | return false; // used more than once | |||
172 | merger.setDimLevelFormat(tensor, idx, dim); | |||
173 | return true; | |||
174 | } | |||
175 | case AffineExprKind::Add: | |||
176 | case AffineExprKind::Mul: { | |||
177 | if (dim.levelType != DimLvlType::kDense) | |||
178 | return false; // compound only in dense dim | |||
179 | auto binOp = a.cast<AffineBinaryOpExpr>(); | |||
180 | return findAffine(merger, tensor, binOp.getLHS(), dim) && | |||
181 | findAffine(merger, tensor, binOp.getRHS(), dim); | |||
182 | } | |||
183 | case AffineExprKind::Constant: | |||
184 | return dim.levelType == DimLvlType::kDense; // const only in dense dim | |||
185 | default: | |||
186 | return false; | |||
187 | } | |||
188 | } | |||
189 | ||||
190 | /// Helper method to inspect sparse encodings in the tensor types. | |||
191 | /// Fills the per-dimension sparsity information for all tensors. | |||
192 | /// Returns true if the sparse annotations and affine subscript | |||
193 | /// expressions of all tensors are admissible. Returns false if | |||
194 | /// no annotations are found or inadmissible constructs occur. | |||
195 | static bool findSparseAnnotations(Merger &merger, linalg::GenericOp op) { | |||
196 | bool annotated = false; | |||
197 | for (OpOperand *t : op.getInputAndOutputOperands()) { | |||
198 | auto map = op.getMatchingIndexingMap(t); | |||
199 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
200 | if (enc) | |||
201 | annotated = true; | |||
202 | assert(map.getNumResults() == op.getRank(t))(static_cast <bool> (map.getNumResults() == op.getRank( t)) ? void (0) : __assert_fail ("map.getNumResults() == op.getRank(t)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 202, __extension__ __PRETTY_FUNCTION__)); | |||
203 | for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { | |||
204 | unsigned tensor = t->getOperandNumber(); | |||
205 | AffineExpr a = map.getResult(toOrigDim(enc, d)); | |||
206 | if (!findAffine(merger, tensor, a, toDimLevelFormat(enc, d))) | |||
207 | return false; // inadmissible affine expression | |||
208 | } | |||
209 | } | |||
210 | return annotated; | |||
211 | } | |||
212 | ||||
213 | /// A helper to compute a topological sort. O(n^2) time complexity | |||
214 | /// as we use adj matrix for the graph. | |||
215 | /// The sorted result will put the first Reduction iterator to the | |||
216 | /// latest possible index. | |||
217 | static bool topSortOptimal(unsigned n, ArrayRef<StringRef> iteratorTypes, | |||
218 | std::vector<unsigned> &topSort, | |||
219 | std::vector<unsigned> &inDegree, | |||
220 | std::vector<std::vector<bool>> &adjM) { | |||
221 | std::vector<unsigned> redIt; // reduce iterator with 0 degree | |||
222 | std::vector<unsigned> parIt; // parallel iterator with 0 degree | |||
223 | for (unsigned i = 0; i < n; i++) { | |||
224 | if (inDegree[i] == 0) { | |||
225 | if (linalg::isReductionIterator(iteratorTypes[i])) | |||
226 | redIt.push_back(i); | |||
227 | else | |||
228 | parIt.push_back(i); | |||
229 | } | |||
230 | } | |||
231 | ||||
232 | while (!redIt.empty() || !parIt.empty()) { | |||
233 | // We always choose parallel iterator if there is any. | |||
234 | auto &it = !parIt.empty() ? parIt : redIt; | |||
235 | auto src = it.back(); | |||
236 | topSort.push_back(src); | |||
237 | it.pop_back(); | |||
238 | // Update in-degree, and push 0-degree node into worklist. | |||
239 | for (unsigned dst = 0; dst < n; dst++) | |||
240 | if (adjM[src][dst] && --inDegree[dst] == 0) { | |||
241 | if (linalg::isReductionIterator(iteratorTypes[dst])) | |||
242 | redIt.push_back(dst); | |||
243 | else | |||
244 | parIt.push_back(dst); | |||
245 | } | |||
246 | } | |||
247 | return topSort.size() == n; | |||
248 | } | |||
249 | ||||
250 | /// Helper method to add all constraints from the indices in one affine | |||
251 | /// expression before all indices in the other affine expression. For | |||
252 | /// example i0+i1 < i2+i3+1 yields i0<i2, i0<i3, i1<i2, and i1<i3. | |||
253 | static void addAffineOrderings(std::vector<std::vector<bool>> &adjM, | |||
254 | std::vector<unsigned> &inDegree, AffineExpr a, | |||
255 | AffineExpr b, unsigned fidx) { | |||
256 | switch (a.getKind()) { | |||
257 | case AffineExprKind::DimId: { | |||
258 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
259 | if (b) | |||
260 | addAffineOrderings(adjM, inDegree, b, AffineExpr(), idx); | |||
261 | else if (!adjM[fidx][idx]) { | |||
262 | adjM[fidx][idx] = true; | |||
263 | inDegree[idx]++; | |||
264 | } | |||
265 | break; | |||
266 | } | |||
267 | case AffineExprKind::Add: | |||
268 | case AffineExprKind::Mul: { | |||
269 | auto binOp = a.cast<AffineBinaryOpExpr>(); | |||
270 | addAffineOrderings(adjM, inDegree, binOp.getLHS(), b, fidx); | |||
271 | addAffineOrderings(adjM, inDegree, binOp.getRHS(), b, fidx); | |||
272 | break; | |||
273 | } | |||
274 | default: | |||
275 | break; | |||
276 | } | |||
277 | } | |||
278 | ||||
279 | /// Computes a topologically sorted iteration graph for the linalg operation. | |||
280 | /// Ensures all tensors are visited in natural index order. This is essential | |||
281 | /// for sparse storage formats since these only support access along fixed | |||
282 | /// dimensions. Even for dense storage formats, however, the natural index | |||
283 | /// order yields innermost unit-stride access with better spatial locality. | |||
284 | static bool computeIterationGraph(Merger &merger, linalg::GenericOp op, | |||
285 | std::vector<unsigned> &topSort, unsigned mask, | |||
286 | OpOperand *skip = nullptr) { | |||
287 | // Set up an n x n from/to adjacency matrix of the iteration graph | |||
288 | // for the implicit loop indices i_0 .. i_n-1. | |||
289 | unsigned n = op.getNumLoops(); | |||
290 | std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false)); | |||
291 | std::vector<unsigned> inDegree(n, 0); // in-degree of each node. | |||
292 | auto iteratorTypes = op.getIteratorTypesArray(); | |||
293 | // Iterate over the indexing maps of every tensor in the tensor expression. | |||
294 | for (OpOperand *t : op.getInputAndOutputOperands()) { | |||
295 | // Skip tensor during cycle resolution. | |||
296 | if (t == skip) | |||
297 | continue; | |||
298 | // Get map and encoding. | |||
299 | auto map = op.getMatchingIndexingMap(t); | |||
300 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
301 | assert(map.getNumDims() == n)(static_cast <bool> (map.getNumDims() == n) ? void (0) : __assert_fail ("map.getNumDims() == n", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 301, __extension__ __PRETTY_FUNCTION__)); | |||
302 | // Skip dense tensor constraints when not requested. | |||
303 | if (!(mask & SortMask::kIncludeDense) && !enc) | |||
304 | continue; | |||
305 | // Each tensor expression and optional dimension ordering (row-major | |||
306 | // by default) puts an ordering constraint on the loop indices. For | |||
307 | // example, the tensor expresion A_ijk forces the ordering i < j < k | |||
308 | // on the loop indices if no explicit dimension ordering is given. | |||
309 | for (unsigned d = 1, rank = map.getNumResults(); d < rank; d++) { | |||
310 | AffineExpr f = map.getResult(toOrigDim(enc, d - 1)); | |||
311 | AffineExpr t = map.getResult(toOrigDim(enc, d)); | |||
312 | addAffineOrderings(adjM, inDegree, f, t, 0); | |||
313 | } | |||
314 | // Push unrelated loops into sparse iteration space, so these | |||
315 | // will be skipped more often. | |||
316 | if (mask & SortMask::kIncludeUndef) { | |||
317 | unsigned tensor = t->getOperandNumber(); | |||
318 | for (unsigned i = 0; i < n; i++) | |||
319 | if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || | |||
320 | merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { | |||
321 | for (unsigned j = 0; j < n; j++) | |||
322 | if (merger.isDimLevelType(tensor, j, DimLvlType::kUndef)) { | |||
323 | adjM[i][j] = true; | |||
324 | inDegree[j]++; | |||
325 | } | |||
326 | } else { | |||
327 | assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(tensor, i, DimLvlType ::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef )) ? void (0) : __assert_fail ("merger.isDimLevelType(tensor, i, DimLvlType::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 328, __extension__ __PRETTY_FUNCTION__)) | |||
328 | merger.isDimLevelType(tensor, i, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(tensor, i, DimLvlType ::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef )) ? void (0) : __assert_fail ("merger.isDimLevelType(tensor, i, DimLvlType::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 328, __extension__ __PRETTY_FUNCTION__)); | |||
329 | } | |||
330 | } | |||
331 | } | |||
332 | // Topologically sort the iteration graph to determine loop order. | |||
333 | // Report failure for a cyclic iteration graph. | |||
334 | topSort.clear(); | |||
335 | topSort.reserve(n); | |||
336 | return topSortOptimal(n, iteratorTypes, topSort, inDegree, adjM); | |||
337 | } | |||
338 | ||||
339 | /// Returns true if tensor materializes uninitialized into the computation. | |||
340 | static bool isMaterializing(Value val) { | |||
341 | return val.getDefiningOp<linalg::InitTensorOp>() || | |||
342 | val.getDefiningOp<bufferization::AllocTensorOp>(); | |||
343 | } | |||
344 | ||||
345 | /// Returns true when the tensor expression is admissible for codegen. | |||
346 | /// Since all sparse input tensors are admissible, we just need to check | |||
347 | /// whether the out tensor in the tensor expression codegen is admissible. | |||
348 | /// Sets `sparseOut` to the tensor and `outerParNest` to the outer injective | |||
349 | /// nesting depth when a "truly dynamic" sparse tensor output occurs. | |||
350 | static bool isAdmissibleTensorExp(Merger &merger, linalg::GenericOp op, | |||
351 | std::vector<unsigned> &topSort, unsigned exp, | |||
352 | OpOperand **sparseOut, | |||
353 | unsigned &outerParNest) { | |||
354 | OpOperand *lhs = op.getOutputOperand(0); | |||
355 | unsigned tensor = lhs->getOperandNumber(); | |||
356 | auto enc = getSparseTensorEncoding(lhs->get().getType()); | |||
357 | // An non-annotated output tensor is assumed dense, and becomes a random | |||
358 | // access n-dim memref. Admissible since insertions cannot occur. | |||
359 | if (!enc) | |||
360 | return true; | |||
361 | // An all-dense annotated "sparse" output tensor becomes a linearized random | |||
362 | // access 1-dim memref. Also admissible since insertions cannot occur. | |||
363 | bool allDense = true; | |||
364 | auto iteratorTypes = op.getIteratorTypesArray(); | |||
365 | unsigned numLoops = iteratorTypes.size(); | |||
366 | for (unsigned i = 0; i < numLoops; i++) | |||
367 | if (merger.isDimLevelType(tensor, i, DimLvlType::kCompressed) || | |||
368 | merger.isDimLevelType(tensor, i, DimLvlType::kSingleton)) { | |||
369 | allDense = false; | |||
370 | break; | |||
371 | } else { | |||
372 | assert(merger.isDimLevelType(tensor, i, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(tensor, i, DimLvlType ::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef )) ? void (0) : __assert_fail ("merger.isDimLevelType(tensor, i, DimLvlType::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 373, __extension__ __PRETTY_FUNCTION__)) | |||
373 | merger.isDimLevelType(tensor, i, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(tensor, i, DimLvlType ::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef )) ? void (0) : __assert_fail ("merger.isDimLevelType(tensor, i, DimLvlType::kDense) || merger.isDimLevelType(tensor, i, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 373, __extension__ __PRETTY_FUNCTION__)); | |||
374 | } | |||
375 | if (allDense) | |||
376 | return true; | |||
377 | // A tensor expression with a sparse output tensor that changes its values | |||
378 | // but not its nonzero structure, an operation called "simply dynamic" in | |||
379 | // [Bik96,Ch9], is also admissible without special codegen. | |||
380 | if (merger.isSingleCondition(tensor, exp)) | |||
381 | return true; | |||
382 | // Accept "truly dynamic" if the output tensor materializes uninitialized | |||
383 | // into the computation and insertions occur in lexicographic index order. | |||
384 | if (isMaterializing(lhs->get())) { | |||
385 | unsigned nest = 0; | |||
386 | for (unsigned i = 0; i < numLoops; i++) { | |||
387 | if (linalg::isReductionIterator(iteratorTypes[topSort[i]])) | |||
388 | break; // terminate at first reduction | |||
389 | nest++; | |||
390 | } | |||
391 | // Determine admissible dynamic insertion situations: | |||
392 | // (1) fully injective, since there are no reductions, | |||
393 | // (2) admissible 1-d expansion in innermost dimension. | |||
394 | if (nest >= op.getRank(lhs) - 1) { | |||
395 | *sparseOut = lhs; | |||
396 | outerParNest = nest; | |||
397 | return true; | |||
398 | } | |||
399 | } | |||
400 | return false; | |||
401 | } | |||
402 | ||||
403 | //===----------------------------------------------------------------------===// | |||
404 | // Sparse compiler synthesis methods (reductions). | |||
405 | //===----------------------------------------------------------------------===// | |||
406 | ||||
407 | /// Maps reduction kind to vector::CombiningKind. | |||
408 | static vector::CombiningKind getCombiningKind(Reduction kind) { | |||
409 | switch (kind) { | |||
410 | case kNoReduc: | |||
411 | case kCustom: | |||
412 | break; | |||
413 | case kSum: | |||
414 | return vector::CombiningKind::ADD; | |||
415 | case kProduct: | |||
416 | return vector::CombiningKind::MUL; | |||
417 | case kAnd: | |||
418 | return vector::CombiningKind::AND; | |||
419 | case kOr: | |||
420 | return vector::CombiningKind::OR; | |||
421 | case kXor: | |||
422 | return vector::CombiningKind::XOR; | |||
423 | } | |||
424 | llvm_unreachable("unknown reduction kind")::llvm::llvm_unreachable_internal("unknown reduction kind", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 424); | |||
425 | } | |||
426 | ||||
427 | /// Maps operation to reduction. | |||
428 | static Reduction getReduction(Kind kind) { | |||
429 | switch (kind) { | |||
430 | case Kind::kAddF: | |||
431 | case Kind::kAddC: | |||
432 | case Kind::kAddI: | |||
433 | case Kind::kSubF: | |||
434 | case Kind::kSubC: | |||
435 | case Kind::kSubI: | |||
436 | return kSum; | |||
437 | case Kind::kMulF: | |||
438 | case Kind::kMulC: | |||
439 | case Kind::kMulI: | |||
440 | return kProduct; | |||
441 | case Kind::kAndI: | |||
442 | return kAnd; | |||
443 | case Kind::kOrI: | |||
444 | return kOr; | |||
445 | case Kind::kXorI: | |||
446 | return kXor; | |||
447 | case Kind::kReduce: | |||
448 | return kCustom; | |||
449 | default: | |||
450 | llvm_unreachable("unexpected reduction operator")::llvm::llvm_unreachable_internal("unexpected reduction operator" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 450); | |||
451 | } | |||
452 | } | |||
453 | ||||
454 | /// Generates an initial value for a vector reduction, following the scheme | |||
455 | /// given in Chapter 5 of "The Software Vectorization Handbook", where the | |||
456 | /// initial scalar value is correctly embedded in the vector reduction value, | |||
457 | /// and a straightforward horizontal reduction will complete the operation. | |||
458 | static Value genVectorReducInit(CodeGen &codegen, OpBuilder &builder, | |||
459 | Location loc, VectorType vtp) { | |||
460 | Value r = codegen.redVal; | |||
461 | switch (codegen.redKind) { | |||
462 | case kNoReduc: | |||
463 | case kCustom: | |||
464 | break; | |||
465 | case kSum: | |||
466 | case kXor: | |||
467 | // Initialize reduction vector to: | 0 | .. | 0 | r | | |||
468 | return builder.create<vector::InsertElementOp>( | |||
469 | loc, r, constantZero(builder, loc, vtp), | |||
470 | constantIndex(builder, loc, 0)); | |||
471 | case kProduct: | |||
472 | // Initialize reduction vector to: | 1 | .. | 1 | r | | |||
473 | return builder.create<vector::InsertElementOp>( | |||
474 | loc, r, constantOne(builder, loc, vtp), constantIndex(builder, loc, 0)); | |||
475 | case kAnd: | |||
476 | case kOr: | |||
477 | // Initialize reduction vector to: | r | .. | r | r | | |||
478 | return builder.create<vector::BroadcastOp>(loc, vtp, r); | |||
479 | } | |||
480 | llvm_unreachable("unknown reduction kind")::llvm::llvm_unreachable_internal("unknown reduction kind", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 480); | |||
481 | } | |||
482 | ||||
483 | /// Generates final value for a vector reduction. | |||
484 | static Value genVectorReducEnd(CodeGen &codegen, OpBuilder &builder, | |||
485 | Location loc, VectorType vtp) { | |||
486 | vector::CombiningKind kind = getCombiningKind(codegen.redKind); | |||
487 | return builder.create<vector::ReductionOp>(loc, kind, codegen.redVal); | |||
488 | } | |||
489 | ||||
490 | /// Updates scalarized reduction value. | |||
491 | static void updateReduc(Merger &merger, CodeGen &codegen, Value reduc) { | |||
492 | assert(codegen.redKind != kNoReduc)(static_cast <bool> (codegen.redKind != kNoReduc) ? void (0) : __assert_fail ("codegen.redKind != kNoReduc", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 492, __extension__ __PRETTY_FUNCTION__)); | |||
493 | codegen.redVal = merger.exp(codegen.redExp).val = reduc; | |||
494 | } | |||
495 | ||||
496 | /// Extracts identity from custom reduce. | |||
497 | static Value getCustomRedId(Operation *op) { | |||
498 | return dyn_cast<sparse_tensor::ReduceOp>(op).getIdentity(); | |||
499 | } | |||
500 | ||||
501 | //===----------------------------------------------------------------------===// | |||
502 | // Sparse compiler synthesis methods (statements and expressions). | |||
503 | //===----------------------------------------------------------------------===// | |||
504 | ||||
505 | /// Generates buffer for the output tensor. Note that all sparse kernels | |||
506 | /// assume that when all elements are written to (viz. x(i) = y(i) * z(i)), | |||
507 | /// the output buffer is already initialized to all zeroes and only nonzeroes | |||
508 | /// values are computed and written out. For updates (viz. x(i) += y(i) * z(i)), | |||
509 | /// only nonzeroes values are used for the updates and no assumption on the | |||
510 | /// original contents of the output buffer is necessary. | |||
511 | static Value genOutputBuffer(CodeGen &codegen, OpBuilder &builder, | |||
512 | linalg::GenericOp op, MemRefType denseTp, | |||
513 | ArrayRef<Value> args) { | |||
514 | Location loc = op.getLoc(); | |||
515 | OpOperand *lhs = op.getOutputOperand(0); | |||
516 | Value tensor = lhs->get(); | |||
517 | bool isInit = op.isInitTensor(lhs); | |||
518 | // An output tensor can simply materialize from the buffer of the tensor that | |||
519 | // appears in the outs() clause. For updates, this has the advantage that only | |||
520 | // the nonzero value are involved in the computation, keeping the operation | |||
521 | // O(nnz). In all other cases, we are forced to zero out the buffer to enforce | |||
522 | // the assumption above, which may negatively impact running complexity | |||
523 | // (viz. O(n^2 + nnz) vs. O(nnz) for matrices). | |||
524 | // TODO: use better analysis to avoid zeroing out the buffer? | |||
525 | Value init = builder.create<bufferization::ToMemrefOp>(loc, denseTp, tensor); | |||
526 | if (!isInit) { | |||
527 | Value zero = constantZero(builder, loc, denseTp.getElementType()); | |||
528 | builder.create<linalg::FillOp>(loc, ValueRange{zero}, ValueRange{init}); | |||
529 | } | |||
530 | return init; | |||
531 | } | |||
532 | ||||
533 | /// Local bufferization of all dense and sparse data structures. | |||
534 | static void genBuffers(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
535 | linalg::GenericOp op) { | |||
536 | Location loc = op.getLoc(); | |||
537 | assert(op.getNumInputsAndOutputs() == op.getNumInputs() + 1)(static_cast <bool> (op.getNumInputsAndOutputs() == op. getNumInputs() + 1) ? void (0) : __assert_fail ("op.getNumInputsAndOutputs() == op.getNumInputs() + 1" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 537, __extension__ __PRETTY_FUNCTION__)); | |||
538 | // For every tensor, find lower and upper bound on dimensions, set the | |||
539 | // same bounds on loop indices, and obtain dense or sparse buffer(s). | |||
540 | auto dynShape = {ShapedType::kDynamicSize}; | |||
541 | SmallVector<Value, 4> args; | |||
542 | for (OpOperand *t : op.getInputAndOutputOperands()) { | |||
543 | unsigned tensor = t->getOperandNumber(); | |||
544 | auto shape = op.getShape(t); | |||
545 | auto map = op.getMatchingIndexingMap(t); | |||
546 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
547 | // Scan all dimensions of current tensor. | |||
548 | args.clear(); | |||
549 | for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { | |||
550 | AffineExpr a = map.getResult(toOrigDim(enc, d)); | |||
551 | if (a.getKind() != AffineExprKind::DimId) | |||
552 | continue; // compound | |||
553 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
554 | // Handle the different storage schemes. | |||
555 | if (merger.isDimLevelType(tensor, idx, DimLvlType::kCompressed)) { | |||
556 | // Compressed dimension, fetch pointer and indices. | |||
557 | auto ptrTp = | |||
558 | MemRefType::get(dynShape, getPointerOverheadType(builder, enc)); | |||
559 | auto indTp = | |||
560 | MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); | |||
561 | auto dim = builder.getIndexAttr(d); | |||
562 | codegen.pointers[tensor][idx] = | |||
563 | builder.create<ToPointersOp>(loc, ptrTp, t->get(), dim); | |||
564 | codegen.indices[tensor][idx] = | |||
565 | builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); | |||
566 | } else if (merger.isDimLevelType(tensor, idx, DimLvlType::kSingleton)) { | |||
567 | // Singleton dimension, fetch indices. | |||
568 | auto indTp = | |||
569 | MemRefType::get(dynShape, getIndexOverheadType(builder, enc)); | |||
570 | auto dim = builder.getIndexAttr(d); | |||
571 | codegen.indices[tensor][idx] = | |||
572 | builder.create<ToIndicesOp>(loc, indTp, t->get(), dim); | |||
573 | } else { | |||
574 | // Dense dimension, nothing to fetch. | |||
575 | assert(merger.isDimLevelType(tensor, idx, DimLvlType::kDense))(static_cast <bool> (merger.isDimLevelType(tensor, idx, DimLvlType::kDense)) ? void (0) : __assert_fail ("merger.isDimLevelType(tensor, idx, DimLvlType::kDense)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 575, __extension__ __PRETTY_FUNCTION__)); | |||
576 | } | |||
577 | // Find upper bound in current dimension. | |||
578 | unsigned p = toOrigDim(enc, d); | |||
579 | Value up = linalg::createOrFoldDimOp(builder, loc, t->get(), p); | |||
580 | if (ShapedType::isDynamic(shape[p])) | |||
581 | args.push_back(up); | |||
582 | assert(codegen.highs[tensor][idx] == nullptr)(static_cast <bool> (codegen.highs[tensor][idx] == nullptr ) ? void (0) : __assert_fail ("codegen.highs[tensor][idx] == nullptr" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 582, __extension__ __PRETTY_FUNCTION__)); | |||
583 | codegen.sizes[idx] = codegen.highs[tensor][idx] = up; | |||
584 | } | |||
585 | // Perform the required bufferization. Dense inputs materialize | |||
586 | // from the input tensors. Dense outputs need special handling. | |||
587 | // Sparse inputs use sparse primitives to obtain the values. | |||
588 | Type elementType = getElementTypeOrSelf(t->get().getType()); | |||
589 | if (!enc) { | |||
590 | // Non-annotated dense tensors. | |||
591 | auto denseTp = MemRefType::get(shape, elementType); | |||
592 | if (tensor < op.getNumInputs()) | |||
593 | codegen.buffers[tensor] = | |||
594 | builder.create<bufferization::ToMemrefOp>(loc, denseTp, t->get()); | |||
595 | else | |||
596 | codegen.buffers[tensor] = | |||
597 | genOutputBuffer(codegen, builder, op, denseTp, args); | |||
598 | } else if (t != codegen.sparseOut) { | |||
599 | // Annotated sparse tensors (not involved in output). | |||
600 | auto sparseTp = MemRefType::get(dynShape, elementType); | |||
601 | codegen.buffers[tensor] = | |||
602 | builder.create<ToValuesOp>(loc, sparseTp, t->get()); | |||
603 | } | |||
604 | } | |||
605 | } | |||
606 | ||||
607 | /// Constructs vector type. | |||
608 | static VectorType vectorType(CodeGen &codegen, Type etp) { | |||
609 | unsigned numScalableDims = codegen.options.enableVLAVectorization; | |||
610 | return VectorType::get(codegen.curVecLength, etp, numScalableDims); | |||
611 | } | |||
612 | ||||
613 | /// Constructs vector type from pointer. | |||
614 | static VectorType vectorType(CodeGen &codegen, Value ptr) { | |||
615 | return vectorType(codegen, ptr.getType().cast<MemRefType>().getElementType()); | |||
616 | } | |||
617 | ||||
618 | /// Constructs vector iteration mask. | |||
619 | static Value genVectorMask(CodeGen &codegen, OpBuilder &builder, Value iv, | |||
620 | Value lo, Value hi, Value step) { | |||
621 | Location loc = iv.getLoc(); | |||
622 | VectorType mtp = vectorType(codegen, builder.getI1Type()); | |||
623 | // Special case if the vector length evenly divides the trip count (for | |||
624 | // example, "for i = 0, 128, 16"). A constant all-true mask is generated | |||
625 | // so that all subsequent masked memory operations are immediately folded | |||
626 | // into unconditional memory operations. | |||
627 | IntegerAttr loInt, hiInt, stepInt; | |||
628 | if (matchPattern(lo, m_Constant(&loInt)) && | |||
629 | matchPattern(hi, m_Constant(&hiInt)) && | |||
630 | matchPattern(step, m_Constant(&stepInt))) { | |||
631 | if (((hiInt.getInt() - loInt.getInt()) % stepInt.getInt()) == 0) | |||
632 | return builder.create<vector::BroadcastOp>( | |||
633 | loc, mtp, constantI1(builder, loc, true)); | |||
634 | } | |||
635 | // Otherwise, generate a vector mask that avoids overrunning the upperbound | |||
636 | // during vector execution. Here we rely on subsequent loop optimizations to | |||
637 | // avoid executing the mask in all iterations, for example, by splitting the | |||
638 | // loop into an unconditional vector loop and a scalar cleanup loop. | |||
639 | auto minMap = AffineMap::get( | |||
640 | /*dimCount=*/2, /*symbolCount=*/1, | |||
641 | {builder.getAffineSymbolExpr(0), | |||
642 | builder.getAffineDimExpr(0) - builder.getAffineDimExpr(1)}, | |||
643 | builder.getContext()); | |||
644 | Value end = | |||
645 | builder.createOrFold<AffineMinOp>(loc, minMap, ValueRange{hi, iv, step}); | |||
646 | return builder.create<vector::CreateMaskOp>(loc, mtp, end); | |||
647 | } | |||
648 | ||||
649 | /// Generates a vectorized load lhs = a[ind[lo:hi]] or lhs = a[lo:hi]. | |||
650 | static Value genVectorLoad(CodeGen &codegen, OpBuilder &builder, Value ptr, | |||
651 | ArrayRef<Value> args) { | |||
652 | Location loc = ptr.getLoc(); | |||
653 | VectorType vtp = vectorType(codegen, ptr); | |||
654 | Value pass = constantZero(builder, loc, vtp); | |||
655 | if (args.back().getType().isa<VectorType>()) { | |||
656 | SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); | |||
657 | Value indexVec = args.back(); | |||
658 | scalarArgs.back() = constantIndex(builder, loc, 0); | |||
659 | return builder.create<vector::GatherOp>(loc, vtp, ptr, scalarArgs, indexVec, | |||
660 | codegen.curVecMask, pass); | |||
661 | } | |||
662 | return builder.create<vector::MaskedLoadOp>(loc, vtp, ptr, args, | |||
663 | codegen.curVecMask, pass); | |||
664 | } | |||
665 | ||||
666 | /// Generates a vectorized store a[ind[lo:hi]] = rhs or a[lo:hi] = rhs. | |||
667 | static void genVectorStore(CodeGen &codegen, OpBuilder &builder, Value rhs, | |||
668 | Value ptr, ArrayRef<Value> args) { | |||
669 | Location loc = ptr.getLoc(); | |||
670 | if (args.back().getType().isa<VectorType>()) { | |||
671 | SmallVector<Value, 4> scalarArgs(args.begin(), args.end()); | |||
672 | Value indexVec = args.back(); | |||
673 | scalarArgs.back() = constantIndex(builder, loc, 0); | |||
674 | builder.create<vector::ScatterOp>(loc, ptr, scalarArgs, indexVec, | |||
675 | codegen.curVecMask, rhs); | |||
676 | return; | |||
677 | } | |||
678 | builder.create<vector::MaskedStoreOp>(loc, ptr, args, codegen.curVecMask, | |||
679 | rhs); | |||
680 | } | |||
681 | ||||
682 | /// Generates a vectorized invariant. Here we rely on subsequent loop | |||
683 | /// optimizations to hoist the invariant broadcast out of the vector loop. | |||
684 | static Value genVectorInvariantValue(CodeGen &codegen, OpBuilder &builder, | |||
685 | Value val) { | |||
686 | VectorType vtp = vectorType(codegen, val.getType()); | |||
687 | return builder.create<vector::BroadcastOp>(val.getLoc(), vtp, val); | |||
688 | } | |||
689 | ||||
690 | /// Generates an affine expression. | |||
691 | // | |||
692 | // TODO: generalize for sparse tensor subscripts | |||
693 | // | |||
694 | static Value genAffine(CodeGen &codegen, OpBuilder &builder, AffineExpr a, | |||
695 | Location loc) { | |||
696 | switch (a.getKind()) { | |||
697 | case AffineExprKind::DimId: { | |||
698 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
699 | return codegen.loops[idx]; // universal dense index | |||
700 | } | |||
701 | case AffineExprKind::Add: { | |||
702 | auto binOp = a.cast<AffineBinaryOpExpr>(); | |||
703 | return builder.create<arith::AddIOp>( | |||
704 | loc, genAffine(codegen, builder, binOp.getLHS(), loc), | |||
705 | genAffine(codegen, builder, binOp.getRHS(), loc)); | |||
706 | } | |||
707 | case AffineExprKind::Mul: { | |||
708 | auto binOp = a.cast<AffineBinaryOpExpr>(); | |||
709 | return builder.create<arith::MulIOp>( | |||
710 | loc, genAffine(codegen, builder, binOp.getLHS(), loc), | |||
711 | genAffine(codegen, builder, binOp.getRHS(), loc)); | |||
712 | } | |||
713 | case AffineExprKind::Constant: { | |||
714 | int64_t c = a.cast<AffineConstantExpr>().getValue(); | |||
715 | return constantIndex(builder, loc, c); | |||
716 | } | |||
717 | default: | |||
718 | llvm_unreachable("unexpected affine subscript")::llvm::llvm_unreachable_internal("unexpected affine subscript" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 718); | |||
719 | } | |||
720 | } | |||
721 | ||||
722 | /// Generates index for load/store on sparse tensor. | |||
723 | static Value genIndex(CodeGen &codegen, linalg::GenericOp op, OpOperand *t) { | |||
724 | auto map = op.getMatchingIndexingMap(t); | |||
725 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
726 | AffineExpr a = map.getResult(toOrigDim(enc, map.getNumResults() - 1)); | |||
727 | assert(a.getKind() == AffineExprKind::DimId)(static_cast <bool> (a.getKind() == AffineExprKind::DimId ) ? void (0) : __assert_fail ("a.getKind() == AffineExprKind::DimId" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 727, __extension__ __PRETTY_FUNCTION__)); | |||
728 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
729 | return codegen.loops[idx]; | |||
730 | } | |||
731 | ||||
732 | /// Generates subscript for load/store on a dense or sparse tensor. | |||
733 | static Value genSubscript(CodeGen &codegen, OpBuilder &builder, | |||
734 | linalg::GenericOp op, OpOperand *t, | |||
735 | SmallVector<Value, 4> &args) { | |||
736 | unsigned tensor = t->getOperandNumber(); | |||
737 | auto map = op.getMatchingIndexingMap(t); | |||
738 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
739 | unsigned rank = map.getNumResults(); | |||
740 | if (enc) { | |||
741 | // Note that currently, all sparse subscripts are simple. | |||
742 | // TODO: accept affine too? | |||
743 | AffineExpr a = map.getResult(toOrigDim(enc, rank - 1)); | |||
744 | assert(a.getKind() == AffineExprKind::DimId)(static_cast <bool> (a.getKind() == AffineExprKind::DimId ) ? void (0) : __assert_fail ("a.getKind() == AffineExprKind::DimId" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 744, __extension__ __PRETTY_FUNCTION__)); | |||
745 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
746 | assert(codegen.pidxs[tensor][idx] != nullptr)(static_cast <bool> (codegen.pidxs[tensor][idx] != nullptr ) ? void (0) : __assert_fail ("codegen.pidxs[tensor][idx] != nullptr" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 746, __extension__ __PRETTY_FUNCTION__)); | |||
747 | args.push_back(codegen.pidxs[tensor][idx]); // position index | |||
748 | } else { | |||
749 | for (unsigned d = 0; d < rank; d++) { | |||
750 | AffineExpr a = map.getResult(d); | |||
751 | args.push_back(genAffine(codegen, builder, a, op.getLoc())); | |||
752 | } | |||
753 | } | |||
754 | return codegen.buffers[tensor]; | |||
755 | } | |||
756 | ||||
757 | /// Generates insertion code to implement dynamic tensor load. | |||
758 | static Value genInsertionLoad(CodeGen &codegen, OpBuilder &builder, | |||
759 | linalg::GenericOp op, OpOperand *t) { | |||
760 | Location loc = op.getLoc(); | |||
761 | // Direct lexicographic index order, tensor loads as zero. | |||
762 | if (!codegen.expValues) { | |||
763 | Type tp = getElementTypeOrSelf(t->get().getType()); | |||
| ||||
764 | return constantZero(builder, loc, tp); | |||
765 | } | |||
766 | // Load from expanded access pattern. | |||
767 | Value index = genIndex(codegen, op, t); | |||
768 | return builder.create<memref::LoadOp>(loc, codegen.expValues, index); | |||
769 | } | |||
770 | ||||
771 | /// Generates insertion code to implement dynamic tensor load for reduction. | |||
772 | static Value genInsertionLoadReduce(Merger &merger, CodeGen &codegen, | |||
773 | OpBuilder &builder, linalg::GenericOp op, | |||
774 | OpOperand *t) { | |||
775 | Location loc = op.getLoc(); | |||
776 | Value identity = getCustomRedId(merger.exp(codegen.redCustom).op); | |||
777 | // Direct lexicographic index order, tensor loads as identity. | |||
778 | if (!codegen.expValues) { | |||
779 | return identity; | |||
780 | } | |||
781 | // Load from expanded access pattern if filled, identity otherwise. | |||
782 | Value index = genIndex(codegen, op, t); | |||
783 | Value isFilled = | |||
784 | builder.create<memref::LoadOp>(loc, codegen.expFilled, index); | |||
785 | Value valAtIndex = | |||
786 | builder.create<memref::LoadOp>(loc, codegen.expValues, index); | |||
787 | return builder.create<arith::SelectOp>(loc, isFilled, valAtIndex, identity); | |||
788 | } | |||
789 | ||||
790 | /// Generates insertion code to implement dynamic tensor store. | |||
791 | static void genInsertionStore(CodeGen &codegen, OpBuilder &builder, | |||
792 | linalg::GenericOp op, OpOperand *t, Value rhs) { | |||
793 | Location loc = op.getLoc(); | |||
794 | // Direct insertion in lexicographic index order. | |||
795 | if (!codegen.expValues) { | |||
796 | unsigned rank = op.getRank(t); | |||
797 | SmallVector<Value, 4> indices; | |||
798 | for (unsigned i = 0; i < rank; i++) { | |||
799 | assert(codegen.loops[codegen.topSort[i]])(static_cast <bool> (codegen.loops[codegen.topSort[i]]) ? void (0) : __assert_fail ("codegen.loops[codegen.topSort[i]]" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 799, __extension__ __PRETTY_FUNCTION__)); | |||
800 | indices.push_back(codegen.loops[codegen.topSort[i]]); | |||
801 | } | |||
802 | builder.create<InsertOp>(loc, rhs, t->get(), indices); | |||
803 | return; | |||
804 | } | |||
805 | // Generates insertion code along expanded access pattern. | |||
806 | // if (!expFilled[i]) then | |||
807 | // expFilled[i] = true | |||
808 | // expAdded[inserts++] = i | |||
809 | // endif | |||
810 | // values[i] = rhs | |||
811 | Value index = genIndex(codegen, op, t); | |||
812 | Value fval = constantI1(builder, loc, false); | |||
813 | Value tval = constantI1(builder, loc, true); | |||
814 | // If statement. | |||
815 | Value filled = builder.create<memref::LoadOp>(loc, codegen.expFilled, index); | |||
816 | Value cond = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | |||
817 | filled, fval); | |||
818 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, builder.getIndexType(), cond, | |||
819 | /*else=*/true); | |||
820 | // True branch. | |||
821 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); | |||
822 | builder.create<memref::StoreOp>(loc, tval, codegen.expFilled, index); | |||
823 | builder.create<memref::StoreOp>(loc, index, codegen.expAdded, | |||
824 | codegen.expCount); | |||
825 | Value one = constantIndex(builder, loc, 1); | |||
826 | Value add = builder.create<arith::AddIOp>(loc, codegen.expCount, one); | |||
827 | builder.create<scf::YieldOp>(loc, add); | |||
828 | // False branch. | |||
829 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); | |||
830 | builder.create<scf::YieldOp>(loc, codegen.expCount); | |||
831 | builder.setInsertionPointAfter(ifOp); | |||
832 | // Value assignment. | |||
833 | codegen.expCount = ifOp.getResult(0); | |||
834 | builder.create<memref::StoreOp>(loc, rhs, codegen.expValues, index); | |||
835 | } | |||
836 | ||||
837 | /// Generates a load on a dense or sparse tensor. | |||
838 | static Value genTensorLoad(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
839 | linalg::GenericOp op, unsigned exp) { | |||
840 | // Test if the load was hoisted to a higher loop nest. | |||
841 | Value val = merger.exp(exp).val; | |||
842 | if (val) { | |||
843 | if (codegen.curVecLength > 1 && !val.getType().isa<VectorType>()) | |||
844 | return genVectorInvariantValue(codegen, builder, val); | |||
845 | return val; | |||
846 | } | |||
847 | // Load during insertion. | |||
848 | OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; | |||
849 | if (t == codegen.sparseOut) { | |||
850 | if (codegen.redCustom != -1u) | |||
851 | return genInsertionLoadReduce(merger, codegen, builder, op, t); | |||
852 | return genInsertionLoad(codegen, builder, op, t); | |||
853 | } | |||
854 | // Actual load. | |||
855 | SmallVector<Value, 4> args; | |||
856 | Value ptr = genSubscript(codegen, builder, op, t, args); | |||
857 | if (codegen.curVecLength > 1) | |||
858 | return genVectorLoad(codegen, builder, ptr, args); | |||
859 | return builder.create<memref::LoadOp>(op.getLoc(), ptr, args); | |||
860 | } | |||
861 | ||||
862 | /// Generates a store on a dense or sparse tensor. | |||
863 | static void genTensorStore(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
864 | linalg::GenericOp op, unsigned exp, Value rhs) { | |||
865 | Location loc = op.getLoc(); | |||
866 | // Test if this is a scalarized reduction. | |||
867 | if (codegen.redVal) { | |||
868 | if (codegen.curVecLength > 1) | |||
869 | rhs = builder.create<arith::SelectOp>(loc, codegen.curVecMask, rhs, | |||
870 | codegen.redVal); | |||
871 | updateReduc(merger, codegen, rhs); | |||
872 | return; | |||
873 | } | |||
874 | // Store during insertion. | |||
875 | OpOperand *t = op.getOutputOperand(0); | |||
876 | if (t == codegen.sparseOut) { | |||
877 | if (!rhs) { | |||
878 | // Only unary and binary are allowed to return uninitialized rhs | |||
879 | // to indicate missing output. | |||
880 | assert(merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary)(static_cast <bool> (merger.exp(exp).kind == kUnary || merger .exp(exp).kind == kBinary) ? void (0) : __assert_fail ("merger.exp(exp).kind == kUnary || merger.exp(exp).kind == kBinary" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 880, __extension__ __PRETTY_FUNCTION__)); | |||
881 | } else { | |||
882 | genInsertionStore(codegen, builder, op, t, rhs); | |||
883 | } | |||
884 | return; | |||
885 | } | |||
886 | // Actual store. | |||
887 | SmallVector<Value, 4> args; | |||
888 | Value ptr = genSubscript(codegen, builder, op, t, args); | |||
889 | if (codegen.curVecLength > 1) | |||
890 | genVectorStore(codegen, builder, rhs, ptr, args); | |||
891 | else | |||
892 | builder.create<memref::StoreOp>(loc, rhs, ptr, args); | |||
893 | } | |||
894 | ||||
895 | /// Generates a pointer/index load from the sparse storage scheme. Narrower | |||
896 | /// data types need to be zero extended before casting the value into the | |||
897 | /// index type used for looping and indexing. | |||
898 | static Value genLoad(CodeGen &codegen, OpBuilder &builder, Location loc, | |||
899 | Value ptr, Value s) { | |||
900 | // See https://llvm.org/docs/GetElementPtr.html for some background on | |||
901 | // the complications described below. | |||
902 | if (codegen.curVecLength > 1) { | |||
903 | // Since the index vector is used in a subsequent gather/scatter operations, | |||
904 | // which effectively defines an unsigned pointer + signed index, we must | |||
905 | // zero extend the vector to an index width. For 8-bit and 16-bit values, | |||
906 | // an 32-bit index width suffices. For 32-bit values, zero extending the | |||
907 | // elements into 64-bit loses some performance since the 32-bit indexed | |||
908 | // gather/scatter is more efficient than the 64-bit index variant (if the | |||
909 | // negative 32-bit index space is unused, the enableSIMDIndex32 flag can | |||
910 | // preserve this performance). For 64-bit values, there is no good way | |||
911 | // to state that the indices are unsigned, with creates the potential of | |||
912 | // incorrect address calculations in the unlikely case we need such | |||
913 | // extremely large offsets. | |||
914 | Type etp = ptr.getType().cast<MemRefType>().getElementType(); | |||
915 | Value vload = genVectorLoad(codegen, builder, ptr, {s}); | |||
916 | if (!etp.isa<IndexType>()) { | |||
917 | if (etp.getIntOrFloatBitWidth() < 32) | |||
918 | vload = builder.create<arith::ExtUIOp>( | |||
919 | loc, vectorType(codegen, builder.getI32Type()), vload); | |||
920 | else if (etp.getIntOrFloatBitWidth() < 64 && | |||
921 | !codegen.options.enableSIMDIndex32) | |||
922 | vload = builder.create<arith::ExtUIOp>( | |||
923 | loc, vectorType(codegen, builder.getI64Type()), vload); | |||
924 | } | |||
925 | return vload; | |||
926 | } | |||
927 | // For the scalar case, we simply zero extend narrower indices into 64-bit | |||
928 | // values before casting to index without a performance penalty. Here too, | |||
929 | // however, indices that already are 64-bit, in theory, cannot express the | |||
930 | // full range as explained above. | |||
931 | Value load = builder.create<memref::LoadOp>(loc, ptr, s); | |||
932 | if (!load.getType().isa<IndexType>()) { | |||
933 | if (load.getType().getIntOrFloatBitWidth() < 64) | |||
934 | load = builder.create<arith::ExtUIOp>(loc, builder.getI64Type(), load); | |||
935 | load = | |||
936 | builder.create<arith::IndexCastOp>(loc, builder.getIndexType(), load); | |||
937 | } | |||
938 | return load; | |||
939 | } | |||
940 | ||||
941 | /// Generates an invariant value. | |||
942 | static Value genInvariantValue(Merger &merger, CodeGen &codegen, | |||
943 | OpBuilder &builder, unsigned exp) { | |||
944 | Value val = merger.exp(exp).val; | |||
945 | if (codegen.curVecLength > 1) | |||
946 | return genVectorInvariantValue(codegen, builder, val); | |||
947 | return val; | |||
948 | } | |||
949 | ||||
950 | /// Generates an address computation "sz * p + i". | |||
951 | static Value genAddress(CodeGen &codegen, OpBuilder &builder, Location loc, | |||
952 | Value size, Value p, Value i) { | |||
953 | Value mul = builder.create<arith::MulIOp>(loc, size, p); | |||
954 | if (auto vtp = i.getType().dyn_cast<VectorType>()) { | |||
955 | Value inv = | |||
956 | builder.create<arith::IndexCastOp>(loc, vtp.getElementType(), mul); | |||
957 | mul = genVectorInvariantValue(codegen, builder, inv); | |||
958 | } | |||
959 | return builder.create<arith::AddIOp>(loc, mul, i); | |||
960 | } | |||
961 | ||||
962 | /// Generates an index value. | |||
963 | static Value genIndexValue(CodeGen &codegen, OpBuilder &builder, unsigned idx, | |||
964 | unsigned ldx) { | |||
965 | Value ival = codegen.loops[idx]; | |||
966 | Type itype = ival.getType(); | |||
967 | // During vectorization, we either encounter: | |||
968 | // (1) indices already in vector form, as in ... = ind[lo:hi], good to go, or | |||
969 | // (2) single index, as in ... = i, must convert to [i, i+1, ...] for inner i. | |||
970 | unsigned vl = codegen.curVecLength; | |||
971 | if (vl > 1 && !itype.isa<VectorType>()) { | |||
972 | Location loc = ival.getLoc(); | |||
973 | VectorType vtp = vectorType(codegen, itype); | |||
974 | ival = builder.create<vector::BroadcastOp>(loc, vtp, ival); | |||
975 | if (idx == ldx) { | |||
976 | Value incr; | |||
977 | if (vtp.isScalable()) { | |||
978 | Type stepvty = vectorType(codegen, builder.getI64Type()); | |||
979 | Value stepv = builder.create<LLVM::StepVectorOp>(loc, stepvty); | |||
980 | incr = builder.create<arith::IndexCastOp>(loc, vtp, stepv); | |||
981 | } else { | |||
982 | SmallVector<APInt, 4> integers; | |||
983 | for (unsigned i = 0; i < vl; i++) | |||
984 | integers.push_back(APInt(/*width=*/64, i)); | |||
985 | auto values = DenseElementsAttr::get(vtp, integers); | |||
986 | incr = builder.create<arith::ConstantOp>(loc, vtp, values); | |||
987 | } | |||
988 | ival = builder.create<arith::AddIOp>(loc, ival, incr); | |||
989 | } | |||
990 | } | |||
991 | return ival; | |||
992 | } | |||
993 | ||||
994 | /// Semi-ring branches are simply inlined by the sparse compiler. Prior | |||
995 | /// analysis has verified that all computations are "local" to the inlined | |||
996 | /// branch or otherwise invariantly defined outside the loop nest, with the | |||
997 | /// exception of index computations, which need to be relinked to actual | |||
998 | /// inlined cloned code. | |||
999 | static Value relinkBranch(CodeGen &codegen, RewriterBase &rewriter, | |||
1000 | Block *block, Value e, unsigned ldx) { | |||
1001 | if (Operation *def = e.getDefiningOp()) { | |||
1002 | if (auto indexOp = dyn_cast<linalg::IndexOp>(def)) | |||
1003 | return genIndexValue(codegen, rewriter, indexOp.getDim(), ldx); | |||
1004 | if (def->getBlock() == block) { | |||
1005 | for (unsigned i = 0, n = def->getNumOperands(); i < n; i++) | |||
1006 | def->setOperand( | |||
1007 | i, relinkBranch(codegen, rewriter, block, def->getOperand(i), ldx)); | |||
1008 | } | |||
1009 | } | |||
1010 | return e; | |||
1011 | } | |||
1012 | ||||
1013 | /// Recursively generates tensor expression. | |||
1014 | static Value genExp(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, | |||
1015 | linalg::GenericOp op, unsigned exp, unsigned ldx) { | |||
1016 | Location loc = op.getLoc(); | |||
1017 | if (exp == -1u) | |||
1018 | return Value(); | |||
1019 | if (merger.exp(exp).kind == Kind::kTensor) | |||
1020 | return genTensorLoad(merger, codegen, rewriter, op, exp); | |||
1021 | if (merger.exp(exp).kind == Kind::kInvariant) | |||
1022 | return genInvariantValue(merger, codegen, rewriter, exp); | |||
1023 | if (merger.exp(exp).kind == Kind::kIndex) | |||
1024 | return genIndexValue(codegen, rewriter, merger.exp(exp).index, ldx); | |||
1025 | ||||
1026 | if (merger.exp(exp).kind == Kind::kReduce) { | |||
1027 | // Make custom reduction identity accessible for expanded access pattern. | |||
1028 | assert(codegen.redCustom == -1u)(static_cast <bool> (codegen.redCustom == -1u) ? void ( 0) : __assert_fail ("codegen.redCustom == -1u", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1028, __extension__ __PRETTY_FUNCTION__)); | |||
1029 | codegen.redCustom = exp; | |||
1030 | } | |||
1031 | ||||
1032 | Value v0 = | |||
1033 | genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e0, ldx); | |||
1034 | Value v1 = | |||
1035 | genExp(merger, codegen, rewriter, op, merger.exp(exp).children.e1, ldx); | |||
1036 | Value ee = merger.buildExp(rewriter, loc, exp, v0, v1); | |||
1037 | if (ee && (merger.exp(exp).kind == Kind::kUnary || | |||
1038 | merger.exp(exp).kind == Kind::kBinary || | |||
1039 | merger.exp(exp).kind == Kind::kBinaryBranch || | |||
1040 | merger.exp(exp).kind == Kind::kReduce)) | |||
1041 | ee = relinkBranch(codegen, rewriter, ee.getParentBlock(), ee, ldx); | |||
1042 | ||||
1043 | if (merger.exp(exp).kind == Kind::kReduce) { | |||
1044 | assert(codegen.redCustom != -1u)(static_cast <bool> (codegen.redCustom != -1u) ? void ( 0) : __assert_fail ("codegen.redCustom != -1u", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1044, __extension__ __PRETTY_FUNCTION__)); | |||
1045 | codegen.redCustom = -1u; | |||
1046 | } | |||
1047 | ||||
1048 | return ee; | |||
1049 | } | |||
1050 | ||||
1051 | /// Determines if affine expression is invariant. | |||
1052 | static bool isInvariantAffine(const CodeGen &codegen, AffineExpr a, | |||
1053 | unsigned ldx, bool &atLevel) { | |||
1054 | switch (a.getKind()) { | |||
1055 | case AffineExprKind::DimId: { | |||
1056 | unsigned idx = a.cast<AffineDimExpr>().getPosition(); | |||
1057 | if (idx == ldx) | |||
1058 | atLevel = true; | |||
1059 | return codegen.loops[idx] != nullptr; // no longer in play? | |||
1060 | } | |||
1061 | case AffineExprKind::Add: | |||
1062 | case AffineExprKind::Mul: { | |||
1063 | auto binOp = a.cast<AffineBinaryOpExpr>(); | |||
1064 | return isInvariantAffine(codegen, binOp.getLHS(), ldx, atLevel) && | |||
1065 | isInvariantAffine(codegen, binOp.getRHS(), ldx, atLevel); | |||
1066 | } | |||
1067 | default: | |||
1068 | return true; | |||
1069 | } | |||
1070 | } | |||
1071 | ||||
1072 | /// Hoists loop invariant tensor loads for which indices have been exhausted. | |||
1073 | static void genInvariants(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1074 | linalg::GenericOp op, unsigned exp, unsigned ldx, | |||
1075 | bool atStart, unsigned last = -1u) { | |||
1076 | if (exp == -1u) | |||
1077 | return; | |||
1078 | if (merger.exp(exp).kind == Kind::kTensor) { | |||
1079 | // Inspect tensor indices. | |||
1080 | bool atLevel = ldx == -1u; | |||
1081 | OpOperand *t = op.getInputAndOutputOperands()[merger.exp(exp).tensor]; | |||
1082 | auto map = op.getMatchingIndexingMap(t); | |||
1083 | auto enc = getSparseTensorEncoding(t->get().getType()); | |||
1084 | for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { | |||
1085 | AffineExpr a = map.getResult(toOrigDim(enc, d)); | |||
1086 | if (!isInvariantAffine(codegen, a, ldx, atLevel)) | |||
1087 | return; // still in play | |||
1088 | } | |||
1089 | // All exhausted at this level (atLevel denotes exactly at this level). | |||
1090 | if (!atLevel) | |||
1091 | return; | |||
1092 | OpOperand *lhs = op.getOutputOperand(0); | |||
1093 | if (lhs == t) { | |||
1094 | // Start or end a scalarized reduction | |||
1095 | if (atStart) { | |||
1096 | Kind kind = merger.exp(last).kind; | |||
1097 | Value load = kind == Kind::kReduce | |||
1098 | ? getCustomRedId(merger.exp(last).op) | |||
1099 | : genTensorLoad(merger, codegen, builder, op, exp); | |||
1100 | codegen.redKind = getReduction(kind); | |||
1101 | codegen.redExp = exp; | |||
1102 | updateReduc(merger, codegen, load); | |||
1103 | } else { | |||
1104 | Value redVal = codegen.redVal; | |||
1105 | updateReduc(merger, codegen, Value()); | |||
1106 | codegen.redExp = -1u; | |||
1107 | codegen.redKind = kNoReduc; | |||
1108 | genTensorStore(merger, codegen, builder, op, exp, redVal); | |||
1109 | } | |||
1110 | } else { | |||
1111 | // Start or end loop invariant hoisting of a tensor load. | |||
1112 | merger.exp(exp).val = | |||
1113 | atStart ? genTensorLoad(merger, codegen, builder, op, exp) : Value(); | |||
1114 | } | |||
1115 | } else if (merger.exp(exp).kind != Kind::kInvariant && | |||
1116 | merger.exp(exp).kind != Kind::kIndex) { | |||
1117 | // Traverse into the binary operations. Note that we only hoist | |||
1118 | // tensor loads, since subsequent MLIR/LLVM passes know how to | |||
1119 | // deal with all other kinds of derived loop invariants. | |||
1120 | unsigned e0 = merger.exp(exp).children.e0; | |||
1121 | unsigned e1 = merger.exp(exp).children.e1; | |||
1122 | genInvariants(merger, codegen, builder, op, e0, ldx, atStart, exp); | |||
1123 | genInvariants(merger, codegen, builder, op, e1, ldx, atStart, exp); | |||
1124 | } | |||
1125 | } | |||
1126 | ||||
1127 | /// Generates an expanded access pattern in innermost dimension. | |||
1128 | static void genExpansion(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1129 | linalg::GenericOp op, unsigned at, bool atStart) { | |||
1130 | OpOperand *lhs = codegen.sparseOut; | |||
1131 | if (!lhs || codegen.outerParNest != op.getRank(lhs) - 1 || | |||
1132 | at != codegen.outerParNest) | |||
1133 | return; // not needed at this level | |||
1134 | // Generate start or end of an expanded access pattern. | |||
1135 | Value tensor = lhs->get(); | |||
1136 | Location loc = op.getLoc(); | |||
1137 | if (atStart) { | |||
1138 | auto dynShape = {ShapedType::kDynamicSize}; | |||
1139 | Type etp = tensor.getType().cast<ShapedType>().getElementType(); | |||
1140 | Type t1 = MemRefType::get(dynShape, etp); | |||
1141 | Type t2 = MemRefType::get(dynShape, builder.getI1Type()); | |||
1142 | Type t3 = MemRefType::get(dynShape, builder.getIndexType()); | |||
1143 | Type t4 = builder.getIndexType(); | |||
1144 | auto res = | |||
1145 | builder.create<ExpandOp>(loc, TypeRange({t1, t2, t3, t4}), tensor); | |||
1146 | assert(res.getNumResults() == 4)(static_cast <bool> (res.getNumResults() == 4) ? void ( 0) : __assert_fail ("res.getNumResults() == 4", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1146, __extension__ __PRETTY_FUNCTION__)); | |||
1147 | assert(!codegen.expValues)(static_cast <bool> (!codegen.expValues) ? void (0) : __assert_fail ("!codegen.expValues", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1147, __extension__ __PRETTY_FUNCTION__)); | |||
1148 | codegen.expValues = res.getResult(0); | |||
1149 | codegen.expFilled = res.getResult(1); | |||
1150 | codegen.expAdded = res.getResult(2); | |||
1151 | codegen.expCount = res.getResult(3); | |||
1152 | } else { | |||
1153 | assert(codegen.expValues)(static_cast <bool> (codegen.expValues) ? void (0) : __assert_fail ("codegen.expValues", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1153, __extension__ __PRETTY_FUNCTION__)); | |||
1154 | SmallVector<Value, 4> indices; | |||
1155 | for (unsigned i = 0; i < at; i++) { | |||
1156 | assert(codegen.loops[codegen.topSort[i]])(static_cast <bool> (codegen.loops[codegen.topSort[i]]) ? void (0) : __assert_fail ("codegen.loops[codegen.topSort[i]]" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1156, __extension__ __PRETTY_FUNCTION__)); | |||
1157 | indices.push_back(codegen.loops[codegen.topSort[i]]); | |||
1158 | } | |||
1159 | builder.create<CompressOp>(loc, codegen.expValues, codegen.expFilled, | |||
1160 | codegen.expAdded, codegen.expCount, tensor, | |||
1161 | indices); | |||
1162 | codegen.expValues = codegen.expFilled = codegen.expAdded = | |||
1163 | codegen.expCount = Value(); | |||
1164 | } | |||
1165 | } | |||
1166 | ||||
1167 | /// Generates initialization code for the subsequent loop sequence at | |||
1168 | /// current index level. Returns true if the loop sequence needs to | |||
1169 | /// maintain the universal index. | |||
1170 | static bool genInit(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1171 | linalg::GenericOp op, unsigned at, BitVector &inits) { | |||
1172 | std::vector<unsigned> &topSort(codegen.topSort); | |||
1173 | bool needsUniv = false; | |||
1174 | Location loc = op.getLoc(); | |||
1175 | unsigned idx = topSort[at]; | |||
1176 | ||||
1177 | // Initialize sparse positions. | |||
1178 | for (unsigned b = 0, be = inits.size(); b < be; b++) { | |||
1179 | if (!inits[b]) | |||
1180 | continue; | |||
1181 | unsigned tensor = merger.tensor(b); | |||
1182 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1182, __extension__ __PRETTY_FUNCTION__)); | |||
1183 | if (merger.isDimLevelType(b, DimLvlType::kCompressed)) { | |||
1184 | // Initialize sparse index that will implement the iteration: | |||
1185 | // for pidx_idx = pointers(pidx_idx-1), pointers(1+pidx_idx-1) | |||
1186 | unsigned pat = at; | |||
1187 | for (; pat != 0; pat--) { | |||
1188 | if (codegen.pidxs[tensor][topSort[pat - 1]]) | |||
1189 | break; | |||
1190 | } | |||
1191 | Value ptr = codegen.pointers[tensor][idx]; | |||
1192 | Value one = constantIndex(builder, loc, 1); | |||
1193 | Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) | |||
1194 | : codegen.pidxs[tensor][topSort[pat - 1]]; | |||
1195 | codegen.pidxs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p0); | |||
1196 | Value p1 = builder.create<arith::AddIOp>(loc, p0, one); | |||
1197 | codegen.highs[tensor][idx] = genLoad(codegen, builder, loc, ptr, p1); | |||
1198 | } else if (merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1199 | // Initialize sparse index that will implement the "iteration": | |||
1200 | // for pidx_idx = pidx_idx-1, 1+pidx_idx-1 | |||
1201 | // We rely on subsequent loop unrolling to get rid of the loop | |||
1202 | // if it is not involved in co-iteration with anything else. | |||
1203 | unsigned pat = at; | |||
1204 | for (; pat != 0; pat--) { | |||
1205 | if (codegen.pidxs[tensor][topSort[pat - 1]]) | |||
1206 | break; | |||
1207 | } | |||
1208 | Value one = constantIndex(builder, loc, 1); | |||
1209 | Value p0 = (pat == 0) ? constantIndex(builder, loc, 0) | |||
1210 | : codegen.pidxs[tensor][topSort[pat - 1]]; | |||
1211 | codegen.pidxs[tensor][idx] = p0; | |||
1212 | codegen.highs[tensor][idx] = builder.create<arith::AddIOp>(loc, p0, one); | |||
1213 | } else { | |||
1214 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1215, __extension__ __PRETTY_FUNCTION__)) | |||
1215 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1215, __extension__ __PRETTY_FUNCTION__)); | |||
1216 | // Dense index still in play. | |||
1217 | needsUniv = true; | |||
1218 | } | |||
1219 | } | |||
1220 | ||||
1221 | // Initialize the universal dense index. | |||
1222 | codegen.loops[idx] = constantIndex(builder, loc, 0); | |||
1223 | return needsUniv; | |||
1224 | } | |||
1225 | ||||
1226 | /// Returns vectorization strategy. Any implicit inner loop in the Linalg | |||
1227 | /// operation is a candidate. Whether it is actually converted to SIMD code | |||
1228 | /// depends on the requested strategy. | |||
1229 | static bool isVectorFor(CodeGen &codegen, bool isInner, bool isReduction, | |||
1230 | bool isSparse) { | |||
1231 | // Reject vectorization of sparse output, unless innermost is reduction. | |||
1232 | if (codegen.sparseOut && !isReduction) | |||
1233 | return false; | |||
1234 | // Inspect strategy. | |||
1235 | switch (codegen.options.vectorizationStrategy) { | |||
1236 | case SparseVectorizationStrategy::kNone: | |||
1237 | return false; | |||
1238 | case SparseVectorizationStrategy::kDenseInnerLoop: | |||
1239 | return isInner && !isSparse; | |||
1240 | case SparseVectorizationStrategy::kAnyStorageInnerLoop: | |||
1241 | return isInner; | |||
1242 | } | |||
1243 | llvm_unreachable("unexpected vectorization strategy")::llvm::llvm_unreachable_internal("unexpected vectorization strategy" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1243); | |||
1244 | } | |||
1245 | ||||
1246 | /// Returns parallelization strategy. Any implicit loop in the Linalg operation | |||
1247 | /// that is marked "parallel" is a candidate. Whether it is actually converted | |||
1248 | /// to a parallel operation depends on the requested strategy. | |||
1249 | static bool isParallelFor(CodeGen &codegen, bool isOuter, bool isReduction, | |||
1250 | bool isSparse, bool isVector) { | |||
1251 | // Reject parallelization of sparse output. | |||
1252 | if (codegen.sparseOut) | |||
1253 | return false; | |||
1254 | // Inspect strategy. | |||
1255 | switch (codegen.options.parallelizationStrategy) { | |||
1256 | case SparseParallelizationStrategy::kNone: | |||
1257 | return false; | |||
1258 | case SparseParallelizationStrategy::kDenseOuterLoop: | |||
1259 | return isOuter && !isSparse && !isReduction && !isVector; | |||
1260 | case SparseParallelizationStrategy::kAnyStorageOuterLoop: | |||
1261 | return isOuter && !isReduction && !isVector; | |||
1262 | case SparseParallelizationStrategy::kDenseAnyLoop: | |||
1263 | return !isSparse && !isReduction && !isVector; | |||
1264 | case SparseParallelizationStrategy::kAnyStorageAnyLoop: | |||
1265 | return !isReduction && !isVector; | |||
1266 | } | |||
1267 | llvm_unreachable("unexpected parallelization strategy")::llvm::llvm_unreachable_internal("unexpected parallelization strategy" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1267); | |||
1268 | } | |||
1269 | ||||
1270 | /// Checks unit stride for dense tensors. The iteration graph may have ignored | |||
1271 | /// dense access patterns in order to avoid cycles (sparse access patterns are | |||
1272 | /// always placed innermost), but that means dense access has become strided. | |||
1273 | /// This prevents effective vectorization. | |||
1274 | static bool denseUnitStrides(Merger &merger, linalg::GenericOp op, | |||
1275 | unsigned idx) { | |||
1276 | for (OpOperand *t : op.getInputAndOutputOperands()) { | |||
1277 | if (!getSparseTensorEncoding(t->get().getType())) { | |||
1278 | auto map = op.getMatchingIndexingMap(t); | |||
1279 | for (unsigned d = 0, rank = map.getNumResults(); d < rank; d++) { | |||
1280 | AffineExpr a = map.getResult(d); | |||
1281 | // Report non-unit stride if innermost index appears at an outer | |||
1282 | // dimension (true non-unit stride) or if the innermost index appears | |||
1283 | // in a compound subscript in the innermost dimension. Even if the | |||
1284 | // latter is unit stride, it does not play well with scatter/gather. | |||
1285 | // TODO: accept unit stride affine innermost like a[i,j+k+1]? | |||
1286 | if (a.isFunctionOfDim(idx) && | |||
1287 | ((d != rank - 1) || (a.getKind() != AffineExprKind::DimId))) | |||
1288 | return false; | |||
1289 | } | |||
1290 | } | |||
1291 | } | |||
1292 | return true; | |||
1293 | } | |||
1294 | ||||
1295 | /// Generates a for-loop on a single index. | |||
1296 | static Operation *genFor(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1297 | linalg::GenericOp op, bool isOuter, bool isInner, | |||
1298 | unsigned idx, BitVector &indices) { | |||
1299 | unsigned fb = indices.find_first(); | |||
1300 | unsigned tensor = merger.tensor(fb); | |||
1301 | assert(idx == merger.index(fb))(static_cast <bool> (idx == merger.index(fb)) ? void (0 ) : __assert_fail ("idx == merger.index(fb)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1301, __extension__ __PRETTY_FUNCTION__)); | |||
1302 | auto iteratorTypes = op.getIteratorTypesArray(); | |||
1303 | bool isReduction = linalg::isReductionIterator(iteratorTypes[idx]); | |||
1304 | bool isSparse = merger.isDimLevelType(fb, DimLvlType::kCompressed) || | |||
1305 | merger.isDimLevelType(fb, DimLvlType::kSingleton); | |||
1306 | bool isVector = isVectorFor(codegen, isInner, isReduction, isSparse) && | |||
1307 | denseUnitStrides(merger, op, idx); | |||
1308 | bool isParallel = | |||
1309 | isParallelFor(codegen, isOuter, isReduction, isSparse, isVector); | |||
1310 | ||||
1311 | // Prepare vector length. | |||
1312 | if (isVector) | |||
1313 | codegen.curVecLength = codegen.options.vectorLength; | |||
1314 | ||||
1315 | // Loop bounds and increment. | |||
1316 | Location loc = op.getLoc(); | |||
1317 | Value lo = isSparse ? codegen.pidxs[tensor][idx] : codegen.loops[idx]; | |||
1318 | Value hi = isSparse ? codegen.highs[tensor][idx] : codegen.sizes[idx]; | |||
1319 | Value step = constantIndex(builder, loc, codegen.curVecLength); | |||
1320 | if (isVector && codegen.options.enableVLAVectorization) { | |||
1321 | Value vscale = builder.create<vector::VectorScaleOp>( | |||
1322 | loc, IndexType::get(builder.getContext())); | |||
1323 | step = builder.create<arith::MulIOp>(loc, vscale, step); | |||
1324 | } | |||
1325 | ||||
1326 | // Emit a parallel loop. | |||
1327 | if (isParallel) { | |||
1328 | assert(!isVector)(static_cast <bool> (!isVector) ? void (0) : __assert_fail ("!isVector", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1328, __extension__ __PRETTY_FUNCTION__)); | |||
1329 | scf::ParallelOp parOp = builder.create<scf::ParallelOp>(loc, lo, hi, step); | |||
1330 | if (isSparse) | |||
1331 | codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0]; | |||
1332 | else | |||
1333 | codegen.loops[idx] = parOp.getInductionVars()[0]; | |||
1334 | builder.setInsertionPointToStart(parOp.getBody()); | |||
1335 | return parOp; | |||
1336 | } | |||
1337 | ||||
1338 | // Emit a sequential or vector loop. | |||
1339 | SmallVector<Value, 4> operands; | |||
1340 | if (codegen.redVal) { | |||
1341 | // In a vector loop, bring reduction into SIMD form, if not already. | |||
1342 | if (isVector && !codegen.redVal.getType().isa<VectorType>()) { | |||
1343 | VectorType vtp = vectorType(codegen, codegen.redVal.getType()); | |||
1344 | Value vred = genVectorReducInit(codegen, builder, loc, vtp); | |||
1345 | updateReduc(merger, codegen, vred); | |||
1346 | } | |||
1347 | operands.push_back(codegen.redVal); | |||
1348 | } | |||
1349 | if (codegen.expValues) | |||
1350 | operands.push_back(codegen.expCount); | |||
1351 | scf::ForOp forOp = builder.create<scf::ForOp>(loc, lo, hi, step, operands); | |||
1352 | if (codegen.redVal) | |||
1353 | updateReduc(merger, codegen, forOp.getRegionIterArgs().front()); | |||
1354 | if (codegen.expValues) | |||
1355 | codegen.expCount = forOp.getRegionIterArgs().back(); | |||
1356 | // Assign induction variable to sparse or dense index. | |||
1357 | Value iv = forOp.getInductionVar(); | |||
1358 | if (isSparse) | |||
1359 | codegen.pidxs[tensor][idx] = iv; | |||
1360 | else | |||
1361 | codegen.loops[idx] = iv; | |||
1362 | builder.setInsertionPointToStart(forOp.getBody()); | |||
1363 | // Share vector iteration mask between all subsequent loads/stores. | |||
1364 | if (isVector) | |||
1365 | codegen.curVecMask = genVectorMask(codegen, builder, iv, lo, hi, step); | |||
1366 | return forOp; | |||
1367 | } | |||
1368 | ||||
1369 | /// Emit a while-loop for co-iteration over multiple indices. | |||
1370 | static Operation *genWhile(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1371 | linalg::GenericOp op, unsigned idx, bool needsUniv, | |||
1372 | BitVector &indices) { | |||
1373 | SmallVector<Type, 4> types; | |||
1374 | SmallVector<Value, 4> operands; | |||
1375 | // Construct the while-loop with a parameter for each index. | |||
1376 | Type indexType = builder.getIndexType(); | |||
1377 | for (unsigned b = 0, be = indices.size(); b < be; b++) { | |||
1378 | if (!indices[b]) | |||
1379 | continue; | |||
1380 | if (merger.isDimLevelType(b, DimLvlType::kCompressed) || | |||
1381 | merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1382 | unsigned tensor = merger.tensor(b); | |||
1383 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1383, __extension__ __PRETTY_FUNCTION__)); | |||
1384 | types.push_back(indexType); | |||
1385 | operands.push_back(codegen.pidxs[tensor][idx]); | |||
1386 | } else { | |||
1387 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1388, __extension__ __PRETTY_FUNCTION__)) | |||
1388 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1388, __extension__ __PRETTY_FUNCTION__)); | |||
1389 | } | |||
1390 | } | |||
1391 | if (codegen.redVal) { | |||
1392 | types.push_back(codegen.redVal.getType()); | |||
1393 | operands.push_back(codegen.redVal); | |||
1394 | } | |||
1395 | if (codegen.expValues) { | |||
1396 | types.push_back(indexType); | |||
1397 | operands.push_back(codegen.expCount); | |||
1398 | } | |||
1399 | if (needsUniv) { | |||
1400 | types.push_back(indexType); | |||
1401 | operands.push_back(codegen.loops[idx]); | |||
1402 | } | |||
1403 | assert(types.size() == operands.size())(static_cast <bool> (types.size() == operands.size()) ? void (0) : __assert_fail ("types.size() == operands.size()", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1403, __extension__ __PRETTY_FUNCTION__)); | |||
1404 | Location loc = op.getLoc(); | |||
1405 | scf::WhileOp whileOp = builder.create<scf::WhileOp>(loc, types, operands); | |||
1406 | ||||
1407 | SmallVector<Location> locs(types.size(), loc); | |||
1408 | Block *before = builder.createBlock(&whileOp.getBefore(), {}, types, locs); | |||
1409 | Block *after = builder.createBlock(&whileOp.getAfter(), {}, types, locs); | |||
1410 | ||||
1411 | // Build the "before" region, which effectively consists | |||
1412 | // of a conjunction of "i < upper" tests on all induction. | |||
1413 | builder.setInsertionPointToStart(&whileOp.getBefore().front()); | |||
1414 | Value cond; | |||
1415 | unsigned o = 0; | |||
1416 | for (unsigned b = 0, be = indices.size(); b < be; b++) { | |||
1417 | if (!indices[b]) | |||
1418 | continue; | |||
1419 | if (merger.isDimLevelType(b, DimLvlType::kCompressed) || | |||
1420 | merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1421 | unsigned tensor = merger.tensor(b); | |||
1422 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1422, __extension__ __PRETTY_FUNCTION__)); | |||
1423 | Value op1 = before->getArgument(o); | |||
1424 | Value op2 = codegen.highs[tensor][idx]; | |||
1425 | Value opc = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ult, | |||
1426 | op1, op2); | |||
1427 | cond = cond ? builder.create<arith::AndIOp>(loc, cond, opc) : opc; | |||
1428 | codegen.pidxs[tensor][idx] = after->getArgument(o++); | |||
1429 | } else { | |||
1430 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1431, __extension__ __PRETTY_FUNCTION__)) | |||
1431 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1431, __extension__ __PRETTY_FUNCTION__)); | |||
1432 | } | |||
1433 | } | |||
1434 | if (codegen.redVal) | |||
1435 | updateReduc(merger, codegen, after->getArgument(o++)); | |||
1436 | if (codegen.expValues) | |||
1437 | codegen.expCount = after->getArgument(o++); | |||
1438 | if (needsUniv) | |||
1439 | codegen.loops[idx] = after->getArgument(o++); | |||
1440 | assert(o == operands.size())(static_cast <bool> (o == operands.size()) ? void (0) : __assert_fail ("o == operands.size()", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1440, __extension__ __PRETTY_FUNCTION__)); | |||
1441 | builder.create<scf::ConditionOp>(loc, cond, before->getArguments()); | |||
1442 | builder.setInsertionPointToStart(&whileOp.getAfter().front()); | |||
1443 | return whileOp; | |||
1444 | } | |||
1445 | ||||
1446 | /// Generates a for-loop or a while-loop, depending on whether it implements | |||
1447 | /// singleton iteration or co-iteration over the given conjunction. | |||
1448 | static Operation *genLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1449 | linalg::GenericOp op, unsigned at, bool needsUniv, | |||
1450 | BitVector &indices) { | |||
1451 | unsigned idx = codegen.topSort[at]; | |||
1452 | if (indices.count() == 1) { | |||
1453 | bool isOuter = at == 0; | |||
1454 | bool isInner = at == codegen.topSort.size() - 1; | |||
1455 | return genFor(merger, codegen, builder, op, isOuter, isInner, idx, indices); | |||
1456 | } | |||
1457 | return genWhile(merger, codegen, builder, op, idx, needsUniv, indices); | |||
1458 | } | |||
1459 | ||||
1460 | /// Generates the local variables for this loop, consisting of the sparse | |||
1461 | /// indices, restored universal dense index, and dense positions. | |||
1462 | static void genLocals(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1463 | linalg::GenericOp op, unsigned at, bool needsUniv, | |||
1464 | BitVector &locals) { | |||
1465 | std::vector<unsigned> &topSort(codegen.topSort); | |||
1466 | Location loc = op.getLoc(); | |||
1467 | unsigned idx = topSort[at]; | |||
1468 | ||||
1469 | // Initialize sparse indices. | |||
1470 | Value min; | |||
1471 | for (unsigned b = 0, be = locals.size(); b < be; b++) { | |||
1472 | if (!locals[b]) | |||
1473 | continue; | |||
1474 | if (merger.isDimLevelType(b, DimLvlType::kCompressed) || | |||
1475 | merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1476 | unsigned tensor = merger.tensor(b); | |||
1477 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1477, __extension__ __PRETTY_FUNCTION__)); | |||
1478 | Value ptr = codegen.indices[tensor][idx]; | |||
1479 | Value s = codegen.pidxs[tensor][idx]; | |||
1480 | Value load = genLoad(codegen, builder, loc, ptr, s); | |||
1481 | codegen.idxs[tensor][idx] = load; | |||
1482 | if (!needsUniv) { | |||
1483 | if (min) { | |||
1484 | Value cmp = builder.create<arith::CmpIOp>( | |||
1485 | loc, arith::CmpIPredicate::ult, load, min); | |||
1486 | min = builder.create<arith::SelectOp>(loc, cmp, load, min); | |||
1487 | } else { | |||
1488 | min = load; | |||
1489 | } | |||
1490 | } | |||
1491 | } else { | |||
1492 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1493, __extension__ __PRETTY_FUNCTION__)) | |||
1493 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1493, __extension__ __PRETTY_FUNCTION__)); | |||
1494 | } | |||
1495 | } | |||
1496 | ||||
1497 | // Merge dense universal index over minimum. | |||
1498 | if (min) { | |||
1499 | assert(!needsUniv)(static_cast <bool> (!needsUniv) ? void (0) : __assert_fail ("!needsUniv", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1499, __extension__ __PRETTY_FUNCTION__)); | |||
1500 | codegen.loops[idx] = min; | |||
1501 | } | |||
1502 | ||||
1503 | // Initialize dense positions. Note that we generate dense indices of the | |||
1504 | // output tensor unconditionally, since they may not appear in the lattice, | |||
1505 | // but may be needed for linearized codegen. | |||
1506 | for (unsigned b = 0, be = locals.size(); b < be; b++) { | |||
1507 | if ((locals[b] || merger.isOutTensor(b, idx)) && | |||
1508 | merger.isDimLevelType(b, DimLvlType::kDense)) { | |||
1509 | unsigned tensor = merger.tensor(b); | |||
1510 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1510, __extension__ __PRETTY_FUNCTION__)); | |||
1511 | unsigned pat = at; | |||
1512 | for (; pat != 0; pat--) | |||
1513 | if (codegen.pidxs[tensor][topSort[pat - 1]]) | |||
1514 | break; | |||
1515 | Value p = (pat == 0) ? constantIndex(builder, loc, 0) | |||
1516 | : codegen.pidxs[tensor][topSort[pat - 1]]; | |||
1517 | codegen.pidxs[tensor][idx] = genAddress( | |||
1518 | codegen, builder, loc, codegen.sizes[idx], p, codegen.loops[idx]); | |||
1519 | } | |||
1520 | } | |||
1521 | } | |||
1522 | ||||
1523 | /// Generates the induction structure for a while-loop. | |||
1524 | static void genWhileInduction(Merger &merger, CodeGen &codegen, | |||
1525 | OpBuilder &builder, linalg::GenericOp op, | |||
1526 | unsigned idx, bool needsUniv, | |||
1527 | BitVector &induction, scf::WhileOp whileOp) { | |||
1528 | Location loc = op.getLoc(); | |||
1529 | // Finalize each else branch of all if statements. | |||
1530 | if (codegen.redVal || codegen.expValues) { | |||
1531 | while (auto ifOp = dyn_cast_or_null<scf::IfOp>( | |||
1532 | builder.getInsertionBlock()->getParentOp())) { | |||
1533 | unsigned y = 0; | |||
1534 | SmallVector<Value, 4> yields; | |||
1535 | if (codegen.redVal) { | |||
1536 | yields.push_back(codegen.redVal); | |||
1537 | updateReduc(merger, codegen, ifOp.getResult(y++)); | |||
1538 | } | |||
1539 | if (codegen.expValues) { | |||
1540 | yields.push_back(codegen.expCount); | |||
1541 | codegen.expCount = ifOp->getResult(y++); | |||
1542 | } | |||
1543 | assert(y == yields.size())(static_cast <bool> (y == yields.size()) ? void (0) : __assert_fail ("y == yields.size()", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1543, __extension__ __PRETTY_FUNCTION__)); | |||
1544 | builder.create<scf::YieldOp>(loc, yields); | |||
1545 | builder.setInsertionPointAfter(ifOp); | |||
1546 | } | |||
1547 | } | |||
1548 | builder.setInsertionPointToEnd(&whileOp.getAfter().front()); | |||
1549 | // Finalize the induction. Note that the induction could be performed | |||
1550 | // in the individual if-branches to avoid re-evaluating the conditions. | |||
1551 | // However, that would result in a rather elaborate forest of yield | |||
1552 | // instructions during code generation. Moreover, performing the induction | |||
1553 | // after the if-statements more closely resembles code generated by TACO. | |||
1554 | unsigned o = 0; | |||
1555 | SmallVector<Value, 4> operands; | |||
1556 | Value one = constantIndex(builder, loc, 1); | |||
1557 | for (unsigned b = 0, be = induction.size(); b < be; b++) { | |||
1558 | if (!induction[b]) | |||
1559 | continue; | |||
1560 | if (merger.isDimLevelType(b, DimLvlType::kCompressed) || | |||
1561 | merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1562 | unsigned tensor = merger.tensor(b); | |||
1563 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1563, __extension__ __PRETTY_FUNCTION__)); | |||
1564 | Value op1 = codegen.idxs[tensor][idx]; | |||
1565 | Value op2 = codegen.loops[idx]; | |||
1566 | Value op3 = codegen.pidxs[tensor][idx]; | |||
1567 | Value cmp = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, | |||
1568 | op1, op2); | |||
1569 | Value add = builder.create<arith::AddIOp>(loc, op3, one); | |||
1570 | operands.push_back(builder.create<arith::SelectOp>(loc, cmp, add, op3)); | |||
1571 | codegen.pidxs[tensor][idx] = whileOp->getResult(o++); | |||
1572 | } else { | |||
1573 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1574, __extension__ __PRETTY_FUNCTION__)) | |||
1574 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1574, __extension__ __PRETTY_FUNCTION__)); | |||
1575 | } | |||
1576 | } | |||
1577 | if (codegen.redVal) { | |||
1578 | operands.push_back(codegen.redVal); | |||
1579 | updateReduc(merger, codegen, whileOp->getResult(o++)); | |||
1580 | } | |||
1581 | if (codegen.expValues) { | |||
1582 | operands.push_back(codegen.expCount); | |||
1583 | codegen.expCount = whileOp->getResult(o++); | |||
1584 | } | |||
1585 | if (needsUniv) { | |||
1586 | operands.push_back( | |||
1587 | builder.create<arith::AddIOp>(loc, codegen.loops[idx], one)); | |||
1588 | codegen.loops[idx] = whileOp->getResult(o++); | |||
1589 | } | |||
1590 | assert(o == operands.size())(static_cast <bool> (o == operands.size()) ? void (0) : __assert_fail ("o == operands.size()", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1590, __extension__ __PRETTY_FUNCTION__)); | |||
1591 | builder.create<scf::YieldOp>(loc, operands); | |||
1592 | builder.setInsertionPointAfter(whileOp); | |||
1593 | } | |||
1594 | ||||
1595 | /// Generates the induction structure for a for-loop. | |||
1596 | static void genForInduction(Merger &merger, CodeGen &codegen, | |||
1597 | OpBuilder &builder, linalg::GenericOp op, | |||
1598 | Operation *loop) { | |||
1599 | Location loc = op.getLoc(); | |||
1600 | unsigned o = 0; | |||
1601 | SmallVector<Value, 4> operands; | |||
1602 | if (codegen.redVal) { | |||
1603 | operands.push_back(codegen.redVal); | |||
1604 | updateReduc(merger, codegen, loop->getResult(o++)); | |||
1605 | } | |||
1606 | if (codegen.expValues) { | |||
1607 | operands.push_back(codegen.expCount); | |||
1608 | codegen.expCount = loop->getResult(o++); | |||
1609 | } | |||
1610 | assert(o == operands.size())(static_cast <bool> (o == operands.size()) ? void (0) : __assert_fail ("o == operands.size()", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1610, __extension__ __PRETTY_FUNCTION__)); | |||
1611 | if (o > 0) | |||
1612 | builder.create<scf::YieldOp>(loc, operands); | |||
1613 | builder.setInsertionPointAfter(loop); | |||
1614 | } | |||
1615 | ||||
1616 | /// Generates a single if-statement within a while-loop. | |||
1617 | static scf::IfOp genIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1618 | linalg::GenericOp op, unsigned idx, | |||
1619 | BitVector &conditions) { | |||
1620 | Location loc = op.getLoc(); | |||
1621 | SmallVector<Type, 4> types; | |||
1622 | Value cond; | |||
1623 | for (unsigned b = 0, be = conditions.size(); b < be; b++) { | |||
1624 | if (!conditions[b]) | |||
1625 | continue; | |||
1626 | unsigned tensor = merger.tensor(b); | |||
1627 | assert(idx == merger.index(b))(static_cast <bool> (idx == merger.index(b)) ? void (0) : __assert_fail ("idx == merger.index(b)", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1627, __extension__ __PRETTY_FUNCTION__)); | |||
1628 | Value clause; | |||
1629 | if (merger.isDimLevelType(b, DimLvlType::kCompressed) || | |||
1630 | merger.isDimLevelType(b, DimLvlType::kSingleton)) { | |||
1631 | Value op1 = codegen.idxs[tensor][idx]; | |||
1632 | Value op2 = codegen.loops[idx]; | |||
1633 | clause = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, op1, | |||
1634 | op2); | |||
1635 | } else { | |||
1636 | assert(merger.isDimLevelType(b, DimLvlType::kDense) ||(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1637, __extension__ __PRETTY_FUNCTION__)) | |||
1637 | merger.isDimLevelType(b, DimLvlType::kUndef))(static_cast <bool> (merger.isDimLevelType(b, DimLvlType ::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)) ? void (0) : __assert_fail ("merger.isDimLevelType(b, DimLvlType::kDense) || merger.isDimLevelType(b, DimLvlType::kUndef)" , "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1637, __extension__ __PRETTY_FUNCTION__)); | |||
1638 | clause = constantI1(builder, loc, true); | |||
1639 | } | |||
1640 | cond = cond ? builder.create<arith::AndIOp>(loc, cond, clause) : clause; | |||
1641 | } | |||
1642 | if (codegen.redVal) | |||
1643 | types.push_back(codegen.redVal.getType()); | |||
1644 | if (codegen.expValues) | |||
1645 | types.push_back(builder.getIndexType()); | |||
1646 | scf::IfOp ifOp = builder.create<scf::IfOp>(loc, types, cond, /*else=*/true); | |||
1647 | builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); | |||
1648 | return ifOp; | |||
1649 | } | |||
1650 | ||||
1651 | /// Generates end of true branch of if-statement within a while-loop. | |||
1652 | static void endIf(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1653 | linalg::GenericOp op, scf::IfOp ifOp, Operation *loop, | |||
1654 | Value redInput, Value cntInput) { | |||
1655 | SmallVector<Value, 4> operands; | |||
1656 | if (codegen.redVal) { | |||
1657 | operands.push_back(codegen.redVal); | |||
1658 | updateReduc(merger, codegen, redInput); | |||
1659 | } | |||
1660 | if (codegen.expValues) { | |||
1661 | operands.push_back(codegen.expCount); | |||
1662 | codegen.expCount = cntInput; | |||
1663 | } | |||
1664 | if (!operands.empty()) | |||
1665 | builder.create<scf::YieldOp>(op.getLoc(), operands); | |||
1666 | builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); | |||
1667 | } | |||
1668 | ||||
1669 | //===----------------------------------------------------------------------===// | |||
1670 | // Sparse compiler synthesis methods (loop sequence). | |||
1671 | //===----------------------------------------------------------------------===// | |||
1672 | ||||
1673 | /// Starts a loop sequence at given level. Returns true if | |||
1674 | /// the universal loop index must be maintained at this level. | |||
1675 | static bool startLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1676 | linalg::GenericOp op, unsigned exp, unsigned at, | |||
1677 | unsigned idx, unsigned ldx, unsigned lts) { | |||
1678 | assert(codegen.curVecLength == 1)(static_cast <bool> (codegen.curVecLength == 1) ? void ( 0) : __assert_fail ("codegen.curVecLength == 1", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1678, __extension__ __PRETTY_FUNCTION__)); | |||
1679 | assert(!codegen.loops[idx])(static_cast <bool> (!codegen.loops[idx]) ? void (0) : __assert_fail ("!codegen.loops[idx]", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1679, __extension__ __PRETTY_FUNCTION__)); | |||
1680 | // Emit invariants at this loop sequence level. | |||
1681 | genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/true); | |||
1682 | // Emit access pattern expansion for sparse tensor output. | |||
1683 | genExpansion(merger, codegen, builder, op, at, /*atStart=*/true); | |||
1684 | // Emit further intitialization at this loop sequence level. | |||
1685 | unsigned l0 = merger.set(lts)[0]; | |||
1686 | bool needsUniv = | |||
1687 | genInit(merger, codegen, builder, op, at, merger.lat(l0).bits); | |||
1688 | // Maintain the universal index only if it is actually | |||
1689 | // consumed by a subsequent lattice point. | |||
1690 | if (needsUniv) { | |||
1691 | unsigned lsize = merger.set(lts).size(); | |||
1692 | for (unsigned i = 1; i < lsize; i++) { | |||
1693 | unsigned li = merger.set(lts)[i]; | |||
1694 | if (!merger.hasAnySparse(merger.lat(li).simple)) | |||
1695 | return true; | |||
1696 | } | |||
1697 | } | |||
1698 | return false; | |||
1699 | } | |||
1700 | ||||
1701 | /// Starts a single loop in current sequence. | |||
1702 | static Operation *startLoop(Merger &merger, CodeGen &codegen, | |||
1703 | OpBuilder &builder, linalg::GenericOp op, | |||
1704 | unsigned at, unsigned li, bool needsUniv) { | |||
1705 | assert(codegen.curVecLength == 1)(static_cast <bool> (codegen.curVecLength == 1) ? void ( 0) : __assert_fail ("codegen.curVecLength == 1", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1705, __extension__ __PRETTY_FUNCTION__)); | |||
1706 | // Emit the for/while-loop control. | |||
1707 | Operation *loop = genLoop(merger, codegen, builder, op, at, needsUniv, | |||
1708 | merger.lat(li).simple); | |||
1709 | // Emit the locals for this loop. | |||
1710 | genLocals(merger, codegen, builder, op, at, needsUniv, merger.lat(li).bits); | |||
1711 | return loop; | |||
1712 | } | |||
1713 | ||||
1714 | /// Ends a single loop in current sequence. Returns new values for needsUniv. | |||
1715 | static bool endLoop(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1716 | linalg::GenericOp op, Operation *loop, unsigned idx, | |||
1717 | unsigned li, bool needsUniv) { | |||
1718 | codegen.curVecLength = 1; | |||
1719 | // End a while-loop. | |||
1720 | if (auto whileOp = dyn_cast<scf::WhileOp>(loop)) { | |||
1721 | genWhileInduction(merger, codegen, builder, op, idx, needsUniv, | |||
1722 | merger.lat(li).bits, whileOp); | |||
1723 | return needsUniv; | |||
1724 | } | |||
1725 | // End a for-loop. | |||
1726 | genForInduction(merger, codegen, builder, op, loop); | |||
1727 | return false; | |||
1728 | } | |||
1729 | ||||
1730 | /// Ends a loop sequence at given level. | |||
1731 | static void endLoopSeq(Merger &merger, CodeGen &codegen, OpBuilder &builder, | |||
1732 | linalg::GenericOp op, unsigned exp, unsigned at, | |||
1733 | unsigned idx, unsigned ldx) { | |||
1734 | assert(codegen.curVecLength == 1)(static_cast <bool> (codegen.curVecLength == 1) ? void ( 0) : __assert_fail ("codegen.curVecLength == 1", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1734, __extension__ __PRETTY_FUNCTION__)); | |||
1735 | assert(codegen.loops[idx])(static_cast <bool> (codegen.loops[idx]) ? void (0) : __assert_fail ("codegen.loops[idx]", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1735, __extension__ __PRETTY_FUNCTION__)); | |||
1736 | codegen.loops[idx] = Value(); | |||
1737 | // Bring a pending reduction back from SIMD form when sequence ends. | |||
1738 | if (codegen.redVal) | |||
1739 | if (auto vtp = codegen.redVal.getType().dyn_cast<VectorType>()) | |||
1740 | updateReduc(merger, codegen, | |||
1741 | genVectorReducEnd(codegen, builder, op.getLoc(), vtp)); | |||
1742 | // Unmark bookkeeping of invariants and loop index. | |||
1743 | genInvariants(merger, codegen, builder, op, exp, ldx, /*atStart=*/false); | |||
1744 | // Finalize access pattern expansion for sparse tensor output. | |||
1745 | genExpansion(merger, codegen, builder, op, at, /*atStart=*/false); | |||
1746 | } | |||
1747 | ||||
1748 | /// Recursively generates code while computing iteration lattices in order | |||
1749 | /// to manage the complexity of implementing co-iteration over unions | |||
1750 | /// and intersections of sparse iterations spaces. | |||
1751 | static void genStmt(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, | |||
1752 | linalg::GenericOp op, unsigned exp, unsigned at) { | |||
1753 | // At each leaf, assign remaining tensor (sub)expression to output tensor. | |||
1754 | if (at == codegen.topSort.size()) { | |||
1755 | unsigned ldx = codegen.topSort[at - 1]; | |||
1756 | Value rhs = genExp(merger, codegen, rewriter, op, exp, ldx); | |||
1757 | genTensorStore(merger, codegen, rewriter, op, exp, rhs); | |||
1758 | return; | |||
1759 | } | |||
1760 | ||||
1761 | // Construct iteration lattices for current loop index, with L0 at top. | |||
1762 | unsigned idx = codegen.topSort[at]; | |||
1763 | unsigned ldx = at == 0 ? -1u : codegen.topSort[at - 1]; | |||
1764 | unsigned lts = merger.optimizeSet(merger.buildLattices(exp, idx)); | |||
1765 | ||||
1766 | // Start a loop sequence. | |||
1767 | bool needsUniv = | |||
1768 | startLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx, lts); | |||
1769 | ||||
1770 | // Emit a loop for every lattice point L0 >= Li in this loop sequence. | |||
1771 | unsigned lsize = merger.set(lts).size(); | |||
1772 | for (unsigned i = 0; i < lsize; i++) { | |||
1773 | // Start a loop. | |||
1774 | unsigned li = merger.set(lts)[i]; | |||
1775 | Operation *loop = | |||
1776 | startLoop(merger, codegen, rewriter, op, at, li, needsUniv); | |||
1777 | ||||
1778 | // Visit all lattices points with Li >= Lj to generate the | |||
1779 | // loop-body, possibly with if statements for coiteration. | |||
1780 | Value redInput = codegen.redVal; | |||
1781 | Value cntInput = codegen.expCount; | |||
1782 | bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr; | |||
1783 | for (unsigned j = 0; j < lsize; j++) { | |||
1784 | unsigned lj = merger.set(lts)[j]; | |||
1785 | unsigned ej = merger.lat(lj).exp; | |||
1786 | if (li == lj || merger.latGT(li, lj)) { | |||
1787 | // Recurse into body of each branch. | |||
1788 | if (isWhile) { | |||
1789 | scf::IfOp ifOp = | |||
1790 | genIf(merger, codegen, rewriter, op, idx, merger.lat(lj).simple); | |||
1791 | genStmt(merger, codegen, rewriter, op, ej, at + 1); | |||
1792 | endIf(merger, codegen, rewriter, op, ifOp, loop, redInput, cntInput); | |||
1793 | } else { | |||
1794 | genStmt(merger, codegen, rewriter, op, ej, at + 1); | |||
1795 | } | |||
1796 | } | |||
1797 | } | |||
1798 | ||||
1799 | // End a loop. | |||
1800 | needsUniv = | |||
1801 | endLoop(merger, codegen, rewriter, op, loop, idx, li, needsUniv); | |||
1802 | } | |||
1803 | ||||
1804 | // End a loop sequence. | |||
1805 | endLoopSeq(merger, codegen, rewriter, op, exp, at, idx, ldx); | |||
1806 | } | |||
1807 | ||||
1808 | /// Converts the result computed by the sparse kernel into the required form. | |||
1809 | static void genResult(Merger &merger, CodeGen &codegen, RewriterBase &rewriter, | |||
1810 | linalg::GenericOp op) { | |||
1811 | OpOperand *lhs = op.getOutputOperand(0); | |||
1812 | Type resType = lhs->get().getType(); | |||
1813 | if (getSparseTensorEncoding(resType)) { | |||
1814 | // The sparse tensor rematerializes from the original sparse tensor's | |||
1815 | // underlying sparse storage format. | |||
1816 | rewriter.replaceOpWithNewOp<LoadOp>(op, resType, lhs->get(), | |||
1817 | codegen.sparseOut == lhs); | |||
1818 | } else { | |||
1819 | // To rematerialize an non-annotated tensor, simply load it | |||
1820 | // from the bufferized value. | |||
1821 | Value val = codegen.buffers.back(); // value array | |||
1822 | rewriter.replaceOpWithNewOp<bufferization::ToTensorOp>(op, resType, val); | |||
1823 | } | |||
1824 | } | |||
1825 | ||||
1826 | //===----------------------------------------------------------------------===// | |||
1827 | // Sparse compiler rewriting methods. | |||
1828 | //===----------------------------------------------------------------------===// | |||
1829 | ||||
1830 | namespace { | |||
1831 | ||||
1832 | /// Sparse rewriting rule for generic Lingalg operation. | |||
1833 | struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> { | |||
1834 | public: | |||
1835 | GenericOpSparsifier(MLIRContext *context, SparsificationOptions o) | |||
1836 | : OpRewritePattern<linalg::GenericOp>(context), options(o) {} | |||
1837 | ||||
1838 | LogicalResult matchAndRewrite(linalg::GenericOp op, | |||
1839 | PatternRewriter &rewriter) const override { | |||
1840 | // Detects sparse annotations and translate the per-dimension sparsity | |||
1841 | // information for all tensors to loop indices in the kernel. | |||
1842 | assert(op.getNumOutputs() == 1)(static_cast <bool> (op.getNumOutputs() == 1) ? void (0 ) : __assert_fail ("op.getNumOutputs() == 1", "mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp" , 1842, __extension__ __PRETTY_FUNCTION__)); | |||
| ||||
1843 | unsigned numTensors = op.getNumInputsAndOutputs(); | |||
1844 | unsigned numLoops = op.iterator_types().getValue().size(); | |||
1845 | Merger merger(numTensors, numLoops); | |||
1846 | if (!findSparseAnnotations(merger, op)) | |||
1847 | return failure(); | |||
1848 | ||||
1849 | // Builds the tensor expression for the Linalg operation in SSA form. | |||
1850 | Optional<unsigned> optExp = merger.buildTensorExpFromLinalg(op); | |||
1851 | if (!optExp.has_value()) | |||
1852 | return failure(); | |||
1853 | ||||
1854 | unsigned exp = optExp.value(); | |||
1855 | OpOperand *sparseOut = nullptr; | |||
1856 | unsigned outerParNest = 0; | |||
1857 | // Computes a topologically sorted iteration graph to ensure tensors | |||
1858 | // are visited in natural index order. Gradually relaxes the considered | |||
1859 | // constraints until an acyclic iteration graph results, such that sparse | |||
1860 | // code generation can proceed. As a last resort, an attempt is made | |||
1861 | // to resolve cycles by inserting a conversion. | |||
1862 | std::vector<unsigned> topSort; | |||
1863 | // Whether the current GenericOp is admissible. | |||
1864 | bool isAdmissible = false; | |||
1865 | bool hasCycle = true; | |||
1866 | // An const list of all masks that we used for interation graph | |||
1867 | // computation. Must be ordered from strict -> loose. | |||
1868 | const auto allMask = {SortMask::kIncludeAll, SortMask::kIncludeUndef, | |||
1869 | SortMask::kIncludeDense, SortMask::kSparseOnly}; | |||
1870 | for (auto mask : allMask) | |||
1871 | if (computeIterationGraph(merger, op, topSort, mask)) { | |||
1872 | hasCycle = false; | |||
1873 | if (isAdmissibleTensorExp(merger, op, topSort, exp, &sparseOut, | |||
1874 | outerParNest)) { | |||
1875 | isAdmissible = true; | |||
1876 | break; | |||
1877 | } | |||
1878 | // else try a set of less strict constraints. | |||
1879 | } | |||
1880 | ||||
1881 | if (hasCycle
| |||
1882 | // Give it one last shot to resolve the cycle. | |||
1883 | return resolveCycle(merger, rewriter, op); | |||
1884 | if (!isAdmissible
| |||
1885 | // Inadmissible expression, reject. | |||
1886 | return failure(); | |||
1887 | ||||
1888 | // Recursively generates code if admissible. | |||
1889 | merger.setHasSparseOut(sparseOut != nullptr); | |||
1890 | CodeGen codegen(options, numTensors, numLoops, sparseOut, outerParNest, | |||
1891 | topSort); | |||
1892 | genBuffers(merger, codegen, rewriter, op); | |||
1893 | genStmt(merger, codegen, rewriter, op, exp, 0); | |||
1894 | genResult(merger, codegen, rewriter, op); | |||
1895 | return success(); | |||
1896 | } | |||
1897 | ||||
1898 | private: | |||
1899 | // Last resort cycle resolution. | |||
1900 | LogicalResult resolveCycle(Merger &merger, PatternRewriter &rewriter, | |||
1901 | linalg::GenericOp op) const { | |||
1902 | // Compute topological sort while leaving out every | |||
1903 | // sparse input tensor in succession until an acylic | |||
1904 | // iteration graph results. | |||
1905 | std::vector<unsigned> topSort; | |||
1906 | for (OpOperand *t : op.getInputOperands()) { | |||
1907 | unsigned tensor = t->getOperandNumber(); | |||
1908 | Value tval = t->get(); | |||
1909 | auto srcEnc = getSparseTensorEncoding(tval.getType()); | |||
1910 | if (!srcEnc || | |||
1911 | !computeIterationGraph(merger, op, topSort, SortMask::kSparseOnly, t)) | |||
1912 | continue; | |||
1913 | // Found an input tensor that resolves the cycle by inserting a | |||
1914 | // conversion into a sparse tensor that adheres to the iteration | |||
1915 | // graph order. Also releases the temporary sparse tensor. | |||
1916 | // | |||
1917 | // TODO: investigate fusing the conversion with computation, | |||
1918 | // especially if it is a direct yield! | |||
1919 | // | |||
1920 | auto srcTp = tval.getType().cast<RankedTensorType>(); | |||
1921 | auto dstEnc = SparseTensorEncodingAttr::get( | |||
1922 | op->getContext(), srcEnc.getDimLevelType(), | |||
1923 | permute(getContext(), op.getMatchingIndexingMap(t), | |||
1924 | topSort), // new order | |||
1925 | srcEnc.getPointerBitWidth(), srcEnc.getIndexBitWidth()); | |||
1926 | auto dstTp = RankedTensorType::get(srcTp.getShape(), | |||
1927 | srcTp.getElementType(), dstEnc); | |||
1928 | auto convert = rewriter.create<ConvertOp>(tval.getLoc(), dstTp, tval); | |||
1929 | op->setOperand(tensor, convert); | |||
1930 | rewriter.setInsertionPointAfter(op); | |||
1931 | rewriter.create<bufferization::DeallocTensorOp>(tval.getLoc(), convert); | |||
1932 | return success(); | |||
1933 | } | |||
1934 | // Cannot be resolved with a single conversion. | |||
1935 | // TODO: convert more than one? | |||
1936 | return failure(); | |||
1937 | } | |||
1938 | ||||
1939 | /// Options to control sparse code generation. | |||
1940 | SparsificationOptions options; | |||
1941 | }; | |||
1942 | ||||
1943 | } // namespace | |||
1944 | ||||
1945 | /// Populates the given patterns list with rewriting rules required for | |||
1946 | /// the sparsification of linear algebra operations. | |||
1947 | void mlir::populateSparsificationPatterns( | |||
1948 | RewritePatternSet &patterns, const SparsificationOptions &options) { | |||
1949 | patterns.add<GenericOpSparsifier>(patterns.getContext(), options); | |||
1950 | } |