File: | build/source/mlir/lib/Dialect/Vector/IR/VectorOps.cpp |
Warning: | line 5547, column 23 1st function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements convenience types for working with super-vectorization | |||
10 | // operations, in particular super-vector loads and stores. | |||
11 | // | |||
12 | //===----------------------------------------------------------------------===// | |||
13 | ||||
14 | #include "mlir/Dialect/Vector/IR/VectorOps.h" | |||
15 | ||||
16 | #include "mlir/Dialect/Arith/IR/Arith.h" | |||
17 | #include "mlir/Dialect/Arith/Utils/Utils.h" | |||
18 | #include "mlir/Dialect/MemRef/IR/MemRef.h" | |||
19 | #include "mlir/Dialect/Tensor/IR/Tensor.h" | |||
20 | #include "mlir/Dialect/Utils/IndexingUtils.h" | |||
21 | #include "mlir/Dialect/Utils/StructuredOpsUtils.h" | |||
22 | #include "mlir/IR/AffineExpr.h" | |||
23 | #include "mlir/IR/AffineMap.h" | |||
24 | #include "mlir/IR/BlockAndValueMapping.h" | |||
25 | #include "mlir/IR/Builders.h" | |||
26 | #include "mlir/IR/BuiltinAttributes.h" | |||
27 | #include "mlir/IR/BuiltinOps.h" | |||
28 | #include "mlir/IR/BuiltinTypes.h" | |||
29 | #include "mlir/IR/DialectImplementation.h" | |||
30 | #include "mlir/IR/OpImplementation.h" | |||
31 | #include "mlir/IR/PatternMatch.h" | |||
32 | #include "mlir/IR/TypeUtilities.h" | |||
33 | #include "mlir/Support/LLVM.h" | |||
34 | #include "llvm/ADT/ArrayRef.h" | |||
35 | #include "llvm/ADT/STLExtras.h" | |||
36 | #include "llvm/ADT/SmallVector.h" | |||
37 | #include "llvm/ADT/StringSet.h" | |||
38 | #include "llvm/ADT/TypeSwitch.h" | |||
39 | #include "llvm/ADT/bit.h" | |||
40 | ||||
41 | #include <cassert> | |||
42 | #include <cstdint> | |||
43 | #include <numeric> | |||
44 | ||||
45 | #include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc" | |||
46 | // Pull in all enum type and utility function definitions. | |||
47 | #include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc" | |||
48 | ||||
49 | using namespace mlir; | |||
50 | using namespace mlir::vector; | |||
51 | ||||
52 | /// Helper enum to classify mask value. | |||
53 | enum class MaskFormat { | |||
54 | AllTrue = 0, | |||
55 | AllFalse = 1, | |||
56 | Unknown = 2, | |||
57 | }; | |||
58 | ||||
59 | /// Helper method to classify a mask value. Currently, the method | |||
60 | /// looks "under the hood" of a constant value with dense attributes | |||
61 | /// and a constant mask operation (since the client may be called at | |||
62 | /// various stages during progressive lowering). | |||
63 | static MaskFormat getMaskFormat(Value mask) { | |||
64 | if (auto c = mask.getDefiningOp<arith::ConstantOp>()) { | |||
65 | // Inspect constant dense values. We count up for bits that | |||
66 | // are set, count down for bits that are cleared, and bail | |||
67 | // when a mix is detected. | |||
68 | if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) { | |||
69 | int64_t val = 0; | |||
70 | for (bool b : denseElts.getValues<bool>()) | |||
71 | if (b && val >= 0) | |||
72 | val++; | |||
73 | else if (!b && val <= 0) | |||
74 | val--; | |||
75 | else | |||
76 | return MaskFormat::Unknown; | |||
77 | if (val > 0) | |||
78 | return MaskFormat::AllTrue; | |||
79 | if (val < 0) | |||
80 | return MaskFormat::AllFalse; | |||
81 | } | |||
82 | } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) { | |||
83 | // Inspect constant mask index. If the index exceeds the | |||
84 | // dimension size, all bits are set. If the index is zero | |||
85 | // or less, no bits are set. | |||
86 | ArrayAttr masks = m.getMaskDimSizes(); | |||
87 | auto shape = m.getType().getShape(); | |||
88 | bool allTrue = true; | |||
89 | bool allFalse = true; | |||
90 | for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) { | |||
91 | int64_t i = maskIdx.cast<IntegerAttr>().getInt(); | |||
92 | if (i < dimSize) | |||
93 | allTrue = false; | |||
94 | if (i > 0) | |||
95 | allFalse = false; | |||
96 | } | |||
97 | if (allTrue) | |||
98 | return MaskFormat::AllTrue; | |||
99 | if (allFalse) | |||
100 | return MaskFormat::AllFalse; | |||
101 | } | |||
102 | return MaskFormat::Unknown; | |||
103 | } | |||
104 | ||||
105 | /// Default callback to build a region with a 'vector.yield' terminator with no | |||
106 | /// arguments. | |||
107 | void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) { | |||
108 | builder.create<vector::YieldOp>(loc); | |||
109 | } | |||
110 | ||||
111 | // Helper for verifying combining kinds in contractions and reductions. | |||
112 | static bool isSupportedCombiningKind(CombiningKind combiningKind, | |||
113 | Type elementType) { | |||
114 | switch (combiningKind) { | |||
115 | case CombiningKind::ADD: | |||
116 | case CombiningKind::MUL: | |||
117 | return elementType.isIntOrIndexOrFloat(); | |||
118 | case CombiningKind::MINUI: | |||
119 | case CombiningKind::MINSI: | |||
120 | case CombiningKind::MAXUI: | |||
121 | case CombiningKind::MAXSI: | |||
122 | case CombiningKind::AND: | |||
123 | case CombiningKind::OR: | |||
124 | case CombiningKind::XOR: | |||
125 | return elementType.isIntOrIndex(); | |||
126 | case CombiningKind::MINF: | |||
127 | case CombiningKind::MAXF: | |||
128 | return elementType.isa<FloatType>(); | |||
129 | } | |||
130 | return false; | |||
131 | } | |||
132 | ||||
133 | /// Return true if the last dimension of the MemRefType has unit stride. Also | |||
134 | /// return true for memrefs with no strides. | |||
135 | bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) { | |||
136 | int64_t offset; | |||
137 | SmallVector<int64_t> strides; | |||
138 | auto successStrides = getStridesAndOffset(type, strides, offset); | |||
139 | return succeeded(successStrides) && (strides.empty() || strides.back() == 1); | |||
140 | } | |||
141 | ||||
142 | AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType, | |||
143 | VectorType vectorType) { | |||
144 | int64_t elementVectorRank = 0; | |||
145 | VectorType elementVectorType = | |||
146 | shapedType.getElementType().dyn_cast<VectorType>(); | |||
147 | if (elementVectorType) | |||
148 | elementVectorRank += elementVectorType.getRank(); | |||
149 | // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>. | |||
150 | // TODO: replace once we have 0-d vectors. | |||
151 | if (shapedType.getRank() == 0 && | |||
152 | vectorType.getShape() == ArrayRef<int64_t>{1}) | |||
153 | return AffineMap::get( | |||
154 | /*numDims=*/0, /*numSymbols=*/0, | |||
155 | getAffineConstantExpr(0, shapedType.getContext())); | |||
156 | return AffineMap::getMinorIdentityMap( | |||
157 | shapedType.getRank(), vectorType.getRank() - elementVectorRank, | |||
158 | shapedType.getContext()); | |||
159 | } | |||
160 | ||||
161 | bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite, | |||
162 | vector::TransferReadOp read) { | |||
163 | return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() && | |||
164 | !read.getMask() && defWrite.getIndices() == read.getIndices() && | |||
165 | defWrite.getVectorType() == read.getVectorType() && | |||
166 | defWrite.getPermutationMap() == read.getPermutationMap(); | |||
167 | } | |||
168 | ||||
169 | bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write, | |||
170 | vector::TransferWriteOp priorWrite) { | |||
171 | return priorWrite.getIndices() == write.getIndices() && | |||
172 | priorWrite.getMask() == write.getMask() && | |||
173 | priorWrite.getVectorType() == write.getVectorType() && | |||
174 | priorWrite.getPermutationMap() == write.getPermutationMap(); | |||
175 | } | |||
176 | ||||
177 | bool mlir::vector::isDisjointTransferIndices( | |||
178 | VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) { | |||
179 | // For simplicity only look at transfer of same type. | |||
180 | if (transferA.getVectorType() != transferB.getVectorType()) | |||
181 | return false; | |||
182 | unsigned rankOffset = transferA.getLeadingShapedRank(); | |||
183 | for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) { | |||
184 | auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>(); | |||
185 | auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>(); | |||
186 | // If any of the indices are dynamic we cannot prove anything. | |||
187 | if (!indexA || !indexB) | |||
188 | continue; | |||
189 | ||||
190 | if (i < rankOffset) { | |||
191 | // For leading dimensions, if we can prove that index are different we | |||
192 | // know we are accessing disjoint slices. | |||
193 | if (indexA.getValue().cast<IntegerAttr>().getInt() != | |||
194 | indexB.getValue().cast<IntegerAttr>().getInt()) | |||
195 | return true; | |||
196 | } else { | |||
197 | // For this dimension, we slice a part of the memref we need to make sure | |||
198 | // the intervals accessed don't overlap. | |||
199 | int64_t distance = | |||
200 | std::abs(indexA.getValue().cast<IntegerAttr>().getInt() - | |||
201 | indexB.getValue().cast<IntegerAttr>().getInt()); | |||
202 | if (distance >= transferA.getVectorType().getDimSize(i - rankOffset)) | |||
203 | return true; | |||
204 | } | |||
205 | } | |||
206 | return false; | |||
207 | } | |||
208 | ||||
209 | bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA, | |||
210 | VectorTransferOpInterface transferB) { | |||
211 | if (transferA.source() != transferB.source()) | |||
212 | return false; | |||
213 | return isDisjointTransferIndices(transferA, transferB); | |||
214 | } | |||
215 | ||||
216 | // Helper to iterate over n-D vector slice elements. Calculate the next | |||
217 | // `position` in the n-D vector of size `shape`, applying an offset `offsets`. | |||
218 | // Modifies the `position` in place. Returns a failure when `position` becomes | |||
219 | // the end position. | |||
220 | static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position, | |||
221 | ArrayRef<int64_t> shape, | |||
222 | ArrayRef<int64_t> offsets) { | |||
223 | for (auto [posInDim, dimSize, offsetInDim] : | |||
224 | llvm::reverse(llvm::zip_equal(position, shape, offsets))) { | |||
225 | ++posInDim; | |||
226 | if (posInDim < dimSize + offsetInDim) | |||
227 | return success(); | |||
228 | ||||
229 | // Carry the overflow to the next loop iteration. | |||
230 | posInDim = offsetInDim; | |||
231 | } | |||
232 | ||||
233 | return failure(); | |||
234 | } | |||
235 | ||||
236 | //===----------------------------------------------------------------------===// | |||
237 | // CombiningKindAttr | |||
238 | //===----------------------------------------------------------------------===// | |||
239 | ||||
240 | namespace mlir { | |||
241 | namespace vector { | |||
242 | namespace detail { | |||
243 | struct BitmaskEnumStorage : public AttributeStorage { | |||
244 | using KeyTy = uint64_t; | |||
245 | ||||
246 | BitmaskEnumStorage(KeyTy val) : value(val) {} | |||
247 | ||||
248 | bool operator==(const KeyTy &key) const { return value == key; } | |||
249 | ||||
250 | static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator, | |||
251 | const KeyTy &key) { | |||
252 | return new (allocator.allocate<BitmaskEnumStorage>()) | |||
253 | BitmaskEnumStorage(key); | |||
254 | } | |||
255 | ||||
256 | KeyTy value = 0; | |||
257 | }; | |||
258 | } // namespace detail | |||
259 | } // namespace vector | |||
260 | } // namespace mlir | |||
261 | ||||
262 | //===----------------------------------------------------------------------===// | |||
263 | // VectorDialect | |||
264 | //===----------------------------------------------------------------------===// | |||
265 | ||||
266 | void VectorDialect::initialize() { | |||
267 | addAttributes< | |||
268 | #define GET_ATTRDEF_LIST | |||
269 | #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" | |||
270 | >(); | |||
271 | ||||
272 | addOperations< | |||
273 | #define GET_OP_LIST | |||
274 | #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" | |||
275 | >(); | |||
276 | } | |||
277 | ||||
278 | /// Materialize a single constant operation from a given attribute value with | |||
279 | /// the desired resultant type. | |||
280 | Operation *VectorDialect::materializeConstant(OpBuilder &builder, | |||
281 | Attribute value, Type type, | |||
282 | Location loc) { | |||
283 | return builder.create<arith::ConstantOp>(loc, type, value); | |||
284 | } | |||
285 | ||||
286 | IntegerType vector::getVectorSubscriptType(Builder &builder) { | |||
287 | return builder.getIntegerType(64); | |||
288 | } | |||
289 | ||||
290 | ArrayAttr vector::getVectorSubscriptAttr(Builder &builder, | |||
291 | ArrayRef<int64_t> values) { | |||
292 | return builder.getI64ArrayAttr(values); | |||
293 | } | |||
294 | ||||
295 | //===----------------------------------------------------------------------===// | |||
296 | // MultiDimReductionOp | |||
297 | //===----------------------------------------------------------------------===// | |||
298 | ||||
299 | void vector::MultiDimReductionOp::build(OpBuilder &builder, | |||
300 | OperationState &result, Value source, | |||
301 | Value acc, ArrayRef<bool> reductionMask, | |||
302 | CombiningKind kind) { | |||
303 | SmallVector<int64_t> reductionDims; | |||
304 | for (const auto &en : llvm::enumerate(reductionMask)) | |||
305 | if (en.value()) | |||
306 | reductionDims.push_back(en.index()); | |||
307 | build(builder, result, kind, source, acc, | |||
308 | builder.getI64ArrayAttr(reductionDims)); | |||
309 | } | |||
310 | ||||
311 | OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) { | |||
312 | // Single parallel dim, this is a noop. | |||
313 | if (getSourceVectorType().getRank() == 1 && !isReducedDim(0)) | |||
314 | return getSource(); | |||
315 | return {}; | |||
316 | } | |||
317 | ||||
318 | std::optional<SmallVector<int64_t, 4>> | |||
319 | MultiDimReductionOp::getShapeForUnroll() { | |||
320 | return llvm::to_vector<4>(getSourceVectorType().getShape()); | |||
321 | } | |||
322 | ||||
323 | LogicalResult MultiDimReductionOp::verify() { | |||
324 | SmallVector<int64_t> targetShape; | |||
325 | Type inferredReturnType; | |||
326 | for (auto it : llvm::enumerate(getSourceVectorType().getShape())) | |||
327 | if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) { | |||
328 | return attr.cast<IntegerAttr>().getValue() == it.index(); | |||
329 | })) | |||
330 | targetShape.push_back(it.value()); | |||
331 | // TODO: update to also allow 0-d vectors when available. | |||
332 | if (targetShape.empty()) | |||
333 | inferredReturnType = getSourceVectorType().getElementType(); | |||
334 | else | |||
335 | inferredReturnType = | |||
336 | VectorType::get(targetShape, getSourceVectorType().getElementType()); | |||
337 | if (getType() != inferredReturnType) | |||
338 | return emitOpError() << "destination type " << getType() | |||
339 | << " is incompatible with source type " | |||
340 | << getSourceVectorType(); | |||
341 | ||||
342 | return success(); | |||
343 | } | |||
344 | ||||
345 | namespace { | |||
346 | // Only unit dimensions that are being reduced are folded. If the dimension is | |||
347 | // unit, but not reduced, it is not folded, thereby keeping the output type the | |||
348 | // same. If not all dimensions which are reduced are of unit dimension, this | |||
349 | // transformation does nothing. This is just a generalization of | |||
350 | // ElideSingleElementReduction for ReduceOp. | |||
351 | struct ElideUnitDimsInMultiDimReduction | |||
352 | : public OpRewritePattern<MultiDimReductionOp> { | |||
353 | using OpRewritePattern::OpRewritePattern; | |||
354 | ||||
355 | LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp, | |||
356 | PatternRewriter &rewriter) const override { | |||
357 | ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape(); | |||
358 | for (const auto &dim : enumerate(shape)) { | |||
359 | if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1) | |||
360 | return failure(); | |||
361 | } | |||
362 | Location loc = reductionOp.getLoc(); | |||
363 | Value acc = reductionOp.getAcc(); | |||
364 | Value cast; | |||
365 | if (reductionOp.getDestType().isa<VectorType>()) { | |||
366 | cast = rewriter.create<vector::ShapeCastOp>( | |||
367 | loc, reductionOp.getDestType(), reductionOp.getSource()); | |||
368 | } else { | |||
369 | // This means we are reducing all the dimensions, and all reduction | |||
370 | // dimensions are of size 1. So a simple extraction would do. | |||
371 | cast = rewriter.create<vector::ExtractOp>( | |||
372 | loc, reductionOp.getDestType(), reductionOp.getSource(), | |||
373 | rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0))); | |||
374 | } | |||
375 | ||||
376 | Value result = vector::makeArithReduction(rewriter, loc, | |||
377 | reductionOp.getKind(), acc, cast); | |||
378 | rewriter.replaceOp(reductionOp, result); | |||
379 | return success(); | |||
380 | } | |||
381 | }; | |||
382 | } // namespace | |||
383 | ||||
384 | void MultiDimReductionOp::getCanonicalizationPatterns( | |||
385 | RewritePatternSet &results, MLIRContext *context) { | |||
386 | results.add<ElideUnitDimsInMultiDimReduction>(context); | |||
387 | } | |||
388 | ||||
389 | //===----------------------------------------------------------------------===// | |||
390 | // ReductionOp | |||
391 | //===----------------------------------------------------------------------===// | |||
392 | ||||
393 | void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, | |||
394 | CombiningKind kind, Value vector) { | |||
395 | build(builder, result, kind, vector, /*acc=*/Value()); | |||
396 | } | |||
397 | ||||
398 | void vector::ReductionOp::build(OpBuilder &builder, OperationState &result, | |||
399 | CombiningKind kind, Value vector, Value acc) { | |||
400 | build(builder, result, vector.getType().cast<VectorType>().getElementType(), | |||
401 | kind, vector, acc); | |||
402 | } | |||
403 | ||||
404 | LogicalResult ReductionOp::verify() { | |||
405 | // Verify for 0-D and 1-D vector. | |||
406 | int64_t rank = getVectorType().getRank(); | |||
407 | if (rank > 1) | |||
408 | return emitOpError("unsupported reduction rank: ") << rank; | |||
409 | ||||
410 | // Verify supported reduction kind. | |||
411 | Type eltType = getDest().getType(); | |||
412 | if (!isSupportedCombiningKind(getKind(), eltType)) | |||
413 | return emitOpError("unsupported reduction type '") | |||
414 | << eltType << "' for kind '" << stringifyCombiningKind(getKind()) | |||
415 | << "'"; | |||
416 | ||||
417 | return success(); | |||
418 | } | |||
419 | ||||
420 | ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) { | |||
421 | SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo; | |||
422 | Type redType; | |||
423 | Type resType; | |||
424 | CombiningKindAttr kindAttr; | |||
425 | if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind", | |||
426 | result.attributes) || | |||
427 | parser.parseComma() || parser.parseOperandList(operandsInfo) || | |||
428 | parser.parseColonType(redType) || | |||
429 | parser.parseKeywordType("into", resType) || | |||
430 | (!operandsInfo.empty() && | |||
431 | parser.resolveOperand(operandsInfo[0], redType, result.operands)) || | |||
432 | (operandsInfo.size() > 1 && | |||
433 | parser.resolveOperand(operandsInfo[1], resType, result.operands)) || | |||
434 | parser.addTypeToList(resType, result.types)) | |||
435 | return failure(); | |||
436 | if (operandsInfo.empty() || operandsInfo.size() > 2) | |||
437 | return parser.emitError(parser.getNameLoc(), | |||
438 | "unsupported number of operands"); | |||
439 | return success(); | |||
440 | } | |||
441 | ||||
442 | void ReductionOp::print(OpAsmPrinter &p) { | |||
443 | p << " "; | |||
444 | getKindAttr().print(p); | |||
445 | p << ", " << getVector(); | |||
446 | if (getAcc()) | |||
447 | p << ", " << getAcc(); | |||
448 | p << " : " << getVector().getType() << " into " << getDest().getType(); | |||
449 | } | |||
450 | ||||
451 | // MaskableOpInterface methods. | |||
452 | ||||
453 | /// Returns the mask type expected by this operation. | |||
454 | Type ReductionOp::getExpectedMaskType() { | |||
455 | auto vecType = getVectorType(); | |||
456 | return vecType.cloneWith(std::nullopt, | |||
457 | IntegerType::get(vecType.getContext(), /*width=*/1)); | |||
458 | } | |||
459 | ||||
460 | Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op, | |||
461 | OpBuilder &builder, Location loc, | |||
462 | Value vector) { | |||
463 | switch (op) { | |||
464 | case arith::AtomicRMWKind::addf: | |||
465 | case arith::AtomicRMWKind::addi: | |||
466 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
467 | CombiningKind::ADD, vector); | |||
468 | case arith::AtomicRMWKind::mulf: | |||
469 | case arith::AtomicRMWKind::muli: | |||
470 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
471 | CombiningKind::MUL, vector); | |||
472 | case arith::AtomicRMWKind::minf: | |||
473 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
474 | CombiningKind::MINF, vector); | |||
475 | case arith::AtomicRMWKind::mins: | |||
476 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
477 | CombiningKind::MINSI, vector); | |||
478 | case arith::AtomicRMWKind::minu: | |||
479 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
480 | CombiningKind::MINUI, vector); | |||
481 | case arith::AtomicRMWKind::maxf: | |||
482 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
483 | CombiningKind::MAXF, vector); | |||
484 | case arith::AtomicRMWKind::maxs: | |||
485 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
486 | CombiningKind::MAXSI, vector); | |||
487 | case arith::AtomicRMWKind::maxu: | |||
488 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
489 | CombiningKind::MAXUI, vector); | |||
490 | case arith::AtomicRMWKind::andi: | |||
491 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
492 | CombiningKind::AND, vector); | |||
493 | case arith::AtomicRMWKind::ori: | |||
494 | return builder.create<vector::ReductionOp>(vector.getLoc(), | |||
495 | CombiningKind::OR, vector); | |||
496 | // TODO: Add remaining reduction operations. | |||
497 | default: | |||
498 | (void)emitOptionalError(loc, "Reduction operation type not supported"); | |||
499 | break; | |||
500 | } | |||
501 | return nullptr; | |||
502 | } | |||
503 | ||||
504 | std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() { | |||
505 | return llvm::to_vector<4>(getVectorType().getShape()); | |||
506 | } | |||
507 | ||||
508 | namespace { | |||
509 | struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> { | |||
510 | using OpRewritePattern::OpRewritePattern; | |||
511 | ||||
512 | LogicalResult matchAndRewrite(ReductionOp reductionOp, | |||
513 | PatternRewriter &rewriter) const override { | |||
514 | if (reductionOp.getVectorType().getDimSize(0) != 1) | |||
515 | return failure(); | |||
516 | ||||
517 | Location loc = reductionOp.getLoc(); | |||
518 | Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(), | |||
519 | reductionOp.getVector(), | |||
520 | rewriter.getI64ArrayAttr(0)); | |||
521 | ||||
522 | if (Value acc = reductionOp.getAcc()) | |||
523 | result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), | |||
524 | result, acc); | |||
525 | ||||
526 | rewriter.replaceOp(reductionOp, result); | |||
527 | return success(); | |||
528 | } | |||
529 | }; | |||
530 | } // namespace | |||
531 | ||||
532 | void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
533 | MLIRContext *context) { | |||
534 | results.add<ElideSingleElementReduction>(context); | |||
535 | } | |||
536 | ||||
537 | //===----------------------------------------------------------------------===// | |||
538 | // ContractionOp | |||
539 | //===----------------------------------------------------------------------===// | |||
540 | ||||
541 | void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, | |||
542 | Value lhs, Value rhs, Value acc, | |||
543 | ArrayRef<ArrayRef<AffineExpr>> indexingExprs, | |||
544 | ArrayRef<IteratorType> iteratorTypes) { | |||
545 | result.addOperands({lhs, rhs, acc}); | |||
546 | result.addTypes(acc.getType()); | |||
547 | result.addAttribute(getIndexingMapsAttrName(result.name), | |||
548 | builder.getAffineMapArrayAttr( | |||
549 | AffineMap::inferFromExprList(indexingExprs))); | |||
550 | result.addAttribute( | |||
551 | getIteratorTypesAttrName(result.name), | |||
552 | builder.getArrayAttr(llvm::to_vector(llvm::map_range( | |||
553 | iteratorTypes, [&](IteratorType t) -> mlir::Attribute { | |||
554 | return IteratorTypeAttr::get(builder.getContext(), t); | |||
555 | })))); | |||
556 | } | |||
557 | ||||
558 | void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, | |||
559 | Value lhs, Value rhs, Value acc, | |||
560 | ArrayAttr indexingMaps, | |||
561 | ArrayAttr iteratorTypes) { | |||
562 | build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes, | |||
563 | ContractionOp::getDefaultKind()); | |||
564 | } | |||
565 | ||||
566 | void vector::ContractionOp::build(OpBuilder &builder, OperationState &result, | |||
567 | Value lhs, Value rhs, Value acc, | |||
568 | ArrayAttr indexingMaps, | |||
569 | ArrayAttr iteratorTypes, CombiningKind kind) { | |||
570 | result.addOperands({lhs, rhs, acc}); | |||
571 | result.addTypes(acc.getType()); | |||
572 | result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps); | |||
573 | result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes); | |||
574 | result.addAttribute(getKindAttrName(result.name), | |||
575 | CombiningKindAttr::get(builder.getContext(), kind)); | |||
576 | } | |||
577 | ||||
578 | ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) { | |||
579 | OpAsmParser::UnresolvedOperand lhsInfo; | |||
580 | OpAsmParser::UnresolvedOperand rhsInfo; | |||
581 | OpAsmParser::UnresolvedOperand accInfo; | |||
582 | SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo; | |||
583 | SmallVector<Type, 2> types; | |||
584 | Type resultType; | |||
585 | auto loc = parser.getCurrentLocation(); | |||
586 | DictionaryAttr dictAttr; | |||
587 | // TODO: Unify linalg op attribute parsing. | |||
588 | if (parser.parseAttribute(dictAttr, "_", result.attributes) || | |||
589 | parser.parseOperand(lhsInfo) || parser.parseComma() || | |||
590 | parser.parseOperand(rhsInfo) || parser.parseComma() || | |||
591 | parser.parseOperand(accInfo) || | |||
592 | parser.parseTrailingOperandList(masksInfo) || | |||
593 | parser.parseOptionalAttrDict(result.attributes) || | |||
594 | parser.parseColonTypeList(types) || | |||
595 | parser.parseKeywordType("into", resultType) || | |||
596 | parser.resolveOperand(lhsInfo, types[0], result.operands) || | |||
597 | parser.resolveOperand(rhsInfo, types[1], result.operands) || | |||
598 | parser.resolveOperand(accInfo, resultType, result.operands) || | |||
599 | parser.addTypeToList(resultType, result.types)) | |||
600 | return failure(); | |||
601 | result.attributes.assign(dictAttr.getValue().begin(), | |||
602 | dictAttr.getValue().end()); | |||
603 | ||||
604 | // Convert array of string into an array of IteratyType enums. This is needed, | |||
605 | // because tests still use the old format when 'iterator_types' attribute is | |||
606 | // represented as an array of strings. | |||
607 | // TODO: Remove this conversion once tests are fixed. | |||
608 | ArrayAttr iteratorTypes = | |||
609 | result.attributes.get(getIteratorTypesAttrName(result.name)) | |||
610 | .cast<ArrayAttr>(); | |||
611 | ||||
612 | SmallVector<Attribute> iteratorTypeAttrs; | |||
613 | ||||
614 | for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) { | |||
615 | auto maybeIteratorType = symbolizeIteratorType(s); | |||
616 | if (!maybeIteratorType.has_value()) | |||
617 | return parser.emitError(loc) << "unexpected iterator_type (" << s << ")"; | |||
618 | ||||
619 | iteratorTypeAttrs.push_back( | |||
620 | IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value())); | |||
621 | } | |||
622 | result.attributes.set(getIteratorTypesAttrName(result.name), | |||
623 | parser.getBuilder().getArrayAttr(iteratorTypeAttrs)); | |||
624 | ||||
625 | if (!result.attributes.get(getKindAttrName(result.name))) { | |||
626 | result.addAttribute( | |||
627 | getKindAttrName(result.name), | |||
628 | CombiningKindAttr::get(result.getContext(), | |||
629 | ContractionOp::getDefaultKind())); | |||
630 | } | |||
631 | if (masksInfo.empty()) | |||
632 | return success(); | |||
633 | if (masksInfo.size() != 2) | |||
634 | return parser.emitError(parser.getNameLoc(), | |||
635 | "expected zero or exactly 2 vector mask operands"); | |||
636 | auto lhsType = types[0].cast<VectorType>(); | |||
637 | auto rhsType = types[1].cast<VectorType>(); | |||
638 | auto maskElementType = parser.getBuilder().getI1Type(); | |||
639 | std::array<Type, 2> maskTypes = { | |||
640 | VectorType::Builder(lhsType).setElementType(maskElementType), | |||
641 | VectorType::Builder(rhsType).setElementType(maskElementType)}; | |||
642 | if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) | |||
643 | return failure(); | |||
644 | return success(); | |||
645 | } | |||
646 | ||||
647 | void ContractionOp::print(OpAsmPrinter &p) { | |||
648 | // TODO: Unify printing code with linalg ops. | |||
649 | auto attrNames = getTraitAttrNames(); | |||
650 | llvm::StringSet<> traitAttrsSet; | |||
651 | traitAttrsSet.insert(attrNames.begin(), attrNames.end()); | |||
652 | SmallVector<NamedAttribute, 8> attrs; | |||
653 | for (auto attr : (*this)->getAttrs()) { | |||
654 | if (attr.getName() == getIteratorTypesAttrName()) { | |||
655 | auto iteratorTypes = | |||
656 | attr.getValue() | |||
657 | .cast<ArrayAttr>() | |||
658 | .getAsValueRange<IteratorTypeAttr, IteratorType>(); | |||
659 | // Convert IteratorType enums into the string representation. This is | |||
660 | // needed, because tests still use the old format when 'iterator_types' | |||
661 | // attribute is represented as an array of strings. | |||
662 | // TODO: Remove this conversion once tests are fixed. | |||
663 | SmallVector<Attribute> iteratorTypeNames = llvm::to_vector( | |||
664 | llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute { | |||
665 | return StringAttr::get(getContext(), stringifyIteratorType(t)); | |||
666 | })); | |||
667 | ||||
668 | attrs.emplace_back(getIteratorTypesAttrName(), | |||
669 | ArrayAttr::get(getContext(), iteratorTypeNames)); | |||
670 | } else if (traitAttrsSet.count(attr.getName().strref()) > 0) | |||
671 | attrs.push_back(attr); | |||
672 | } | |||
673 | ||||
674 | auto dictAttr = DictionaryAttr::get(getContext(), attrs); | |||
675 | p << " " << dictAttr << " " << getLhs() << ", "; | |||
676 | p << getRhs() << ", " << getAcc(); | |||
677 | if (getMasks().size() == 2) | |||
678 | p << ", " << getMasks(); | |||
679 | ||||
680 | p.printOptionalAttrDict((*this)->getAttrs(), attrNames); | |||
681 | p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into " | |||
682 | << getResultType(); | |||
683 | } | |||
684 | ||||
685 | static bool verifyDimMap(VectorType lhsType, VectorType rhsType, | |||
686 | const std::vector<std::pair<int64_t, int64_t>> &map) { | |||
687 | for (auto &dimPair : map) { | |||
688 | if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() || | |||
689 | dimPair.second < 0 || dimPair.second >= rhsType.getRank() || | |||
690 | lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second)) | |||
691 | return false; | |||
692 | } | |||
693 | return true; | |||
694 | } | |||
695 | ||||
696 | static LogicalResult verifyOutputShape( | |||
697 | ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType, | |||
698 | Type resType, | |||
699 | const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap, | |||
700 | const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) { | |||
701 | DenseSet<int64_t> lhsContractingDimSet; | |||
702 | DenseSet<int64_t> rhsContractingDimSet; | |||
703 | for (auto &dimPair : contractingDimMap) { | |||
704 | lhsContractingDimSet.insert(dimPair.first); | |||
705 | rhsContractingDimSet.insert(dimPair.second); | |||
706 | } | |||
707 | DenseSet<int64_t> rhsBatchDimSet; | |||
708 | for (auto &dimPair : batchDimMap) | |||
709 | rhsBatchDimSet.insert(dimPair.second); | |||
710 | ||||
711 | // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'. | |||
712 | SmallVector<int64_t, 4> expectedResultDims; | |||
713 | for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) { | |||
714 | if (lhsContractingDimSet.count(i) > 0) | |||
715 | continue; | |||
716 | expectedResultDims.push_back(lhsType.getDimSize(i)); | |||
717 | } | |||
718 | ||||
719 | // Add free dimensions from 'rhsType' to 'expectedResultDims'. | |||
720 | for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) { | |||
721 | if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0) | |||
722 | continue; | |||
723 | expectedResultDims.push_back(rhsType.getDimSize(i)); | |||
724 | } | |||
725 | ||||
726 | // Verify 'expectedResultDims'. | |||
727 | if (expectedResultDims.empty()) { | |||
728 | // No batch or free dimension implies a scalar result. | |||
729 | if (resType.isa<VectorType>() || accType.isa<VectorType>()) | |||
730 | return op.emitOpError("invalid accumulator/result vector shape"); | |||
731 | } else { | |||
732 | // At least one batch or free dimension implies a vector result. | |||
733 | auto resVectorType = resType.dyn_cast<VectorType>(); | |||
734 | auto accVectorType = accType.dyn_cast<VectorType>(); | |||
735 | if (!resVectorType || !accVectorType) | |||
736 | return op.emitOpError("invalid accumulator/result vector shape"); | |||
737 | ||||
738 | // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector | |||
739 | // types fully define the result vector type. This assumes the affine maps | |||
740 | // are well-formed, which must have been verified already. | |||
741 | MLIRContext *ctx = op.getContext(); | |||
742 | AffineMap lhsMap = op.getIndexingMapsArray()[0]; | |||
743 | AffineMap rhsMap = op.getIndexingMapsArray()[1]; | |||
744 | if (getUnusedDimsBitVector({lhsMap, rhsMap}).any()) | |||
745 | return op.emitOpError( | |||
746 | "expected all dimensions to be either a LHS or a RHS dimension"); | |||
747 | SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs()); | |||
748 | for (auto pair : | |||
749 | {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) { | |||
750 | VectorType v = pair.first; | |||
751 | auto map = pair.second; | |||
752 | for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) { | |||
753 | unsigned pos = map.getDimPosition(idx); | |||
754 | if (!extents[pos]) | |||
755 | extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx); | |||
756 | } | |||
757 | } | |||
758 | if (!llvm::all_of(extents, [](AffineExpr e) { return e; })) | |||
759 | return op.emitOpError("expected all dimensions to get an extent as " | |||
760 | "either a LHS or a RHS dimension"); | |||
761 | ||||
762 | AffineMap resMap = op.getIndexingMapsArray()[2]; | |||
763 | auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(), | |||
764 | /*symCount=*/0, extents, ctx); | |||
765 | // Compose the resMap with the extentsMap, which is a constant map. | |||
766 | AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap)); | |||
767 | assert(llvm::all_of((static_cast <bool> (llvm::all_of( expectedMap.getResults (), [](AffineExpr e) { return e.isa<AffineConstantExpr> (); }) && "expected constant extent along all dimensions." ) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__ __PRETTY_FUNCTION__)) | |||
768 | expectedMap.getResults(),(static_cast <bool> (llvm::all_of( expectedMap.getResults (), [](AffineExpr e) { return e.isa<AffineConstantExpr> (); }) && "expected constant extent along all dimensions." ) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__ __PRETTY_FUNCTION__)) | |||
769 | [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&(static_cast <bool> (llvm::all_of( expectedMap.getResults (), [](AffineExpr e) { return e.isa<AffineConstantExpr> (); }) && "expected constant extent along all dimensions." ) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__ __PRETTY_FUNCTION__)) | |||
770 | "expected constant extent along all dimensions.")(static_cast <bool> (llvm::all_of( expectedMap.getResults (), [](AffineExpr e) { return e.isa<AffineConstantExpr> (); }) && "expected constant extent along all dimensions." ) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__ __PRETTY_FUNCTION__)); | |||
771 | // Extract the expected shape and build the type. | |||
772 | auto expectedShape = llvm::to_vector<4>( | |||
773 | llvm::map_range(expectedMap.getResults(), [](AffineExpr e) { | |||
774 | return e.cast<AffineConstantExpr>().getValue(); | |||
775 | })); | |||
776 | auto expected = | |||
777 | VectorType::get(expectedShape, resVectorType.getElementType()); | |||
778 | if (resVectorType != expected || accVectorType != expected) | |||
779 | return op.emitOpError( | |||
780 | "invalid accumulator/result vector shape, expected: ") | |||
781 | << expected; | |||
782 | } | |||
783 | return success(); | |||
784 | } | |||
785 | ||||
786 | LogicalResult ContractionOp::verify() { | |||
787 | auto lhsType = getLhsType(); | |||
788 | auto rhsType = getRhsType(); | |||
789 | auto accType = getAccType(); | |||
790 | auto resType = getResultType(); | |||
791 | ||||
792 | // Verify that an indexing map was specified for each vector operand. | |||
793 | if (getIndexingMapsArray().size() != 3) | |||
794 | return emitOpError("expected an indexing map for each vector operand"); | |||
795 | ||||
796 | // Verify that each index map has 'numIterators' inputs, no symbols, and | |||
797 | // that the number of map outputs equals the rank of its associated | |||
798 | // vector operand. | |||
799 | unsigned numIterators = getIteratorTypes().getValue().size(); | |||
800 | for (const auto &it : llvm::enumerate(getIndexingMapsArray())) { | |||
801 | auto index = it.index(); | |||
802 | auto map = it.value(); | |||
803 | if (map.getNumSymbols() != 0) | |||
804 | return emitOpError("expected indexing map ") | |||
805 | << index << " to have no symbols"; | |||
806 | auto vectorType = getOperand(index).getType().dyn_cast<VectorType>(); | |||
807 | unsigned rank = vectorType ? vectorType.getShape().size() : 0; | |||
808 | // Verify that the map has the right number of inputs, outputs, and indices. | |||
809 | // This also correctly accounts for (..) -> () for rank-0 results. | |||
810 | if (map.getNumDims() != numIterators) | |||
811 | return emitOpError("expected indexing map ") | |||
812 | << index << " to have " << numIterators << " number of inputs"; | |||
813 | if (map.getNumResults() != rank) | |||
814 | return emitOpError("expected indexing map ") | |||
815 | << index << " to have " << rank << " number of outputs"; | |||
816 | if (!map.isProjectedPermutation()) | |||
817 | return emitOpError("expected indexing map ") | |||
818 | << index << " to be a projected permutation of its inputs"; | |||
819 | } | |||
820 | ||||
821 | auto contractingDimMap = getContractingDimMap(); | |||
822 | auto batchDimMap = getBatchDimMap(); | |||
823 | ||||
824 | // Verify at least one contracting dimension pair was specified. | |||
825 | if (contractingDimMap.empty()) | |||
826 | return emitOpError("expected at least one contracting dimension pair"); | |||
827 | ||||
828 | // Verify contracting dimension map was properly constructed. | |||
829 | if (!verifyDimMap(lhsType, rhsType, contractingDimMap)) | |||
830 | return emitOpError("invalid contracting dimension map"); | |||
831 | ||||
832 | // Verify batch dimension map was properly constructed. | |||
833 | if (!verifyDimMap(lhsType, rhsType, batchDimMap)) | |||
834 | return emitOpError("invalid batch dimension map"); | |||
835 | ||||
836 | // Verify 'accType' and 'resType' shape. | |||
837 | if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType, | |||
838 | contractingDimMap, batchDimMap))) | |||
839 | return failure(); | |||
840 | ||||
841 | // Verify that either two vector masks are set or none are set. | |||
842 | auto lhsMaskType = getLHSVectorMaskType(); | |||
843 | auto rhsMaskType = getRHSVectorMaskType(); | |||
844 | if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) | |||
845 | return emitOpError("invalid number of vector masks specified"); | |||
846 | if (lhsMaskType && rhsMaskType) { | |||
847 | // Verify mask rank == argument rank. | |||
848 | if (lhsMaskType.getShape().size() != lhsType.getShape().size() || | |||
849 | rhsMaskType.getShape().size() != rhsType.getShape().size()) | |||
850 | return emitOpError("invalid vector mask rank"); | |||
851 | } | |||
852 | ||||
853 | // Verify supported combining kind. | |||
854 | auto vectorType = resType.dyn_cast<VectorType>(); | |||
855 | auto elementType = vectorType ? vectorType.getElementType() : resType; | |||
856 | if (!isSupportedCombiningKind(getKind(), elementType)) | |||
857 | return emitOpError("unsupported contraction type"); | |||
858 | ||||
859 | return success(); | |||
860 | } | |||
861 | ||||
862 | SmallVector<StringRef> ContractionOp::getTraitAttrNames() { | |||
863 | return SmallVector<StringRef>{getIndexingMapsAttrName(), | |||
864 | getIteratorTypesAttrName(), getKindAttrName()}; | |||
865 | } | |||
866 | ||||
867 | static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) { | |||
868 | for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) | |||
869 | if (targetExpr == map.getResult(i)) | |||
870 | return i; | |||
871 | return -1; | |||
872 | } | |||
873 | ||||
874 | static std::vector<std::pair<int64_t, int64_t>> | |||
875 | getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes, | |||
876 | IteratorType targetIteratorType, MLIRContext *context) { | |||
877 | std::vector<std::pair<int64_t, int64_t>> dimMap; | |||
878 | for (const auto &it : llvm::enumerate(iteratorTypes)) { | |||
879 | auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue(); | |||
880 | if (iteratorType != targetIteratorType) | |||
881 | continue; | |||
882 | // Search lhs/rhs map results for 'targetExpr'. | |||
883 | auto targetExpr = getAffineDimExpr(it.index(), context); | |||
884 | int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr); | |||
885 | int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr); | |||
886 | if (lhsDim >= 0 && rhsDim >= 0) | |||
887 | dimMap.emplace_back(lhsDim, rhsDim); | |||
888 | } | |||
889 | return dimMap; | |||
890 | } | |||
891 | ||||
892 | void ContractionOp::getIterationBounds( | |||
893 | SmallVectorImpl<int64_t> &iterationBounds) { | |||
894 | auto lhsShape = getLhsType().getShape(); | |||
895 | auto resVectorType = getResultType().dyn_cast<VectorType>(); | |||
896 | SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); | |||
897 | SmallVector<int64_t, 2> iterationShape; | |||
898 | for (const auto &it : llvm::enumerate(getIteratorTypes())) { | |||
899 | // Search lhs/rhs map results for 'targetExpr'. | |||
900 | auto targetExpr = getAffineDimExpr(it.index(), getContext()); | |||
901 | auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue(); | |||
902 | if (iteratorType == IteratorType::reduction) { | |||
903 | // Get reduction dim size from lhs shape (same size in rhsShape). | |||
904 | int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr); | |||
905 | assert(lhsDimIndex >= 0)(static_cast <bool> (lhsDimIndex >= 0) ? void (0) : __assert_fail ("lhsDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp" , 905, __extension__ __PRETTY_FUNCTION__)); | |||
906 | iterationBounds.push_back(lhsShape[lhsDimIndex]); | |||
907 | continue; | |||
908 | } | |||
909 | // Get parallel dimension size from result shape. | |||
910 | int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr); | |||
911 | assert(resDimIndex >= 0)(static_cast <bool> (resDimIndex >= 0) ? void (0) : __assert_fail ("resDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp" , 911, __extension__ __PRETTY_FUNCTION__)); | |||
912 | assert(resVectorType != nullptr)(static_cast <bool> (resVectorType != nullptr) ? void ( 0) : __assert_fail ("resVectorType != nullptr", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp" , 912, __extension__ __PRETTY_FUNCTION__)); | |||
913 | iterationBounds.push_back(resVectorType.getShape()[resDimIndex]); | |||
914 | } | |||
915 | } | |||
916 | ||||
917 | void ContractionOp::getIterationIndexMap( | |||
918 | std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) { | |||
919 | unsigned numMaps = getIndexingMapsArray().size(); | |||
920 | iterationIndexMap.resize(numMaps); | |||
921 | for (const auto &it : llvm::enumerate(getIndexingMapsArray())) { | |||
922 | auto index = it.index(); | |||
923 | auto map = it.value(); | |||
924 | for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) { | |||
925 | auto dim = map.getResult(i).cast<AffineDimExpr>(); | |||
926 | iterationIndexMap[index][dim.getPosition()] = i; | |||
927 | } | |||
928 | } | |||
929 | } | |||
930 | ||||
931 | std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { | |||
932 | SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); | |||
933 | return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction, | |||
934 | getContext()); | |||
935 | } | |||
936 | ||||
937 | std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { | |||
938 | SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray()); | |||
939 | return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel, | |||
940 | getContext()); | |||
941 | } | |||
942 | ||||
943 | std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() { | |||
944 | SmallVector<int64_t, 4> shape; | |||
945 | getIterationBounds(shape); | |||
946 | return shape; | |||
947 | } | |||
948 | ||||
949 | /// Return a fused vector::ContractionOp which represents a patterns such as: | |||
950 | /// | |||
951 | /// ```mlir | |||
952 | /// %c0 = vector.constant 0: ... | |||
953 | /// %c = vector.contract %a, %b, %c0: ... | |||
954 | /// %e = add %c, %d: ... | |||
955 | /// ``` | |||
956 | /// | |||
957 | /// by: | |||
958 | /// | |||
959 | /// ```mlir | |||
960 | /// %e = vector.contract %a, %b, %d: ... | |||
961 | /// ``` | |||
962 | /// | |||
963 | /// Return null if the canonicalization does not apply. | |||
964 | // TODO: This should be a folding of Add into Contract in core but while they | |||
965 | // live in different dialects, it is not possible without unnatural | |||
966 | // dependencies. | |||
967 | template <typename AddOpType> | |||
968 | struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> { | |||
969 | using OpRewritePattern<AddOpType>::OpRewritePattern; | |||
970 | ||||
971 | LogicalResult matchAndRewrite(AddOpType addOp, | |||
972 | PatternRewriter &rewriter) const override { | |||
973 | auto canonicalize = [&](Value maybeContraction, | |||
974 | Value otherOperand) -> vector::ContractionOp { | |||
975 | vector::ContractionOp contractionOp = | |||
976 | dyn_cast_or_null<vector::ContractionOp>( | |||
977 | maybeContraction.getDefiningOp()); | |||
978 | if (!contractionOp) | |||
979 | return vector::ContractionOp(); | |||
980 | if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>( | |||
981 | contractionOp.getAcc().getDefiningOp())) { | |||
982 | if (maybeZero.getValue() == | |||
983 | rewriter.getZeroAttr(contractionOp.getAcc().getType())) { | |||
984 | BlockAndValueMapping bvm; | |||
985 | bvm.map(contractionOp.getAcc(), otherOperand); | |||
986 | auto newContraction = | |||
987 | cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm)); | |||
988 | rewriter.replaceOp(addOp, newContraction.getResult()); | |||
989 | return newContraction; | |||
990 | } | |||
991 | } | |||
992 | return vector::ContractionOp(); | |||
993 | }; | |||
994 | ||||
995 | Value a = addOp->getOperand(0), b = addOp->getOperand(1); | |||
996 | vector::ContractionOp contract = canonicalize(a, b); | |||
997 | contract = contract ? contract : canonicalize(b, a); | |||
998 | return contract ? success() : failure(); | |||
999 | } | |||
1000 | }; | |||
1001 | ||||
1002 | void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
1003 | MLIRContext *context) { | |||
1004 | results.add<CanonicalizeContractAdd<arith::AddIOp>, | |||
1005 | CanonicalizeContractAdd<arith::AddFOp>>(context); | |||
1006 | } | |||
1007 | ||||
1008 | //===----------------------------------------------------------------------===// | |||
1009 | // ExtractElementOp | |||
1010 | //===----------------------------------------------------------------------===// | |||
1011 | ||||
1012 | void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, | |||
1013 | Value source) { | |||
1014 | result.addOperands({source}); | |||
1015 | result.addTypes(source.getType().cast<VectorType>().getElementType()); | |||
1016 | } | |||
1017 | ||||
1018 | void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, | |||
1019 | Value source, Value position) { | |||
1020 | result.addOperands({source, position}); | |||
1021 | result.addTypes(source.getType().cast<VectorType>().getElementType()); | |||
1022 | } | |||
1023 | ||||
1024 | LogicalResult vector::ExtractElementOp::verify() { | |||
1025 | VectorType vectorType = getVectorType(); | |||
1026 | if (vectorType.getRank() == 0) { | |||
1027 | if (getPosition()) | |||
1028 | return emitOpError("expected position to be empty with 0-D vector"); | |||
1029 | return success(); | |||
1030 | } | |||
1031 | if (vectorType.getRank() != 1) | |||
1032 | return emitOpError("unexpected >1 vector rank"); | |||
1033 | if (!getPosition()) | |||
1034 | return emitOpError("expected position for 1-D vector"); | |||
1035 | return success(); | |||
1036 | } | |||
1037 | ||||
1038 | OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) { | |||
1039 | // Skip the 0-D vector here now. | |||
1040 | if (operands.size() < 2) | |||
1041 | return {}; | |||
1042 | ||||
1043 | Attribute src = operands[0]; | |||
1044 | Attribute pos = operands[1]; | |||
1045 | ||||
1046 | // Fold extractelement (splat X) -> X. | |||
1047 | if (auto splat = getVector().getDefiningOp<vector::SplatOp>()) | |||
1048 | return splat.getInput(); | |||
1049 | ||||
1050 | if (!pos || !src) | |||
1051 | return {}; | |||
1052 | ||||
1053 | auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>(); | |||
1054 | ||||
1055 | auto attr = pos.dyn_cast<IntegerAttr>(); | |||
1056 | uint64_t posIdx = attr.getInt(); | |||
1057 | ||||
1058 | return srcElements[posIdx]; | |||
1059 | } | |||
1060 | ||||
1061 | //===----------------------------------------------------------------------===// | |||
1062 | // ExtractOp | |||
1063 | //===----------------------------------------------------------------------===// | |||
1064 | ||||
1065 | void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, | |||
1066 | Value source, ArrayRef<int64_t> position) { | |||
1067 | build(builder, result, source, getVectorSubscriptAttr(builder, position)); | |||
1068 | } | |||
1069 | ||||
1070 | // Convenience builder which assumes the values are constant indices. | |||
1071 | void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, | |||
1072 | Value source, ValueRange position) { | |||
1073 | SmallVector<int64_t, 4> positionConstants = | |||
1074 | llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { | |||
1075 | return pos.getDefiningOp<arith::ConstantIndexOp>().value(); | |||
1076 | })); | |||
1077 | build(builder, result, source, positionConstants); | |||
1078 | } | |||
1079 | ||||
1080 | LogicalResult | |||
1081 | ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>, | |||
1082 | ValueRange operands, DictionaryAttr attributes, | |||
1083 | RegionRange, | |||
1084 | SmallVectorImpl<Type> &inferredReturnTypes) { | |||
1085 | ExtractOp::Adaptor op(operands, attributes); | |||
1086 | auto vectorType = op.getVector().getType().cast<VectorType>(); | |||
1087 | if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) { | |||
1088 | inferredReturnTypes.push_back(vectorType.getElementType()); | |||
1089 | } else { | |||
1090 | auto n = | |||
1091 | std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1); | |||
1092 | inferredReturnTypes.push_back(VectorType::get( | |||
1093 | vectorType.getShape().drop_front(n), vectorType.getElementType())); | |||
1094 | } | |||
1095 | return success(); | |||
1096 | } | |||
1097 | ||||
1098 | bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { | |||
1099 | // Allow extracting 1-element vectors instead of scalars. | |||
1100 | auto isCompatible = [](TypeRange l, TypeRange r) { | |||
1101 | auto vectorType = l.front().dyn_cast<VectorType>(); | |||
1102 | return vectorType && vectorType.getShape().equals({1}) && | |||
1103 | vectorType.getElementType() == r.front(); | |||
1104 | }; | |||
1105 | if (l.size() == 1 && r.size() == 1 && | |||
1106 | (isCompatible(l, r) || isCompatible(r, l))) | |||
1107 | return true; | |||
1108 | return l == r; | |||
1109 | } | |||
1110 | ||||
1111 | LogicalResult vector::ExtractOp::verify() { | |||
1112 | auto positionAttr = getPosition().getValue(); | |||
1113 | if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank())) | |||
1114 | return emitOpError( | |||
1115 | "expected position attribute of rank smaller than vector rank"); | |||
1116 | for (const auto &en : llvm::enumerate(positionAttr)) { | |||
1117 | auto attr = en.value().dyn_cast<IntegerAttr>(); | |||
1118 | if (!attr || attr.getInt() < 0 || | |||
1119 | attr.getInt() >= getVectorType().getDimSize(en.index())) | |||
1120 | return emitOpError("expected position attribute #") | |||
1121 | << (en.index() + 1) | |||
1122 | << " to be a non-negative integer smaller than the corresponding " | |||
1123 | "vector dimension"; | |||
1124 | } | |||
1125 | return success(); | |||
1126 | } | |||
1127 | ||||
1128 | template <typename IntType> | |||
1129 | static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) { | |||
1130 | return llvm::to_vector<4>(llvm::map_range( | |||
1131 | arrayAttr.getAsRange<IntegerAttr>(), | |||
1132 | [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); })); | |||
1133 | } | |||
1134 | ||||
1135 | /// Fold the result of chains of ExtractOp in place by simply concatenating the | |||
1136 | /// positions. | |||
1137 | static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) { | |||
1138 | if (!extractOp.getVector().getDefiningOp<ExtractOp>()) | |||
1139 | return failure(); | |||
1140 | ||||
1141 | SmallVector<int64_t, 4> globalPosition; | |||
1142 | ExtractOp currentOp = extractOp; | |||
1143 | auto extrPos = extractVector<int64_t>(currentOp.getPosition()); | |||
1144 | globalPosition.append(extrPos.rbegin(), extrPos.rend()); | |||
1145 | while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) { | |||
1146 | currentOp = nextOp; | |||
1147 | auto extrPos = extractVector<int64_t>(currentOp.getPosition()); | |||
1148 | globalPosition.append(extrPos.rbegin(), extrPos.rend()); | |||
1149 | } | |||
1150 | extractOp.setOperand(currentOp.getVector()); | |||
1151 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
1152 | OpBuilder b(extractOp.getContext()); | |||
1153 | std::reverse(globalPosition.begin(), globalPosition.end()); | |||
1154 | extractOp->setAttr(ExtractOp::getPositionAttrStrName(), | |||
1155 | b.getI64ArrayAttr(globalPosition)); | |||
1156 | return success(); | |||
1157 | } | |||
1158 | ||||
1159 | namespace { | |||
1160 | /// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps. | |||
1161 | /// Walk back a chain of InsertOp/TransposeOp until we hit a match. | |||
1162 | /// Compose TransposeOp permutations as we walk back. | |||
1163 | /// This helper class keeps an updated extraction position `extractPosition` | |||
1164 | /// with extra trailing sentinels. | |||
1165 | /// The sentinels encode the internal transposition status of the result vector. | |||
1166 | /// As we iterate, extractPosition is permuted and updated. | |||
1167 | class ExtractFromInsertTransposeChainState { | |||
1168 | public: | |||
1169 | ExtractFromInsertTransposeChainState(ExtractOp e); | |||
1170 | ||||
1171 | /// Iterate over producing insert and transpose ops until we find a fold. | |||
1172 | Value fold(); | |||
1173 | ||||
1174 | private: | |||
1175 | /// Return true if the vector at position `a` is contained within the vector | |||
1176 | /// at position `b`. Under insert/extract semantics, this is the same as `a` | |||
1177 | /// is a prefix of `b`. | |||
1178 | template <typename ContainerA, typename ContainerB> | |||
1179 | bool isContainedWithin(const ContainerA &a, const ContainerB &b) { | |||
1180 | return a.size() <= b.size() && | |||
1181 | std::equal(a.begin(), a.begin() + a.size(), b.begin()); | |||
1182 | } | |||
1183 | ||||
1184 | /// Return true if the vector at position `a` intersects the vector at | |||
1185 | /// position `b`. Under insert/extract semantics, this is the same as equality | |||
1186 | /// of all entries of `a` that are >=0 with the corresponding entries of b. | |||
1187 | /// Comparison is on the common prefix (i.e. zip). | |||
1188 | template <typename ContainerA, typename ContainerB> | |||
1189 | bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) { | |||
1190 | for (auto [elemA, elemB] : llvm::zip(a, b)) { | |||
1191 | if (elemA < 0 || elemB < 0) | |||
1192 | continue; | |||
1193 | if (elemA != elemB) | |||
1194 | return false; | |||
1195 | } | |||
1196 | return true; | |||
1197 | } | |||
1198 | ||||
1199 | /// Folding is only possible in the absence of an internal permutation in the | |||
1200 | /// result vector. | |||
1201 | bool canFold() { | |||
1202 | return (sentinels == | |||
1203 | makeArrayRef(extractPosition).drop_front(extractedRank)); | |||
1204 | } | |||
1205 | ||||
1206 | // Helper to get the next defining op of interest. | |||
1207 | void updateStateForNextIteration(Value v) { | |||
1208 | nextInsertOp = v.getDefiningOp<vector::InsertOp>(); | |||
1209 | nextTransposeOp = v.getDefiningOp<vector::TransposeOp>(); | |||
1210 | }; | |||
1211 | ||||
1212 | // Case 1. If we hit a transpose, just compose the map and iterate. | |||
1213 | // Invariant: insert + transpose do not change rank, we can always compose. | |||
1214 | LogicalResult handleTransposeOp(); | |||
1215 | ||||
1216 | // Case 2: the insert position matches extractPosition exactly, early return. | |||
1217 | LogicalResult handleInsertOpWithMatchingPos(Value &res); | |||
1218 | ||||
1219 | /// Case 3: if the insert position is a prefix of extractPosition, extract a | |||
1220 | /// portion of the source of the insert. | |||
1221 | /// Example: | |||
1222 | /// ``` | |||
1223 | /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5> | |||
1224 | /// // extractPosition == [1, 2, 3] | |||
1225 | /// %ext = vector.extract %ins[1, 0]: vector<3x4x5> | |||
1226 | /// // can fold to vector.extract %source[0, 3] | |||
1227 | /// %ext = vector.extract %source[3]: vector<5x6> | |||
1228 | /// ``` | |||
1229 | /// To traverse through %source, we need to set the leading dims to 0 and | |||
1230 | /// drop the extra leading dims. | |||
1231 | /// This method updates the internal state. | |||
1232 | LogicalResult handleInsertOpWithPrefixPos(Value &res); | |||
1233 | ||||
1234 | /// Try to fold in place to extract(source, extractPosition) and return the | |||
1235 | /// folded result. Return null if folding is not possible (e.g. due to an | |||
1236 | /// internal tranposition in the result). | |||
1237 | Value tryToFoldExtractOpInPlace(Value source); | |||
1238 | ||||
1239 | ExtractOp extractOp; | |||
1240 | int64_t vectorRank; | |||
1241 | int64_t extractedRank; | |||
1242 | ||||
1243 | InsertOp nextInsertOp; | |||
1244 | TransposeOp nextTransposeOp; | |||
1245 | ||||
1246 | /// Sentinel values that encode the internal permutation status of the result. | |||
1247 | /// They are set to (-1, ... , -k) at the beginning and appended to | |||
1248 | /// `extractPosition`. | |||
1249 | /// In the end, the tail of `extractPosition` must be exactly `sentinels` to | |||
1250 | /// ensure that there is no internal transposition. | |||
1251 | /// Internal transposition cannot be accounted for with a folding pattern. | |||
1252 | // TODO: We could relax the internal transposition with an extra transposition | |||
1253 | // operation in a future canonicalizer. | |||
1254 | SmallVector<int64_t> sentinels; | |||
1255 | SmallVector<int64_t> extractPosition; | |||
1256 | }; | |||
1257 | } // namespace | |||
1258 | ||||
1259 | ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState( | |||
1260 | ExtractOp e) | |||
1261 | : extractOp(e), vectorRank(extractOp.getVectorType().getRank()), | |||
1262 | extractedRank(extractOp.getPosition().size()) { | |||
1263 | assert(vectorRank >= extractedRank && "extracted pos overflow")(static_cast <bool> (vectorRank >= extractedRank && "extracted pos overflow") ? void (0) : __assert_fail ("vectorRank >= extractedRank && \"extracted pos overflow\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1263, __extension__ __PRETTY_FUNCTION__)); | |||
1264 | sentinels.reserve(vectorRank - extractedRank); | |||
1265 | for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i) | |||
1266 | sentinels.push_back(-(i + 1)); | |||
1267 | extractPosition = extractVector<int64_t>(extractOp.getPosition()); | |||
1268 | llvm::append_range(extractPosition, sentinels); | |||
1269 | } | |||
1270 | ||||
1271 | // Case 1. If we hit a transpose, just compose the map and iterate. | |||
1272 | // Invariant: insert + transpose do not change rank, we can always compose. | |||
1273 | LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() { | |||
1274 | if (!nextTransposeOp) | |||
1275 | return failure(); | |||
1276 | auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp()); | |||
1277 | AffineMap m = inversePermutation( | |||
1278 | AffineMap::getPermutationMap(permutation, extractOp.getContext())); | |||
1279 | extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition)); | |||
1280 | return success(); | |||
1281 | } | |||
1282 | ||||
1283 | // Case 2: the insert position matches extractPosition exactly, early return. | |||
1284 | LogicalResult | |||
1285 | ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos( | |||
1286 | Value &res) { | |||
1287 | auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition()); | |||
1288 | if (makeArrayRef(insertedPos) != | |||
1289 | llvm::makeArrayRef(extractPosition).take_front(extractedRank)) | |||
1290 | return failure(); | |||
1291 | // Case 2.a. early-exit fold. | |||
1292 | res = nextInsertOp.getSource(); | |||
1293 | // Case 2.b. if internal transposition is present, canFold will be false. | |||
1294 | return success(); | |||
1295 | } | |||
1296 | ||||
1297 | /// Case 3: if inserted position is a prefix of extractPosition, | |||
1298 | /// extract a portion of the source of the insertion. | |||
1299 | /// This method updates the internal state. | |||
1300 | LogicalResult | |||
1301 | ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) { | |||
1302 | auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition()); | |||
1303 | if (!isContainedWithin(insertedPos, extractPosition)) | |||
1304 | return failure(); | |||
1305 | // Set leading dims to zero. | |||
1306 | std::fill_n(extractPosition.begin(), insertedPos.size(), 0); | |||
1307 | // Drop extra leading dims. | |||
1308 | extractPosition.erase(extractPosition.begin(), | |||
1309 | extractPosition.begin() + insertedPos.size()); | |||
1310 | extractedRank = extractPosition.size() - sentinels.size(); | |||
1311 | // Case 3.a. early-exit fold (break and delegate to post-while path). | |||
1312 | res = nextInsertOp.getSource(); | |||
1313 | // Case 3.b. if internal transposition is present, canFold will be false. | |||
1314 | return success(); | |||
1315 | } | |||
1316 | ||||
1317 | /// Try to fold in place to extract(source, extractPosition) and return the | |||
1318 | /// folded result. Return null if folding is not possible (e.g. due to an | |||
1319 | /// internal tranposition in the result). | |||
1320 | Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace( | |||
1321 | Value source) { | |||
1322 | // If we can't fold (either internal transposition, or nothing to fold), bail. | |||
1323 | bool nothingToFold = (source == extractOp.getVector()); | |||
1324 | if (nothingToFold || !canFold()) | |||
1325 | return Value(); | |||
1326 | // Otherwise, fold by updating the op inplace and return its result. | |||
1327 | OpBuilder b(extractOp.getContext()); | |||
1328 | extractOp->setAttr( | |||
1329 | extractOp.getPositionAttrName(), | |||
1330 | b.getI64ArrayAttr( | |||
1331 | makeArrayRef(extractPosition).take_front(extractedRank))); | |||
1332 | extractOp.getVectorMutable().assign(source); | |||
1333 | return extractOp.getResult(); | |||
1334 | } | |||
1335 | ||||
1336 | /// Iterate over producing insert and transpose ops until we find a fold. | |||
1337 | Value ExtractFromInsertTransposeChainState::fold() { | |||
1338 | Value valueToExtractFrom = extractOp.getVector(); | |||
1339 | updateStateForNextIteration(valueToExtractFrom); | |||
1340 | while (nextInsertOp || nextTransposeOp) { | |||
1341 | // Case 1. If we hit a transpose, just compose the map and iterate. | |||
1342 | // Invariant: insert + transpose do not change rank, we can always compose. | |||
1343 | if (succeeded(handleTransposeOp())) { | |||
1344 | valueToExtractFrom = nextTransposeOp.getVector(); | |||
1345 | updateStateForNextIteration(valueToExtractFrom); | |||
1346 | continue; | |||
1347 | } | |||
1348 | ||||
1349 | Value result; | |||
1350 | // Case 2: the position match exactly. | |||
1351 | if (succeeded(handleInsertOpWithMatchingPos(result))) | |||
1352 | return result; | |||
1353 | ||||
1354 | // Case 3: if the inserted position is a prefix of extractPosition, we can | |||
1355 | // just extract a portion of the source of the insert. | |||
1356 | if (succeeded(handleInsertOpWithPrefixPos(result))) | |||
1357 | return tryToFoldExtractOpInPlace(result); | |||
1358 | ||||
1359 | // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel | |||
1360 | // values. This is a more difficult case and we bail. | |||
1361 | auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition()); | |||
1362 | if (isContainedWithin(extractPosition, insertedPos) || | |||
1363 | intersectsWhereNonNegative(extractPosition, insertedPos)) | |||
1364 | return Value(); | |||
1365 | ||||
1366 | // Case 5: No intersection, we forward the extract to insertOp.dest(). | |||
1367 | valueToExtractFrom = nextInsertOp.getDest(); | |||
1368 | updateStateForNextIteration(valueToExtractFrom); | |||
1369 | } | |||
1370 | // If after all this we can fold, go for it. | |||
1371 | return tryToFoldExtractOpInPlace(valueToExtractFrom); | |||
1372 | } | |||
1373 | ||||
1374 | /// Fold extractOp with scalar result coming from BroadcastOp or SplatOp. | |||
1375 | static Value foldExtractFromBroadcast(ExtractOp extractOp) { | |||
1376 | Operation *defOp = extractOp.getVector().getDefiningOp(); | |||
1377 | if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) | |||
1378 | return Value(); | |||
1379 | Value source = defOp->getOperand(0); | |||
1380 | if (extractOp.getType() == source.getType()) | |||
1381 | return source; | |||
1382 | auto getRank = [](Type type) { | |||
1383 | return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; | |||
1384 | }; | |||
1385 | // If splat or broadcast from a scalar, just return the source scalar. | |||
1386 | unsigned broadcastSrcRank = getRank(source.getType()); | |||
1387 | if (broadcastSrcRank == 0) | |||
1388 | return source; | |||
1389 | ||||
1390 | unsigned extractResultRank = getRank(extractOp.getType()); | |||
1391 | if (extractResultRank >= broadcastSrcRank) | |||
1392 | return Value(); | |||
1393 | // Check that the dimension of the result haven't been broadcasted. | |||
1394 | auto extractVecType = extractOp.getType().dyn_cast<VectorType>(); | |||
1395 | auto broadcastVecType = source.getType().dyn_cast<VectorType>(); | |||
1396 | if (extractVecType && broadcastVecType && | |||
1397 | extractVecType.getShape() != | |||
1398 | broadcastVecType.getShape().take_back(extractResultRank)) | |||
1399 | return Value(); | |||
1400 | ||||
1401 | auto broadcastOp = cast<vector::BroadcastOp>(defOp); | |||
1402 | int64_t rankDiff = broadcastSrcRank - extractResultRank; | |||
1403 | // Detect all the positions that come from "dim-1" broadcasting. | |||
1404 | // These dimensions correspond to "dim-1" broadcasted dims; set the mathching | |||
1405 | // extract position to `0` when extracting from the source operand. | |||
1406 | llvm::SetVector<int64_t> broadcastedUnitDims = | |||
1407 | broadcastOp.computeBroadcastedUnitDims(); | |||
1408 | auto extractPos = extractVector<int64_t>(extractOp.getPosition()); | |||
1409 | for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i) | |||
1410 | if (broadcastedUnitDims.contains(i)) | |||
1411 | extractPos[i] = 0; | |||
1412 | // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the | |||
1413 | // matching extract position when extracting from the source operand. | |||
1414 | extractPos.erase(extractPos.begin(), | |||
1415 | std::next(extractPos.begin(), extractPos.size() - rankDiff)); | |||
1416 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
1417 | OpBuilder b(extractOp.getContext()); | |||
1418 | extractOp.setOperand(source); | |||
1419 | extractOp->setAttr(ExtractOp::getPositionAttrStrName(), | |||
1420 | b.getI64ArrayAttr(extractPos)); | |||
1421 | return extractOp.getResult(); | |||
1422 | } | |||
1423 | ||||
1424 | // Fold extractOp with source coming from ShapeCast op. | |||
1425 | static Value foldExtractFromShapeCast(ExtractOp extractOp) { | |||
1426 | auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>(); | |||
1427 | if (!shapeCastOp) | |||
1428 | return Value(); | |||
1429 | // Get the nth dimension size starting from lowest dimension. | |||
1430 | auto getDimReverse = [](VectorType type, int64_t n) { | |||
1431 | return type.getShape().take_back(n + 1).front(); | |||
1432 | }; | |||
1433 | int64_t destinationRank = | |||
1434 | extractOp.getType().isa<VectorType>() | |||
1435 | ? extractOp.getType().cast<VectorType>().getRank() | |||
1436 | : 0; | |||
1437 | if (destinationRank > shapeCastOp.getSourceVectorType().getRank()) | |||
1438 | return Value(); | |||
1439 | if (destinationRank > 0) { | |||
1440 | auto destinationType = extractOp.getResult().getType().cast<VectorType>(); | |||
1441 | for (int64_t i = 0; i < destinationRank; i++) { | |||
1442 | // The lowest dimension of of the destination must match the lowest | |||
1443 | // dimension of the shapecast op source. | |||
1444 | // TODO: This case could be support in a canonicalization pattern. | |||
1445 | if (getDimReverse(shapeCastOp.getSourceVectorType(), i) != | |||
1446 | getDimReverse(destinationType, i)) | |||
1447 | return Value(); | |||
1448 | } | |||
1449 | } | |||
1450 | // Extract the strides associated with the extract op vector source. Then use | |||
1451 | // this to calculate a linearized position for the extract. | |||
1452 | auto extractedPos = extractVector<int64_t>(extractOp.getPosition()); | |||
1453 | std::reverse(extractedPos.begin(), extractedPos.end()); | |||
1454 | SmallVector<int64_t, 4> strides; | |||
1455 | int64_t stride = 1; | |||
1456 | for (int64_t i = 0, e = extractedPos.size(); i < e; i++) { | |||
1457 | strides.push_back(stride); | |||
1458 | stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank); | |||
1459 | } | |||
1460 | ||||
1461 | int64_t position = linearize(extractedPos, strides); | |||
1462 | // Then extract the strides associated to the shapeCast op vector source and | |||
1463 | // delinearize the position using those strides. | |||
1464 | SmallVector<int64_t, 4> newStrides; | |||
1465 | int64_t numDimension = | |||
1466 | shapeCastOp.getSourceVectorType().getRank() - destinationRank; | |||
1467 | stride = 1; | |||
1468 | for (int64_t i = 0; i < numDimension; i++) { | |||
1469 | newStrides.push_back(stride); | |||
1470 | stride *= | |||
1471 | getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank); | |||
1472 | } | |||
1473 | std::reverse(newStrides.begin(), newStrides.end()); | |||
1474 | SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position); | |||
1475 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
1476 | OpBuilder b(extractOp.getContext()); | |||
1477 | extractOp->setAttr(ExtractOp::getPositionAttrStrName(), | |||
1478 | b.getI64ArrayAttr(newPosition)); | |||
1479 | extractOp.setOperand(shapeCastOp.getSource()); | |||
1480 | return extractOp.getResult(); | |||
1481 | } | |||
1482 | ||||
1483 | /// Fold an ExtractOp from ExtractStridedSliceOp. | |||
1484 | static Value foldExtractFromExtractStrided(ExtractOp extractOp) { | |||
1485 | auto extractStridedSliceOp = | |||
1486 | extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>(); | |||
1487 | if (!extractStridedSliceOp) | |||
1488 | return Value(); | |||
1489 | // Return if 'extractStridedSliceOp' has non-unit strides. | |||
1490 | if (extractStridedSliceOp.hasNonUnitStrides()) | |||
1491 | return Value(); | |||
1492 | ||||
1493 | // Trim offsets for dimensions fully extracted. | |||
1494 | auto sliceOffsets = | |||
1495 | extractVector<int64_t>(extractStridedSliceOp.getOffsets()); | |||
1496 | while (!sliceOffsets.empty()) { | |||
1497 | size_t lastOffset = sliceOffsets.size() - 1; | |||
1498 | if (sliceOffsets.back() != 0 || | |||
1499 | extractStridedSliceOp.getType().getDimSize(lastOffset) != | |||
1500 | extractStridedSliceOp.getVectorType().getDimSize(lastOffset)) | |||
1501 | break; | |||
1502 | sliceOffsets.pop_back(); | |||
1503 | } | |||
1504 | unsigned destinationRank = 0; | |||
1505 | if (auto vecType = extractOp.getType().dyn_cast<VectorType>()) | |||
1506 | destinationRank = vecType.getRank(); | |||
1507 | // The dimensions of the result need to be untouched by the | |||
1508 | // extractStridedSlice op. | |||
1509 | if (destinationRank > | |||
1510 | extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size()) | |||
1511 | return Value(); | |||
1512 | auto extractedPos = extractVector<int64_t>(extractOp.getPosition()); | |||
1513 | assert(extractedPos.size() >= sliceOffsets.size())(static_cast <bool> (extractedPos.size() >= sliceOffsets .size()) ? void (0) : __assert_fail ("extractedPos.size() >= sliceOffsets.size()" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1513, __extension__ __PRETTY_FUNCTION__)); | |||
1514 | for (size_t i = 0, e = sliceOffsets.size(); i < e; i++) | |||
1515 | extractedPos[i] = extractedPos[i] + sliceOffsets[i]; | |||
1516 | extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector()); | |||
1517 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
1518 | OpBuilder b(extractOp.getContext()); | |||
1519 | extractOp->setAttr(ExtractOp::getPositionAttrStrName(), | |||
1520 | b.getI64ArrayAttr(extractedPos)); | |||
1521 | return extractOp.getResult(); | |||
1522 | } | |||
1523 | ||||
1524 | /// Fold extract_op fed from a chain of insertStridedSlice ops. | |||
1525 | static Value foldExtractStridedOpFromInsertChain(ExtractOp op) { | |||
1526 | int64_t destinationRank = op.getType().isa<VectorType>() | |||
1527 | ? op.getType().cast<VectorType>().getRank() | |||
1528 | : 0; | |||
1529 | auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>(); | |||
1530 | while (insertOp) { | |||
1531 | int64_t insertRankDiff = insertOp.getDestVectorType().getRank() - | |||
1532 | insertOp.getSourceVectorType().getRank(); | |||
1533 | if (destinationRank > insertOp.getSourceVectorType().getRank()) | |||
1534 | return Value(); | |||
1535 | auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets()); | |||
1536 | auto extractOffsets = extractVector<int64_t>(op.getPosition()); | |||
1537 | ||||
1538 | if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) { | |||
1539 | return attr.cast<IntegerAttr>().getInt() != 1; | |||
1540 | })) | |||
1541 | return Value(); | |||
1542 | bool disjoint = false; | |||
1543 | SmallVector<int64_t, 4> offsetDiffs; | |||
1544 | for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { | |||
1545 | int64_t start = insertOffsets[dim]; | |||
1546 | int64_t size = | |||
1547 | (dim < insertRankDiff) | |||
1548 | ? 1 | |||
1549 | : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff); | |||
1550 | int64_t end = start + size; | |||
1551 | int64_t offset = extractOffsets[dim]; | |||
1552 | // Check if the start of the extract offset is in the interval inserted. | |||
1553 | if (start <= offset && offset < end) { | |||
1554 | if (dim >= insertRankDiff) | |||
1555 | offsetDiffs.push_back(offset - start); | |||
1556 | continue; | |||
1557 | } | |||
1558 | disjoint = true; | |||
1559 | break; | |||
1560 | } | |||
1561 | // The extract element chunk overlap with the vector inserted. | |||
1562 | if (!disjoint) { | |||
1563 | // If any of the inner dimensions are only partially inserted we have a | |||
1564 | // partial overlap. | |||
1565 | int64_t srcRankDiff = | |||
1566 | insertOp.getSourceVectorType().getRank() - destinationRank; | |||
1567 | for (int64_t i = 0; i < destinationRank; i++) { | |||
1568 | if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) != | |||
1569 | insertOp.getDestVectorType().getDimSize(i + srcRankDiff + | |||
1570 | insertRankDiff)) | |||
1571 | return Value(); | |||
1572 | } | |||
1573 | op.getVectorMutable().assign(insertOp.getSource()); | |||
1574 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
1575 | OpBuilder b(op.getContext()); | |||
1576 | op->setAttr(ExtractOp::getPositionAttrStrName(), | |||
1577 | b.getI64ArrayAttr(offsetDiffs)); | |||
1578 | return op.getResult(); | |||
1579 | } | |||
1580 | // If the chunk extracted is disjoint from the chunk inserted, keep | |||
1581 | // looking in the insert chain. | |||
1582 | insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>(); | |||
1583 | } | |||
1584 | return Value(); | |||
1585 | } | |||
1586 | ||||
1587 | OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) { | |||
1588 | if (getPosition().empty()) | |||
1589 | return getVector(); | |||
1590 | if (succeeded(foldExtractOpFromExtractChain(*this))) | |||
1591 | return getResult(); | |||
1592 | if (auto res = ExtractFromInsertTransposeChainState(*this).fold()) | |||
1593 | return res; | |||
1594 | if (auto res = foldExtractFromBroadcast(*this)) | |||
1595 | return res; | |||
1596 | if (auto res = foldExtractFromShapeCast(*this)) | |||
1597 | return res; | |||
1598 | if (auto val = foldExtractFromExtractStrided(*this)) | |||
1599 | return val; | |||
1600 | if (auto val = foldExtractStridedOpFromInsertChain(*this)) | |||
1601 | return val; | |||
1602 | return OpFoldResult(); | |||
1603 | } | |||
1604 | ||||
1605 | namespace { | |||
1606 | ||||
1607 | // Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast. | |||
1608 | class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> { | |||
1609 | public: | |||
1610 | using OpRewritePattern::OpRewritePattern; | |||
1611 | ||||
1612 | LogicalResult matchAndRewrite(ExtractOp extractOp, | |||
1613 | PatternRewriter &rewriter) const override { | |||
1614 | Operation *defOp = extractOp.getVector().getDefiningOp(); | |||
1615 | if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp)) | |||
1616 | return failure(); | |||
1617 | ||||
1618 | Value source = defOp->getOperand(0); | |||
1619 | if (extractOp.getType() == source.getType()) | |||
1620 | return failure(); | |||
1621 | auto getRank = [](Type type) { | |||
1622 | return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0; | |||
1623 | }; | |||
1624 | unsigned broadcastSrcRank = getRank(source.getType()); | |||
1625 | unsigned extractResultRank = getRank(extractOp.getType()); | |||
1626 | // We only consider the case where the rank of the source is less than or | |||
1627 | // equal to the rank of the extract dst. The other cases are handled in the | |||
1628 | // folding patterns. | |||
1629 | if (extractResultRank < broadcastSrcRank) | |||
1630 | return failure(); | |||
1631 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | |||
1632 | extractOp, extractOp.getType(), source); | |||
1633 | return success(); | |||
1634 | } | |||
1635 | }; | |||
1636 | ||||
1637 | // Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp. | |||
1638 | class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> { | |||
1639 | public: | |||
1640 | using OpRewritePattern::OpRewritePattern; | |||
1641 | ||||
1642 | LogicalResult matchAndRewrite(ExtractOp extractOp, | |||
1643 | PatternRewriter &rewriter) const override { | |||
1644 | // Return if 'ExtractOp' operand is not defined by a splat vector | |||
1645 | // ConstantOp. | |||
1646 | Value sourceVector = extractOp.getVector(); | |||
1647 | Attribute vectorCst; | |||
1648 | if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | |||
1649 | return failure(); | |||
1650 | auto splat = vectorCst.dyn_cast<SplatElementsAttr>(); | |||
1651 | if (!splat) | |||
1652 | return failure(); | |||
1653 | Attribute newAttr = splat.getSplatValue<Attribute>(); | |||
1654 | if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>()) | |||
1655 | newAttr = DenseElementsAttr::get(vecDstType, newAttr); | |||
1656 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); | |||
1657 | return success(); | |||
1658 | } | |||
1659 | }; | |||
1660 | ||||
1661 | // Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp. | |||
1662 | class ExtractOpNonSplatConstantFolder final | |||
1663 | : public OpRewritePattern<ExtractOp> { | |||
1664 | public: | |||
1665 | using OpRewritePattern::OpRewritePattern; | |||
1666 | ||||
1667 | LogicalResult matchAndRewrite(ExtractOp extractOp, | |||
1668 | PatternRewriter &rewriter) const override { | |||
1669 | // Return if 'ExtractOp' operand is not defined by a compatible vector | |||
1670 | // ConstantOp. | |||
1671 | Value sourceVector = extractOp.getVector(); | |||
1672 | Attribute vectorCst; | |||
1673 | if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | |||
1674 | return failure(); | |||
1675 | ||||
1676 | auto vecTy = sourceVector.getType().cast<VectorType>(); | |||
1677 | if (vecTy.isScalable()) | |||
1678 | return failure(); | |||
1679 | ||||
1680 | // The splat case is handled by `ExtractOpSplatConstantFolder`. | |||
1681 | auto dense = vectorCst.dyn_cast<DenseElementsAttr>(); | |||
1682 | if (!dense || dense.isSplat()) | |||
1683 | return failure(); | |||
1684 | ||||
1685 | // Calculate the linearized position of the continuous chunk of elements to | |||
1686 | // extract. | |||
1687 | llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0); | |||
1688 | copy(getI64SubArray(extractOp.getPosition()), completePositions.begin()); | |||
1689 | int64_t elemBeginPosition = | |||
1690 | linearize(completePositions, computeStrides(vecTy.getShape())); | |||
1691 | auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition; | |||
1692 | ||||
1693 | Attribute newAttr; | |||
1694 | if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) { | |||
1695 | SmallVector<Attribute> elementValues( | |||
1696 | denseValuesBegin, denseValuesBegin + resVecTy.getNumElements()); | |||
1697 | newAttr = DenseElementsAttr::get(resVecTy, elementValues); | |||
1698 | } else { | |||
1699 | newAttr = *denseValuesBegin; | |||
1700 | } | |||
1701 | ||||
1702 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr); | |||
1703 | return success(); | |||
1704 | } | |||
1705 | }; | |||
1706 | ||||
1707 | } // namespace | |||
1708 | ||||
1709 | void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
1710 | MLIRContext *context) { | |||
1711 | results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder, | |||
1712 | ExtractOpFromBroadcast>(context); | |||
1713 | } | |||
1714 | ||||
1715 | static void populateFromInt64AttrArray(ArrayAttr arrayAttr, | |||
1716 | SmallVectorImpl<int64_t> &results) { | |||
1717 | for (auto attr : arrayAttr) | |||
1718 | results.push_back(attr.cast<IntegerAttr>().getInt()); | |||
1719 | } | |||
1720 | ||||
1721 | //===----------------------------------------------------------------------===// | |||
1722 | // FmaOp | |||
1723 | //===----------------------------------------------------------------------===// | |||
1724 | ||||
1725 | std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() { | |||
1726 | return llvm::to_vector<4>(getVectorType().getShape()); | |||
1727 | } | |||
1728 | ||||
1729 | //===----------------------------------------------------------------------===// | |||
1730 | // BroadcastOp | |||
1731 | //===----------------------------------------------------------------------===// | |||
1732 | ||||
1733 | /// Return the dimensions of the result vector that were formerly ones in the | |||
1734 | /// source tensor and thus correspond to "dim-1" broadcasting. | |||
1735 | static llvm::SetVector<int64_t> | |||
1736 | computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape, | |||
1737 | ArrayRef<int64_t> dstShape) { | |||
1738 | int64_t rankDiff = dstShape.size() - srcShape.size(); | |||
1739 | int64_t dstDim = rankDiff; | |||
1740 | llvm::SetVector<int64_t> res; | |||
1741 | for (auto [s1, s2] : | |||
1742 | llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) { | |||
1743 | if (s1 != s2) { | |||
1744 | assert(s1 == 1 && "expected dim-1 broadcasting")(static_cast <bool> (s1 == 1 && "expected dim-1 broadcasting" ) ? void (0) : __assert_fail ("s1 == 1 && \"expected dim-1 broadcasting\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1744, __extension__ __PRETTY_FUNCTION__)); | |||
1745 | res.insert(dstDim); | |||
1746 | } | |||
1747 | ++dstDim; | |||
1748 | } | |||
1749 | return res; | |||
1750 | } | |||
1751 | ||||
1752 | llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() { | |||
1753 | // Scalar broadcast is without any unit dim broadcast. | |||
1754 | auto srcVectorType = getSourceType().dyn_cast<VectorType>(); | |||
1755 | if (!srcVectorType) | |||
1756 | return {}; | |||
1757 | return ::computeBroadcastedUnitDims(srcVectorType.getShape(), | |||
1758 | getVectorType().getShape()); | |||
1759 | } | |||
1760 | ||||
1761 | /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the | |||
1762 | /// `broadcastedDims` dimensions in the dstShape are broadcasted. | |||
1763 | /// This requires (and asserts) that the broadcast is free of dim-1 | |||
1764 | /// broadcasting. | |||
1765 | /// Since vector.broadcast only allows expanding leading dimensions, an extra | |||
1766 | /// vector.transpose may be inserted to make the broadcast possible. | |||
1767 | /// `value`, `dstShape` and `broadcastedDims` must be properly specified or | |||
1768 | /// the helper will assert. This means: | |||
1769 | /// 1. `dstShape` must not be empty. | |||
1770 | /// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)] | |||
1771 | /// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims` | |||
1772 | // must match the `value` shape. | |||
1773 | Value BroadcastOp::createOrFoldBroadcastOp( | |||
1774 | OpBuilder &b, Value value, ArrayRef<int64_t> dstShape, | |||
1775 | const llvm::SetVector<int64_t> &broadcastedDims) { | |||
1776 | assert(!dstShape.empty() && "unexpected empty dst shape")(static_cast <bool> (!dstShape.empty() && "unexpected empty dst shape" ) ? void (0) : __assert_fail ("!dstShape.empty() && \"unexpected empty dst shape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1776, __extension__ __PRETTY_FUNCTION__)); | |||
1777 | ||||
1778 | // Well-formedness check. | |||
1779 | SmallVector<int64_t> checkShape; | |||
1780 | for (int i = 0, e = dstShape.size(); i < e; ++i) { | |||
1781 | if (broadcastedDims.contains(i)) | |||
1782 | continue; | |||
1783 | checkShape.push_back(dstShape[i]); | |||
1784 | } | |||
1785 | assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&(static_cast <bool> (broadcastedDims.size() == dstShape .size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to " "destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__ __PRETTY_FUNCTION__)) | |||
1786 | "ill-formed broadcastedDims contains values not confined to "(static_cast <bool> (broadcastedDims.size() == dstShape .size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to " "destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__ __PRETTY_FUNCTION__)) | |||
1787 | "destVectorShape")(static_cast <bool> (broadcastedDims.size() == dstShape .size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to " "destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__ __PRETTY_FUNCTION__)); | |||
1788 | ||||
1789 | Location loc = value.getLoc(); | |||
1790 | Type elementType = getElementTypeOrSelf(value.getType()); | |||
1791 | VectorType srcVectorType = value.getType().dyn_cast<VectorType>(); | |||
1792 | VectorType dstVectorType = VectorType::get(dstShape, elementType); | |||
1793 | ||||
1794 | // Step 2. If scalar -> dstShape broadcast, just do it. | |||
1795 | if (!srcVectorType) { | |||
1796 | assert(checkShape.empty() &&(static_cast <bool> (checkShape.empty() && "ill-formed createOrFoldBroadcastOp arguments" ) ? void (0) : __assert_fail ("checkShape.empty() && \"ill-formed createOrFoldBroadcastOp arguments\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1797, __extension__ __PRETTY_FUNCTION__)) | |||
1797 | "ill-formed createOrFoldBroadcastOp arguments")(static_cast <bool> (checkShape.empty() && "ill-formed createOrFoldBroadcastOp arguments" ) ? void (0) : __assert_fail ("checkShape.empty() && \"ill-formed createOrFoldBroadcastOp arguments\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1797, __extension__ __PRETTY_FUNCTION__)); | |||
1798 | return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value); | |||
1799 | } | |||
1800 | ||||
1801 | assert(srcVectorType.getShape().equals(checkShape) &&(static_cast <bool> (srcVectorType.getShape().equals(checkShape ) && "ill-formed createOrFoldBroadcastOp arguments") ? void (0) : __assert_fail ("srcVectorType.getShape().equals(checkShape) && \"ill-formed createOrFoldBroadcastOp arguments\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1802, __extension__ __PRETTY_FUNCTION__)) | |||
1802 | "ill-formed createOrFoldBroadcastOp arguments")(static_cast <bool> (srcVectorType.getShape().equals(checkShape ) && "ill-formed createOrFoldBroadcastOp arguments") ? void (0) : __assert_fail ("srcVectorType.getShape().equals(checkShape) && \"ill-formed createOrFoldBroadcastOp arguments\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1802, __extension__ __PRETTY_FUNCTION__)); | |||
1803 | ||||
1804 | // Step 3. Since vector.broadcast only allows creating leading dims, | |||
1805 | // vector -> dstShape broadcast may require a transpose. | |||
1806 | // Traverse the dims in order and construct: | |||
1807 | // 1. The leading entries of the broadcastShape that is guaranteed to be | |||
1808 | // achievable by a simple broadcast. | |||
1809 | // 2. The induced permutation for the subsequent vector.transpose that will | |||
1810 | // bring us from `broadcastShape` back to he desired `dstShape`. | |||
1811 | // If the induced permutation is not the identity, create a vector.transpose. | |||
1812 | SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1); | |||
1813 | broadcastShape.reserve(dstShape.size()); | |||
1814 | // Consider the example: | |||
1815 | // srcShape = 2x4 | |||
1816 | // dstShape = 1x2x3x4x5 | |||
1817 | // broadcastedDims = [0, 2, 4] | |||
1818 | // | |||
1819 | // We want to build: | |||
1820 | // broadcastShape = 1x3x5x2x4 | |||
1821 | // permutation = [0, 2, 4, 1, 3] | |||
1822 | // ---V--- -----V----- | |||
1823 | // leading broadcast part src shape part | |||
1824 | // | |||
1825 | // Note that the trailing dims of broadcastShape are exactly the srcShape | |||
1826 | // by construction. | |||
1827 | // nextSrcShapeDim is used to keep track of where in the permutation the | |||
1828 | // "src shape part" occurs. | |||
1829 | int64_t nextSrcShapeDim = broadcastedDims.size(); | |||
1830 | for (int64_t i = 0, e = dstShape.size(); i < e; ++i) { | |||
1831 | if (broadcastedDims.contains(i)) { | |||
1832 | // 3.a. For each dim in the dst shape, if it is a broadcasted dim, | |||
1833 | // bring it to the head of the broadcastShape. | |||
1834 | // It will need to be permuted back from `broadcastShape.size() - 1` into | |||
1835 | // position `i`. | |||
1836 | broadcastShape.push_back(dstShape[i]); | |||
1837 | permutation[i] = broadcastShape.size() - 1; | |||
1838 | } else { | |||
1839 | // 3.b. Otherwise, the dim is not broadcasted, it comes from the src | |||
1840 | // shape and needs to be permuted into position `i`. | |||
1841 | // Don't touch `broadcastShape` here, the whole srcShape will be | |||
1842 | // appended after. | |||
1843 | permutation[i] = nextSrcShapeDim++; | |||
1844 | } | |||
1845 | } | |||
1846 | // 3.c. Append the srcShape. | |||
1847 | llvm::append_range(broadcastShape, srcVectorType.getShape()); | |||
1848 | ||||
1849 | // Ensure there are no dim-1 broadcasts. | |||
1850 | assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType .getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast" ) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__ __PRETTY_FUNCTION__)) | |||
1851 | .empty() &&(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType .getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast" ) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__ __PRETTY_FUNCTION__)) | |||
1852 | "unexpected dim-1 broadcast")(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType .getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast" ) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__ __PRETTY_FUNCTION__)); | |||
1853 | ||||
1854 | VectorType broadcastType = VectorType::get(broadcastShape, elementType); | |||
1855 | assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==(static_cast <bool> (vector::isBroadcastableTo(value.getType (), broadcastType) == vector::BroadcastableToResult::Success && "must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__ __PRETTY_FUNCTION__)) | |||
1856 | vector::BroadcastableToResult::Success &&(static_cast <bool> (vector::isBroadcastableTo(value.getType (), broadcastType) == vector::BroadcastableToResult::Success && "must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__ __PRETTY_FUNCTION__)) | |||
1857 | "must be broadcastable")(static_cast <bool> (vector::isBroadcastableTo(value.getType (), broadcastType) == vector::BroadcastableToResult::Success && "must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__ __PRETTY_FUNCTION__)); | |||
1858 | Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value); | |||
1859 | // Step 4. If we find any dimension that indeed needs to be permuted, | |||
1860 | // immediately return a new vector.transpose. | |||
1861 | for (int64_t i = 0, e = permutation.size(); i < e; ++i) | |||
1862 | if (permutation[i] != i) | |||
1863 | return b.createOrFold<vector::TransposeOp>(loc, res, permutation); | |||
1864 | // Otherwise return res. | |||
1865 | return res; | |||
1866 | } | |||
1867 | ||||
1868 | BroadcastableToResult | |||
1869 | mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType, | |||
1870 | std::pair<int, int> *mismatchingDims) { | |||
1871 | // Broadcast scalar to vector of the same element type. | |||
1872 | if (srcType.isIntOrIndexOrFloat() && dstVectorType && | |||
1873 | getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType)) | |||
1874 | return BroadcastableToResult::Success; | |||
1875 | // From now on, only vectors broadcast. | |||
1876 | VectorType srcVectorType = srcType.dyn_cast<VectorType>(); | |||
1877 | if (!srcVectorType) | |||
1878 | return BroadcastableToResult::SourceTypeNotAVector; | |||
1879 | ||||
1880 | int64_t srcRank = srcVectorType.getRank(); | |||
1881 | int64_t dstRank = dstVectorType.getRank(); | |||
1882 | if (srcRank > dstRank) | |||
1883 | return BroadcastableToResult::SourceRankHigher; | |||
1884 | // Source has an exact match or singleton value for all trailing dimensions | |||
1885 | // (all leading dimensions are simply duplicated). | |||
1886 | int64_t lead = dstRank - srcRank; | |||
1887 | for (int64_t r = 0; r < srcRank; ++r) { | |||
1888 | int64_t srcDim = srcVectorType.getDimSize(r); | |||
1889 | int64_t dstDim = dstVectorType.getDimSize(lead + r); | |||
1890 | if (srcDim != 1 && srcDim != dstDim) { | |||
1891 | if (mismatchingDims) { | |||
1892 | mismatchingDims->first = srcDim; | |||
1893 | mismatchingDims->second = dstDim; | |||
1894 | } | |||
1895 | return BroadcastableToResult::DimensionMismatch; | |||
1896 | } | |||
1897 | } | |||
1898 | ||||
1899 | return BroadcastableToResult::Success; | |||
1900 | } | |||
1901 | ||||
1902 | LogicalResult BroadcastOp::verify() { | |||
1903 | std::pair<int, int> mismatchingDims; | |||
1904 | BroadcastableToResult res = | |||
1905 | isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims); | |||
1906 | if (res == BroadcastableToResult::Success) | |||
1907 | return success(); | |||
1908 | if (res == BroadcastableToResult::SourceRankHigher) | |||
1909 | return emitOpError("source rank higher than destination rank"); | |||
1910 | if (res == BroadcastableToResult::DimensionMismatch) | |||
1911 | return emitOpError("dimension mismatch (") | |||
1912 | << mismatchingDims.first << " vs. " << mismatchingDims.second << ")"; | |||
1913 | if (res == BroadcastableToResult::SourceTypeNotAVector) | |||
1914 | return emitOpError("source type is not a vector"); | |||
1915 | llvm_unreachable("unexpected vector.broadcast op error")::llvm::llvm_unreachable_internal("unexpected vector.broadcast op error" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1915); | |||
1916 | } | |||
1917 | ||||
1918 | OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) { | |||
1919 | if (getSourceType() == getVectorType()) | |||
1920 | return getSource(); | |||
1921 | if (!operands[0]) | |||
1922 | return {}; | |||
1923 | auto vectorType = getVectorType(); | |||
1924 | if (operands[0].isa<IntegerAttr, FloatAttr>()) | |||
1925 | return DenseElementsAttr::get(vectorType, operands[0]); | |||
1926 | if (auto attr = operands[0].dyn_cast<SplatElementsAttr>()) | |||
1927 | return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>()); | |||
1928 | return {}; | |||
1929 | } | |||
1930 | ||||
1931 | namespace { | |||
1932 | ||||
1933 | // Fold broadcast1(broadcast2(x)) into broadcast1(x). | |||
1934 | struct BroadcastFolder : public OpRewritePattern<BroadcastOp> { | |||
1935 | using OpRewritePattern::OpRewritePattern; | |||
1936 | ||||
1937 | LogicalResult matchAndRewrite(BroadcastOp broadcastOp, | |||
1938 | PatternRewriter &rewriter) const override { | |||
1939 | auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>(); | |||
1940 | if (!srcBroadcast) | |||
1941 | return failure(); | |||
1942 | rewriter.replaceOpWithNewOp<BroadcastOp>( | |||
1943 | broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource()); | |||
1944 | return success(); | |||
1945 | } | |||
1946 | }; | |||
1947 | } // namespace | |||
1948 | ||||
1949 | void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
1950 | MLIRContext *context) { | |||
1951 | // BroadcastToShapeCast is not a default canonicalization, it is opt-in by | |||
1952 | // calling `populateCastAwayVectorLeadingOneDimPatterns` | |||
1953 | results.add<BroadcastFolder>(context); | |||
1954 | } | |||
1955 | ||||
1956 | //===----------------------------------------------------------------------===// | |||
1957 | // ShuffleOp | |||
1958 | //===----------------------------------------------------------------------===// | |||
1959 | ||||
1960 | void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1, | |||
1961 | Value v2, ArrayRef<int64_t> mask) { | |||
1962 | build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask)); | |||
1963 | } | |||
1964 | ||||
1965 | LogicalResult ShuffleOp::verify() { | |||
1966 | VectorType resultType = getVectorType(); | |||
1967 | VectorType v1Type = getV1VectorType(); | |||
1968 | VectorType v2Type = getV2VectorType(); | |||
1969 | // Verify ranks. | |||
1970 | int64_t resRank = resultType.getRank(); | |||
1971 | int64_t v1Rank = v1Type.getRank(); | |||
1972 | int64_t v2Rank = v2Type.getRank(); | |||
1973 | bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1; | |||
1974 | bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank; | |||
1975 | if (!wellFormed0DCase && !wellFormedNDCase) | |||
1976 | return emitOpError("rank mismatch"); | |||
1977 | ||||
1978 | // Verify all but leading dimension sizes. | |||
1979 | for (int64_t r = 1; r < v1Rank; ++r) { | |||
1980 | int64_t resDim = resultType.getDimSize(r); | |||
1981 | int64_t v1Dim = v1Type.getDimSize(r); | |||
1982 | int64_t v2Dim = v2Type.getDimSize(r); | |||
1983 | if (resDim != v1Dim || v1Dim != v2Dim) | |||
1984 | return emitOpError("dimension mismatch"); | |||
1985 | } | |||
1986 | // Verify mask length. | |||
1987 | auto maskAttr = getMask().getValue(); | |||
1988 | int64_t maskLength = maskAttr.size(); | |||
1989 | if (maskLength <= 0) | |||
1990 | return emitOpError("invalid mask length"); | |||
1991 | if (maskLength != resultType.getDimSize(0)) | |||
1992 | return emitOpError("mask length mismatch"); | |||
1993 | // Verify all indices. | |||
1994 | int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) + | |||
1995 | (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0)); | |||
1996 | for (const auto &en : llvm::enumerate(maskAttr)) { | |||
1997 | auto attr = en.value().dyn_cast<IntegerAttr>(); | |||
1998 | if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize) | |||
1999 | return emitOpError("mask index #") << (en.index() + 1) << " out of range"; | |||
2000 | } | |||
2001 | return success(); | |||
2002 | } | |||
2003 | ||||
2004 | LogicalResult | |||
2005 | ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>, | |||
2006 | ValueRange operands, DictionaryAttr attributes, | |||
2007 | RegionRange, | |||
2008 | SmallVectorImpl<Type> &inferredReturnTypes) { | |||
2009 | ShuffleOp::Adaptor op(operands, attributes); | |||
2010 | auto v1Type = op.getV1().getType().cast<VectorType>(); | |||
2011 | auto v1Rank = v1Type.getRank(); | |||
2012 | // Construct resulting type: leading dimension matches mask | |||
2013 | // length, all trailing dimensions match the operands. | |||
2014 | SmallVector<int64_t, 4> shape; | |||
2015 | shape.reserve(v1Rank); | |||
2016 | shape.push_back(std::max<size_t>(1, op.getMask().size())); | |||
2017 | // In the 0-D case there is no trailing shape to append. | |||
2018 | if (v1Rank > 0) | |||
2019 | llvm::append_range(shape, v1Type.getShape().drop_front()); | |||
2020 | inferredReturnTypes.push_back( | |||
2021 | VectorType::get(shape, v1Type.getElementType())); | |||
2022 | return success(); | |||
2023 | } | |||
2024 | ||||
2025 | static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) { | |||
2026 | uint64_t expected = begin; | |||
2027 | return idxArr.size() == width && | |||
2028 | llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(), | |||
2029 | [&expected](auto attr) { | |||
2030 | return attr.getZExtValue() == expected++; | |||
2031 | }); | |||
2032 | } | |||
2033 | ||||
2034 | OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) { | |||
2035 | VectorType v1Type = getV1VectorType(); | |||
2036 | // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding | |||
2037 | // but must be a canonicalization into a vector.broadcast. | |||
2038 | if (v1Type.getRank() == 0) | |||
2039 | return {}; | |||
2040 | ||||
2041 | // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1 | |||
2042 | if (!v1Type.isScalable() && | |||
2043 | isStepIndexArray(getMask(), 0, v1Type.getDimSize(0))) | |||
2044 | return getV1(); | |||
2045 | // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2 | |||
2046 | if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() && | |||
2047 | isStepIndexArray(getMask(), getV1VectorType().getDimSize(0), | |||
2048 | getV2VectorType().getDimSize(0))) | |||
2049 | return getV2(); | |||
2050 | ||||
2051 | Attribute lhs = operands.front(), rhs = operands.back(); | |||
2052 | if (!lhs || !rhs) | |||
2053 | return {}; | |||
2054 | ||||
2055 | auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>(); | |||
2056 | // Only support 1-D for now to avoid complicated n-D DenseElementsAttr | |||
2057 | // manipulation. | |||
2058 | if (lhsType.getRank() != 1) | |||
2059 | return {}; | |||
2060 | int64_t lhsSize = lhsType.getDimSize(0); | |||
2061 | ||||
2062 | SmallVector<Attribute> results; | |||
2063 | auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>(); | |||
2064 | auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>(); | |||
2065 | for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) { | |||
2066 | int64_t i = index.getZExtValue(); | |||
2067 | if (i >= lhsSize) { | |||
2068 | results.push_back(rhsElements[i - lhsSize]); | |||
2069 | } else { | |||
2070 | results.push_back(lhsElements[i]); | |||
2071 | } | |||
2072 | } | |||
2073 | ||||
2074 | return DenseElementsAttr::get(getVectorType(), results); | |||
2075 | } | |||
2076 | ||||
2077 | namespace { | |||
2078 | ||||
2079 | // Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector | |||
2080 | // to a broadcast. | |||
2081 | struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> { | |||
2082 | using OpRewritePattern::OpRewritePattern; | |||
2083 | ||||
2084 | LogicalResult matchAndRewrite(ShuffleOp shuffleOp, | |||
2085 | PatternRewriter &rewriter) const override { | |||
2086 | VectorType v1VectorType = shuffleOp.getV1VectorType(); | |||
2087 | ArrayAttr mask = shuffleOp.getMask(); | |||
2088 | if (v1VectorType.getRank() > 0) | |||
2089 | return failure(); | |||
2090 | if (mask.size() != 1) | |||
2091 | return failure(); | |||
2092 | Type resType = VectorType::Builder(v1VectorType).setShape({1}); | |||
2093 | if (mask[0].cast<IntegerAttr>().getInt() == 0) | |||
2094 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType, | |||
2095 | shuffleOp.getV1()); | |||
2096 | else | |||
2097 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType, | |||
2098 | shuffleOp.getV2()); | |||
2099 | return success(); | |||
2100 | } | |||
2101 | }; | |||
2102 | ||||
2103 | /// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp. | |||
2104 | class ShuffleSplat final : public OpRewritePattern<ShuffleOp> { | |||
2105 | public: | |||
2106 | using OpRewritePattern::OpRewritePattern; | |||
2107 | ||||
2108 | LogicalResult matchAndRewrite(ShuffleOp op, | |||
2109 | PatternRewriter &rewriter) const override { | |||
2110 | auto v1Splat = op.getV1().getDefiningOp<SplatOp>(); | |||
2111 | auto v2Splat = op.getV2().getDefiningOp<SplatOp>(); | |||
2112 | ||||
2113 | if (!v1Splat || !v2Splat) | |||
2114 | return failure(); | |||
2115 | ||||
2116 | if (v1Splat.getInput() != v2Splat.getInput()) | |||
2117 | return failure(); | |||
2118 | ||||
2119 | rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput()); | |||
2120 | return success(); | |||
2121 | } | |||
2122 | }; | |||
2123 | ||||
2124 | } // namespace | |||
2125 | ||||
2126 | void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
2127 | MLIRContext *context) { | |||
2128 | results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context); | |||
2129 | } | |||
2130 | ||||
2131 | //===----------------------------------------------------------------------===// | |||
2132 | // InsertElementOp | |||
2133 | //===----------------------------------------------------------------------===// | |||
2134 | ||||
2135 | void InsertElementOp::build(OpBuilder &builder, OperationState &result, | |||
2136 | Value source, Value dest) { | |||
2137 | build(builder, result, source, dest, {}); | |||
2138 | } | |||
2139 | ||||
2140 | LogicalResult InsertElementOp::verify() { | |||
2141 | auto dstVectorType = getDestVectorType(); | |||
2142 | if (dstVectorType.getRank() == 0) { | |||
2143 | if (getPosition()) | |||
2144 | return emitOpError("expected position to be empty with 0-D vector"); | |||
2145 | return success(); | |||
2146 | } | |||
2147 | if (dstVectorType.getRank() != 1) | |||
2148 | return emitOpError("unexpected >1 vector rank"); | |||
2149 | if (!getPosition()) | |||
2150 | return emitOpError("expected position for 1-D vector"); | |||
2151 | return success(); | |||
2152 | } | |||
2153 | ||||
2154 | OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) { | |||
2155 | // Skip the 0-D vector here. | |||
2156 | if (operands.size() < 3) | |||
2157 | return {}; | |||
2158 | ||||
2159 | Attribute src = operands[0]; | |||
2160 | Attribute dst = operands[1]; | |||
2161 | Attribute pos = operands[2]; | |||
2162 | if (!src || !dst || !pos) | |||
2163 | return {}; | |||
2164 | ||||
2165 | auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>(); | |||
2166 | ||||
2167 | SmallVector<Attribute> results(dstElements); | |||
2168 | ||||
2169 | auto attr = pos.dyn_cast<IntegerAttr>(); | |||
2170 | uint64_t posIdx = attr.getInt(); | |||
2171 | ||||
2172 | results[posIdx] = src; | |||
2173 | ||||
2174 | return DenseElementsAttr::get(getDestVectorType(), results); | |||
2175 | } | |||
2176 | ||||
2177 | //===----------------------------------------------------------------------===// | |||
2178 | // InsertOp | |||
2179 | //===----------------------------------------------------------------------===// | |||
2180 | ||||
2181 | void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, | |||
2182 | Value dest, ArrayRef<int64_t> position) { | |||
2183 | result.addOperands({source, dest}); | |||
2184 | auto positionAttr = getVectorSubscriptAttr(builder, position); | |||
2185 | result.addTypes(dest.getType()); | |||
2186 | result.addAttribute(getPositionAttrStrName(), positionAttr); | |||
2187 | } | |||
2188 | ||||
2189 | // Convenience builder which assumes the values are constant indices. | |||
2190 | void InsertOp::build(OpBuilder &builder, OperationState &result, Value source, | |||
2191 | Value dest, ValueRange position) { | |||
2192 | SmallVector<int64_t, 4> positionConstants = | |||
2193 | llvm::to_vector<4>(llvm::map_range(position, [](Value pos) { | |||
2194 | return pos.getDefiningOp<arith::ConstantIndexOp>().value(); | |||
2195 | })); | |||
2196 | build(builder, result, source, dest, positionConstants); | |||
2197 | } | |||
2198 | ||||
2199 | LogicalResult InsertOp::verify() { | |||
2200 | auto positionAttr = getPosition().getValue(); | |||
2201 | auto destVectorType = getDestVectorType(); | |||
2202 | if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank())) | |||
2203 | return emitOpError( | |||
2204 | "expected position attribute of rank smaller than dest vector rank"); | |||
2205 | auto srcVectorType = getSourceType().dyn_cast<VectorType>(); | |||
2206 | if (srcVectorType && | |||
2207 | (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() != | |||
2208 | static_cast<unsigned>(destVectorType.getRank()))) | |||
2209 | return emitOpError("expected position attribute rank + source rank to " | |||
2210 | "match dest vector rank"); | |||
2211 | if (!srcVectorType && | |||
2212 | (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank()))) | |||
2213 | return emitOpError( | |||
2214 | "expected position attribute rank to match the dest vector rank"); | |||
2215 | for (const auto &en : llvm::enumerate(positionAttr)) { | |||
2216 | auto attr = en.value().dyn_cast<IntegerAttr>(); | |||
2217 | if (!attr || attr.getInt() < 0 || | |||
2218 | attr.getInt() >= destVectorType.getDimSize(en.index())) | |||
2219 | return emitOpError("expected position attribute #") | |||
2220 | << (en.index() + 1) | |||
2221 | << " to be a non-negative integer smaller than the corresponding " | |||
2222 | "dest vector dimension"; | |||
2223 | } | |||
2224 | return success(); | |||
2225 | } | |||
2226 | ||||
2227 | namespace { | |||
2228 | ||||
2229 | // If insertOp is only inserting unit dimensions it can be transformed to a | |||
2230 | // broadcast. | |||
2231 | class InsertToBroadcast final : public OpRewritePattern<InsertOp> { | |||
2232 | public: | |||
2233 | using OpRewritePattern::OpRewritePattern; | |||
2234 | ||||
2235 | LogicalResult matchAndRewrite(InsertOp insertOp, | |||
2236 | PatternRewriter &rewriter) const override { | |||
2237 | auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>(); | |||
2238 | if (!srcVecType || insertOp.getDestVectorType().getNumElements() != | |||
2239 | srcVecType.getNumElements()) | |||
2240 | return failure(); | |||
2241 | rewriter.replaceOpWithNewOp<BroadcastOp>( | |||
2242 | insertOp, insertOp.getDestVectorType(), insertOp.getSource()); | |||
2243 | return success(); | |||
2244 | } | |||
2245 | }; | |||
2246 | ||||
2247 | /// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp. | |||
2248 | class InsertSplatToSplat final : public OpRewritePattern<InsertOp> { | |||
2249 | public: | |||
2250 | using OpRewritePattern::OpRewritePattern; | |||
2251 | ||||
2252 | LogicalResult matchAndRewrite(InsertOp op, | |||
2253 | PatternRewriter &rewriter) const override { | |||
2254 | auto srcSplat = op.getSource().getDefiningOp<SplatOp>(); | |||
2255 | auto dstSplat = op.getDest().getDefiningOp<SplatOp>(); | |||
2256 | ||||
2257 | if (!srcSplat || !dstSplat) | |||
2258 | return failure(); | |||
2259 | ||||
2260 | if (srcSplat.getInput() != dstSplat.getInput()) | |||
2261 | return failure(); | |||
2262 | ||||
2263 | rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput()); | |||
2264 | return success(); | |||
2265 | } | |||
2266 | }; | |||
2267 | ||||
2268 | // Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp. | |||
2269 | class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> { | |||
2270 | public: | |||
2271 | using OpRewritePattern::OpRewritePattern; | |||
2272 | ||||
2273 | // Do not create constants with more than `vectorSizeFoldThreashold` elements, | |||
2274 | // unless the source vector constant has a single use. | |||
2275 | static constexpr int64_t vectorSizeFoldThreshold = 256; | |||
2276 | ||||
2277 | LogicalResult matchAndRewrite(InsertOp op, | |||
2278 | PatternRewriter &rewriter) const override { | |||
2279 | // Return if 'InsertOp' operand is not defined by a compatible vector | |||
2280 | // ConstantOp. | |||
2281 | TypedValue<VectorType> destVector = op.getDest(); | |||
2282 | Attribute vectorDestCst; | |||
2283 | if (!matchPattern(destVector, m_Constant(&vectorDestCst))) | |||
2284 | return failure(); | |||
2285 | ||||
2286 | VectorType destTy = destVector.getType(); | |||
2287 | if (destTy.isScalable()) | |||
2288 | return failure(); | |||
2289 | ||||
2290 | // Make sure we do not create too many large constants. | |||
2291 | if (destTy.getNumElements() > vectorSizeFoldThreshold && | |||
2292 | !destVector.hasOneUse()) | |||
2293 | return failure(); | |||
2294 | ||||
2295 | auto denseDest = vectorDestCst.cast<DenseElementsAttr>(); | |||
2296 | ||||
2297 | Value sourceValue = op.getSource(); | |||
2298 | Attribute sourceCst; | |||
2299 | if (!matchPattern(sourceValue, m_Constant(&sourceCst))) | |||
2300 | return failure(); | |||
2301 | ||||
2302 | // Calculate the linearized position of the continuous chunk of elements to | |||
2303 | // insert. | |||
2304 | llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0); | |||
2305 | copy(getI64SubArray(op.getPosition()), completePositions.begin()); | |||
2306 | int64_t insertBeginPosition = | |||
2307 | linearize(completePositions, computeStrides(destTy.getShape())); | |||
2308 | ||||
2309 | SmallVector<Attribute> insertedValues; | |||
2310 | if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>()) | |||
2311 | llvm::append_range(insertedValues, denseSource.getValues<Attribute>()); | |||
2312 | else | |||
2313 | insertedValues.push_back(sourceCst); | |||
2314 | ||||
2315 | auto allValues = llvm::to_vector(denseDest.getValues<Attribute>()); | |||
2316 | copy(insertedValues, allValues.begin() + insertBeginPosition); | |||
2317 | auto newAttr = DenseElementsAttr::get(destTy, allValues); | |||
2318 | ||||
2319 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr); | |||
2320 | return success(); | |||
2321 | } | |||
2322 | }; | |||
2323 | ||||
2324 | } // namespace | |||
2325 | ||||
2326 | void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
2327 | MLIRContext *context) { | |||
2328 | results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat, | |||
2329 | InsertOpConstantFolder>(context); | |||
2330 | } | |||
2331 | ||||
2332 | // Eliminates insert operations that produce values identical to their source | |||
2333 | // value. This happens when the source and destination vectors have identical | |||
2334 | // sizes. | |||
2335 | OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) { | |||
2336 | if (getPosition().empty()) | |||
2337 | return getSource(); | |||
2338 | return {}; | |||
2339 | } | |||
2340 | ||||
2341 | //===----------------------------------------------------------------------===// | |||
2342 | // InsertStridedSliceOp | |||
2343 | //===----------------------------------------------------------------------===// | |||
2344 | ||||
2345 | void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result, | |||
2346 | Value source, Value dest, | |||
2347 | ArrayRef<int64_t> offsets, | |||
2348 | ArrayRef<int64_t> strides) { | |||
2349 | result.addOperands({source, dest}); | |||
2350 | auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); | |||
2351 | auto stridesAttr = getVectorSubscriptAttr(builder, strides); | |||
2352 | result.addTypes(dest.getType()); | |||
2353 | result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); | |||
2354 | result.addAttribute(getStridesAttrStrName(), stridesAttr); | |||
2355 | } | |||
2356 | ||||
2357 | // TODO: Should be moved to Tablegen ConfinedAttr attributes. | |||
2358 | template <typename OpType> | |||
2359 | static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op, | |||
2360 | ArrayAttr arrayAttr, | |||
2361 | ArrayRef<int64_t> shape, | |||
2362 | StringRef attrName) { | |||
2363 | if (arrayAttr.size() > shape.size()) | |||
2364 | return op.emitOpError("expected ") | |||
2365 | << attrName << " attribute of rank smaller than vector rank"; | |||
2366 | return success(); | |||
2367 | } | |||
2368 | ||||
2369 | // Returns true if all integers in `arrayAttr` are in the half-open [min, max} | |||
2370 | // interval. If `halfOpen` is true then the admissible interval is [min, max). | |||
2371 | // Otherwise, the admissible interval is [min, max]. | |||
2372 | template <typename OpType> | |||
2373 | static LogicalResult | |||
2374 | isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min, | |||
2375 | int64_t max, StringRef attrName, | |||
2376 | bool halfOpen = true) { | |||
2377 | for (auto attr : arrayAttr) { | |||
2378 | auto val = attr.cast<IntegerAttr>().getInt(); | |||
2379 | auto upper = max; | |||
2380 | if (!halfOpen) | |||
2381 | upper += 1; | |||
2382 | if (val < min || val >= upper) | |||
2383 | return op.emitOpError("expected ") << attrName << " to be confined to [" | |||
2384 | << min << ", " << upper << ")"; | |||
2385 | } | |||
2386 | return success(); | |||
2387 | } | |||
2388 | ||||
2389 | // Returns true if all integers in `arrayAttr` are in the half-open [min, max} | |||
2390 | // interval. If `halfOpen` is true then the admissible interval is [min, max). | |||
2391 | // Otherwise, the admissible interval is [min, max]. | |||
2392 | template <typename OpType> | |||
2393 | static LogicalResult | |||
2394 | isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr, | |||
2395 | ArrayRef<int64_t> shape, StringRef attrName, | |||
2396 | bool halfOpen = true, int64_t min = 0) { | |||
2397 | for (auto [index, attrDimPair] : | |||
2398 | llvm::enumerate(llvm::zip_first(arrayAttr, shape))) { | |||
2399 | int64_t val = | |||
2400 | std::get<0>(attrDimPair).template cast<IntegerAttr>().getInt(); | |||
2401 | int64_t max = std::get<1>(attrDimPair); | |||
2402 | if (!halfOpen) | |||
2403 | max += 1; | |||
2404 | if (val < min || val >= max) | |||
2405 | return op.emitOpError("expected ") | |||
2406 | << attrName << " dimension " << index << " to be confined to [" | |||
2407 | << min << ", " << max << ")"; | |||
2408 | } | |||
2409 | return success(); | |||
2410 | } | |||
2411 | ||||
2412 | // Returns true if all integers in `arrayAttr` are in the interval [min, max}. | |||
2413 | // interval. If `halfOpen` is true then the admissible interval is [min, max). | |||
2414 | // Otherwise, the admissible interval is [min, max]. | |||
2415 | template <typename OpType> | |||
2416 | static LogicalResult isSumOfIntegerArrayAttrConfinedToShape( | |||
2417 | OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2, | |||
2418 | ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2, | |||
2419 | bool halfOpen = true, int64_t min = 1) { | |||
2420 | assert(arrayAttr1.size() <= shape.size())(static_cast <bool> (arrayAttr1.size() <= shape.size ()) ? void (0) : __assert_fail ("arrayAttr1.size() <= shape.size()" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2420, __extension__ __PRETTY_FUNCTION__)); | |||
2421 | assert(arrayAttr2.size() <= shape.size())(static_cast <bool> (arrayAttr2.size() <= shape.size ()) ? void (0) : __assert_fail ("arrayAttr2.size() <= shape.size()" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2421, __extension__ __PRETTY_FUNCTION__)); | |||
2422 | for (auto [index, it] : | |||
2423 | llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) { | |||
2424 | auto val1 = std::get<0>(it).template cast<IntegerAttr>().getInt(); | |||
2425 | auto val2 = std::get<1>(it).template cast<IntegerAttr>().getInt(); | |||
2426 | int64_t max = std::get<2>(it); | |||
2427 | if (!halfOpen) | |||
2428 | max += 1; | |||
2429 | if (val1 + val2 < 0 || val1 + val2 >= max) | |||
2430 | return op.emitOpError("expected sum(") | |||
2431 | << attrName1 << ", " << attrName2 << ") dimension " << index | |||
2432 | << " to be confined to [" << min << ", " << max << ")"; | |||
2433 | } | |||
2434 | return success(); | |||
2435 | } | |||
2436 | ||||
2437 | static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values, | |||
2438 | MLIRContext *context) { | |||
2439 | auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute { | |||
2440 | return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v)); | |||
2441 | }); | |||
2442 | return ArrayAttr::get(context, llvm::to_vector<8>(attrs)); | |||
2443 | } | |||
2444 | ||||
2445 | LogicalResult InsertStridedSliceOp::verify() { | |||
2446 | auto sourceVectorType = getSourceVectorType(); | |||
2447 | auto destVectorType = getDestVectorType(); | |||
2448 | auto offsets = getOffsetsAttr(); | |||
2449 | auto strides = getStridesAttr(); | |||
2450 | if (offsets.size() != static_cast<unsigned>(destVectorType.getRank())) | |||
2451 | return emitOpError( | |||
2452 | "expected offsets of same size as destination vector rank"); | |||
2453 | if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank())) | |||
2454 | return emitOpError("expected strides of same size as source vector rank"); | |||
2455 | if (sourceVectorType.getRank() > destVectorType.getRank()) | |||
2456 | return emitOpError( | |||
2457 | "expected source rank to be smaller than destination rank"); | |||
2458 | ||||
2459 | auto sourceShape = sourceVectorType.getShape(); | |||
2460 | auto destShape = destVectorType.getShape(); | |||
2461 | SmallVector<int64_t, 4> sourceShapeAsDestShape( | |||
2462 | destShape.size() - sourceShape.size(), 0); | |||
2463 | sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end()); | |||
2464 | auto offName = InsertStridedSliceOp::getOffsetsAttrName(); | |||
2465 | auto stridesName = InsertStridedSliceOp::getStridesAttrName(); | |||
2466 | if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape, | |||
2467 | offName)) || | |||
2468 | failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, | |||
2469 | stridesName, | |||
2470 | /*halfOpen=*/false)) || | |||
2471 | failed(isSumOfIntegerArrayAttrConfinedToShape( | |||
2472 | *this, offsets, | |||
2473 | makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape, | |||
2474 | offName, "source vector shape", | |||
2475 | /*halfOpen=*/false, /*min=*/1))) | |||
2476 | return failure(); | |||
2477 | ||||
2478 | return success(); | |||
2479 | } | |||
2480 | ||||
2481 | namespace { | |||
2482 | /// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type, | |||
2483 | /// SplatOp(X):dst_type) to SplatOp(X):dst_type. | |||
2484 | class FoldInsertStridedSliceSplat final | |||
2485 | : public OpRewritePattern<InsertStridedSliceOp> { | |||
2486 | public: | |||
2487 | using OpRewritePattern::OpRewritePattern; | |||
2488 | ||||
2489 | LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, | |||
2490 | PatternRewriter &rewriter) const override { | |||
2491 | auto srcSplatOp = | |||
2492 | insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>(); | |||
2493 | auto destSplatOp = | |||
2494 | insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>(); | |||
2495 | ||||
2496 | if (!srcSplatOp || !destSplatOp) | |||
2497 | return failure(); | |||
2498 | ||||
2499 | if (srcSplatOp.getInput() != destSplatOp.getInput()) | |||
2500 | return failure(); | |||
2501 | ||||
2502 | rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); | |||
2503 | return success(); | |||
2504 | } | |||
2505 | }; | |||
2506 | ||||
2507 | /// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst) | |||
2508 | /// to dst. | |||
2509 | class FoldInsertStridedSliceOfExtract final | |||
2510 | : public OpRewritePattern<InsertStridedSliceOp> { | |||
2511 | public: | |||
2512 | using OpRewritePattern::OpRewritePattern; | |||
2513 | ||||
2514 | LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp, | |||
2515 | PatternRewriter &rewriter) const override { | |||
2516 | auto extractStridedSliceOp = | |||
2517 | insertStridedSliceOp.getSource() | |||
2518 | .getDefiningOp<vector::ExtractStridedSliceOp>(); | |||
2519 | ||||
2520 | if (!extractStridedSliceOp) | |||
2521 | return failure(); | |||
2522 | ||||
2523 | if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest()) | |||
2524 | return failure(); | |||
2525 | ||||
2526 | // Check if have the same strides and offsets. | |||
2527 | if (extractStridedSliceOp.getStrides() != | |||
2528 | insertStridedSliceOp.getStrides() || | |||
2529 | extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets()) | |||
2530 | return failure(); | |||
2531 | ||||
2532 | rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest()); | |||
2533 | return success(); | |||
2534 | } | |||
2535 | }; | |||
2536 | ||||
2537 | // Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) -> | |||
2538 | // ConstantOp. | |||
2539 | class InsertStridedSliceConstantFolder final | |||
2540 | : public OpRewritePattern<InsertStridedSliceOp> { | |||
2541 | public: | |||
2542 | using OpRewritePattern::OpRewritePattern; | |||
2543 | ||||
2544 | // Do not create constants with more than `vectorSizeFoldThreashold` elements, | |||
2545 | // unless the source vector constant has a single use. | |||
2546 | static constexpr int64_t vectorSizeFoldThreshold = 256; | |||
2547 | ||||
2548 | LogicalResult matchAndRewrite(InsertStridedSliceOp op, | |||
2549 | PatternRewriter &rewriter) const override { | |||
2550 | // Return if 'InsertOp' operand is not defined by a compatible vector | |||
2551 | // ConstantOp. | |||
2552 | TypedValue<VectorType> destVector = op.getDest(); | |||
2553 | Attribute vectorDestCst; | |||
2554 | if (!matchPattern(destVector, m_Constant(&vectorDestCst))) | |||
2555 | return failure(); | |||
2556 | ||||
2557 | VectorType destTy = destVector.getType(); | |||
2558 | if (destTy.isScalable()) | |||
2559 | return failure(); | |||
2560 | ||||
2561 | // Make sure we do not create too many large constants. | |||
2562 | if (destTy.getNumElements() > vectorSizeFoldThreshold && | |||
2563 | !destVector.hasOneUse()) | |||
2564 | return failure(); | |||
2565 | ||||
2566 | auto denseDest = vectorDestCst.cast<DenseElementsAttr>(); | |||
2567 | ||||
2568 | TypedValue<VectorType> sourceValue = op.getSource(); | |||
2569 | Attribute sourceCst; | |||
2570 | if (!matchPattern(sourceValue, m_Constant(&sourceCst))) | |||
2571 | return failure(); | |||
2572 | ||||
2573 | // TODO: Handle non-unit strides when they become available. | |||
2574 | if (op.hasNonUnitStrides()) | |||
2575 | return failure(); | |||
2576 | ||||
2577 | VectorType sliceVecTy = sourceValue.getType(); | |||
2578 | ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); | |||
2579 | int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank(); | |||
2580 | SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets()); | |||
2581 | SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape()); | |||
2582 | ||||
2583 | // Calcualte the destination element indices by enumerating all slice | |||
2584 | // positions within the destination and linearizing them. The enumeration | |||
2585 | // order is lexicographic which yields a sequence of monotonically | |||
2586 | // increasing linearized position indices. | |||
2587 | // Because the destination may have higher dimensionality then the slice, | |||
2588 | // we keep track of two overlapping sets of positions and offsets. | |||
2589 | auto denseSlice = sourceCst.cast<DenseElementsAttr>(); | |||
2590 | auto sliceValuesIt = denseSlice.value_begin<Attribute>(); | |||
2591 | auto newValues = llvm::to_vector(denseDest.getValues<Attribute>()); | |||
2592 | SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end()); | |||
2593 | MutableArrayRef<int64_t> currSlicePosition( | |||
2594 | currDestPosition.begin() + rankDifference, currDestPosition.end()); | |||
2595 | ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference, | |||
2596 | offsets.end()); | |||
2597 | do { | |||
2598 | int64_t linearizedPosition = linearize(currDestPosition, destStrides); | |||
2599 | assert(linearizedPosition < destTy.getNumElements() && "Invalid index")(static_cast <bool> (linearizedPosition < destTy.getNumElements () && "Invalid index") ? void (0) : __assert_fail ("linearizedPosition < destTy.getNumElements() && \"Invalid index\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2599, __extension__ __PRETTY_FUNCTION__)); | |||
2600 | assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&(static_cast <bool> (sliceValuesIt != denseSlice.value_end <Attribute>() && "Invalid slice element") ? void (0) : __assert_fail ("sliceValuesIt != denseSlice.value_end<Attribute>() && \"Invalid slice element\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2601, __extension__ __PRETTY_FUNCTION__)) | |||
2601 | "Invalid slice element")(static_cast <bool> (sliceValuesIt != denseSlice.value_end <Attribute>() && "Invalid slice element") ? void (0) : __assert_fail ("sliceValuesIt != denseSlice.value_end<Attribute>() && \"Invalid slice element\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2601, __extension__ __PRETTY_FUNCTION__)); | |||
2602 | newValues[linearizedPosition] = *sliceValuesIt; | |||
2603 | ++sliceValuesIt; | |||
2604 | } while (succeeded( | |||
2605 | incSlicePosition(currSlicePosition, sliceShape, sliceOffsets))); | |||
2606 | ||||
2607 | auto newAttr = DenseElementsAttr::get(destTy, newValues); | |||
2608 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr); | |||
2609 | return success(); | |||
2610 | } | |||
2611 | }; | |||
2612 | ||||
2613 | } // namespace | |||
2614 | ||||
2615 | void vector::InsertStridedSliceOp::getCanonicalizationPatterns( | |||
2616 | RewritePatternSet &results, MLIRContext *context) { | |||
2617 | results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract, | |||
2618 | InsertStridedSliceConstantFolder>(context); | |||
2619 | } | |||
2620 | ||||
2621 | OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) { | |||
2622 | if (getSourceVectorType() == getDestVectorType()) | |||
2623 | return getSource(); | |||
2624 | return {}; | |||
2625 | } | |||
2626 | ||||
2627 | //===----------------------------------------------------------------------===// | |||
2628 | // OuterProductOp | |||
2629 | //===----------------------------------------------------------------------===// | |||
2630 | ||||
2631 | /// Build an op without mask, use the type of `acc` as the return type. | |||
2632 | void OuterProductOp::build(OpBuilder &builder, OperationState &result, | |||
2633 | Value lhs, Value rhs, Value acc) { | |||
2634 | result.addOperands({lhs, rhs, acc}); | |||
2635 | result.addTypes(acc.getType()); | |||
2636 | } | |||
2637 | ||||
2638 | void OuterProductOp::print(OpAsmPrinter &p) { | |||
2639 | p << " " << getLhs() << ", " << getRhs(); | |||
2640 | if (!getAcc().empty()) { | |||
2641 | p << ", " << getAcc(); | |||
2642 | p.printOptionalAttrDict((*this)->getAttrs()); | |||
2643 | } | |||
2644 | p << " : " << getLhs().getType() << ", " << getRhs().getType(); | |||
2645 | } | |||
2646 | ||||
2647 | ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) { | |||
2648 | SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo; | |||
2649 | Type tLHS, tRHS; | |||
2650 | if (parser.parseOperandList(operandsInfo) || | |||
2651 | parser.parseOptionalAttrDict(result.attributes) || | |||
2652 | parser.parseColonType(tLHS) || parser.parseComma() || | |||
2653 | parser.parseType(tRHS)) | |||
2654 | return failure(); | |||
2655 | if (operandsInfo.size() < 2) | |||
2656 | return parser.emitError(parser.getNameLoc(), | |||
2657 | "expected at least 2 operands"); | |||
2658 | VectorType vLHS = tLHS.dyn_cast<VectorType>(); | |||
2659 | VectorType vRHS = tRHS.dyn_cast<VectorType>(); | |||
2660 | if (!vLHS) | |||
2661 | return parser.emitError(parser.getNameLoc(), | |||
2662 | "expected vector type for operand #1"); | |||
2663 | VectorType resType = | |||
2664 | vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)}, | |||
2665 | vLHS.getElementType()) | |||
2666 | : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType()); | |||
2667 | ||||
2668 | if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) { | |||
2669 | result.attributes.append( | |||
2670 | OuterProductOp::getKindAttrStrName(), | |||
2671 | CombiningKindAttr::get(result.getContext(), | |||
2672 | OuterProductOp::getDefaultKind())); | |||
2673 | } | |||
2674 | ||||
2675 | return failure( | |||
2676 | parser.resolveOperand(operandsInfo[0], tLHS, result.operands) || | |||
2677 | parser.resolveOperand(operandsInfo[1], tRHS, result.operands) || | |||
2678 | (operandsInfo.size() > 2 && | |||
2679 | parser.resolveOperand(operandsInfo[2], resType, result.operands)) || | |||
2680 | parser.addTypeToList(resType, result.types)); | |||
2681 | } | |||
2682 | ||||
2683 | LogicalResult OuterProductOp::verify() { | |||
2684 | Type tRHS = getOperandTypeRHS(); | |||
2685 | VectorType vLHS = getOperandVectorTypeLHS(), | |||
2686 | vRHS = tRHS.dyn_cast<VectorType>(), | |||
2687 | vACC = getOperandVectorTypeACC(), vRES = getVectorType(); | |||
2688 | ||||
2689 | if (vLHS.getRank() != 1) | |||
2690 | return emitOpError("expected 1-d vector for operand #1"); | |||
2691 | ||||
2692 | if (vRHS) { | |||
2693 | // Proper OUTER operation. | |||
2694 | if (vRHS.getRank() != 1) | |||
2695 | return emitOpError("expected 1-d vector for operand #2"); | |||
2696 | if (vRES.getRank() != 2) | |||
2697 | return emitOpError("expected 2-d vector result"); | |||
2698 | if (vLHS.getDimSize(0) != vRES.getDimSize(0)) | |||
2699 | return emitOpError("expected #1 operand dim to match result dim #1"); | |||
2700 | if (vRHS.getDimSize(0) != vRES.getDimSize(1)) | |||
2701 | return emitOpError("expected #2 operand dim to match result dim #2"); | |||
2702 | } else { | |||
2703 | // An AXPY operation. | |||
2704 | if (vRES.getRank() != 1) | |||
2705 | return emitOpError("expected 1-d vector result"); | |||
2706 | if (vLHS.getDimSize(0) != vRES.getDimSize(0)) | |||
2707 | return emitOpError("expected #1 operand dim to match result dim #1"); | |||
2708 | } | |||
2709 | ||||
2710 | if (vACC && vACC != vRES) | |||
2711 | return emitOpError("expected operand #3 of same type as result type"); | |||
2712 | ||||
2713 | // Verify supported combining kind. | |||
2714 | if (!isSupportedCombiningKind(getKind(), vRES.getElementType())) | |||
2715 | return emitOpError("unsupported outerproduct type"); | |||
2716 | ||||
2717 | return success(); | |||
2718 | } | |||
2719 | ||||
2720 | //===----------------------------------------------------------------------===// | |||
2721 | // ReshapeOp | |||
2722 | //===----------------------------------------------------------------------===// | |||
2723 | ||||
2724 | LogicalResult ReshapeOp::verify() { | |||
2725 | // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank. | |||
2726 | auto inputVectorType = getInputVectorType(); | |||
2727 | auto outputVectorType = getOutputVectorType(); | |||
2728 | int64_t inputShapeRank = getNumInputShapeSizes(); | |||
2729 | int64_t outputShapeRank = getNumOutputShapeSizes(); | |||
2730 | SmallVector<int64_t, 4> fixedVectorSizes; | |||
2731 | getFixedVectorSizes(fixedVectorSizes); | |||
2732 | int64_t numFixedVectorSizes = fixedVectorSizes.size(); | |||
2733 | ||||
2734 | if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes) | |||
2735 | return emitError("invalid input shape for vector type ") << inputVectorType; | |||
2736 | ||||
2737 | if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes) | |||
2738 | return emitError("invalid output shape for vector type ") | |||
2739 | << outputVectorType; | |||
2740 | ||||
2741 | // Verify that the 'fixedVectorSizes' match an input/output vector shape | |||
2742 | // suffix. | |||
2743 | unsigned inputVectorRank = inputVectorType.getRank(); | |||
2744 | for (unsigned i = 0; i < numFixedVectorSizes; ++i) { | |||
2745 | unsigned index = inputVectorRank - numFixedVectorSizes - i; | |||
2746 | if (fixedVectorSizes[i] != inputVectorType.getShape()[index]) | |||
2747 | return emitError("fixed vector size must match input vector for dim ") | |||
2748 | << i; | |||
2749 | } | |||
2750 | ||||
2751 | unsigned outputVectorRank = outputVectorType.getRank(); | |||
2752 | for (unsigned i = 0; i < numFixedVectorSizes; ++i) { | |||
2753 | unsigned index = outputVectorRank - numFixedVectorSizes - i; | |||
2754 | if (fixedVectorSizes[i] != outputVectorType.getShape()[index]) | |||
2755 | return emitError("fixed vector size must match output vector for dim ") | |||
2756 | << i; | |||
2757 | } | |||
2758 | ||||
2759 | // If all shape operands are produced by constant ops, verify that product | |||
2760 | // of dimensions for input/output shape match. | |||
2761 | auto isDefByConstant = [](Value operand) { | |||
2762 | return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); | |||
2763 | }; | |||
2764 | if (llvm::all_of(getInputShape(), isDefByConstant) && | |||
2765 | llvm::all_of(getOutputShape(), isDefByConstant)) { | |||
2766 | int64_t numInputElements = 1; | |||
2767 | for (auto operand : getInputShape()) | |||
2768 | numInputElements *= | |||
2769 | cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); | |||
2770 | int64_t numOutputElements = 1; | |||
2771 | for (auto operand : getOutputShape()) | |||
2772 | numOutputElements *= | |||
2773 | cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value(); | |||
2774 | if (numInputElements != numOutputElements) | |||
2775 | return emitError("product of input and output shape sizes must match"); | |||
2776 | } | |||
2777 | return success(); | |||
2778 | } | |||
2779 | ||||
2780 | void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) { | |||
2781 | populateFromInt64AttrArray(getFixedVectorSizes(), results); | |||
2782 | } | |||
2783 | ||||
2784 | //===----------------------------------------------------------------------===// | |||
2785 | // ExtractStridedSliceOp | |||
2786 | //===----------------------------------------------------------------------===// | |||
2787 | ||||
2788 | // Inference works as follows: | |||
2789 | // 1. Add 'sizes' from prefix of dims in 'offsets'. | |||
2790 | // 2. Add sizes from 'vectorType' for remaining dims. | |||
2791 | static Type inferStridedSliceOpResultType(VectorType vectorType, | |||
2792 | ArrayAttr offsets, ArrayAttr sizes, | |||
2793 | ArrayAttr strides) { | |||
2794 | assert(offsets.size() == sizes.size() && offsets.size() == strides.size())(static_cast <bool> (offsets.size() == sizes.size() && offsets.size() == strides.size()) ? void (0) : __assert_fail ("offsets.size() == sizes.size() && offsets.size() == strides.size()" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2794, __extension__ __PRETTY_FUNCTION__)); | |||
2795 | SmallVector<int64_t, 4> shape; | |||
2796 | shape.reserve(vectorType.getRank()); | |||
2797 | unsigned idx = 0; | |||
2798 | for (unsigned e = offsets.size(); idx < e; ++idx) | |||
2799 | shape.push_back(sizes[idx].cast<IntegerAttr>().getInt()); | |||
2800 | for (unsigned e = vectorType.getShape().size(); idx < e; ++idx) | |||
2801 | shape.push_back(vectorType.getShape()[idx]); | |||
2802 | ||||
2803 | return VectorType::get(shape, vectorType.getElementType()); | |||
2804 | } | |||
2805 | ||||
2806 | void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result, | |||
2807 | Value source, ArrayRef<int64_t> offsets, | |||
2808 | ArrayRef<int64_t> sizes, | |||
2809 | ArrayRef<int64_t> strides) { | |||
2810 | result.addOperands(source); | |||
2811 | auto offsetsAttr = getVectorSubscriptAttr(builder, offsets); | |||
2812 | auto sizesAttr = getVectorSubscriptAttr(builder, sizes); | |||
2813 | auto stridesAttr = getVectorSubscriptAttr(builder, strides); | |||
2814 | result.addTypes( | |||
2815 | inferStridedSliceOpResultType(source.getType().cast<VectorType>(), | |||
2816 | offsetsAttr, sizesAttr, stridesAttr)); | |||
2817 | result.addAttribute(getOffsetsAttrStrName(), offsetsAttr); | |||
2818 | result.addAttribute(getSizesAttrStrName(), sizesAttr); | |||
2819 | result.addAttribute(getStridesAttrStrName(), stridesAttr); | |||
2820 | } | |||
2821 | ||||
2822 | LogicalResult ExtractStridedSliceOp::verify() { | |||
2823 | auto type = getVectorType(); | |||
2824 | auto offsets = getOffsetsAttr(); | |||
2825 | auto sizes = getSizesAttr(); | |||
2826 | auto strides = getStridesAttr(); | |||
2827 | if (offsets.size() != sizes.size() || offsets.size() != strides.size()) | |||
2828 | return emitOpError( | |||
2829 | "expected offsets, sizes and strides attributes of same size"); | |||
2830 | ||||
2831 | auto shape = type.getShape(); | |||
2832 | auto offName = getOffsetsAttrName(); | |||
2833 | auto sizesName = getSizesAttrName(); | |||
2834 | auto stridesName = getStridesAttrName(); | |||
2835 | if (failed( | |||
2836 | isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) || | |||
2837 | failed( | |||
2838 | isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) || | |||
2839 | failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape, | |||
2840 | stridesName)) || | |||
2841 | failed( | |||
2842 | isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) || | |||
2843 | failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName, | |||
2844 | /*halfOpen=*/false, | |||
2845 | /*min=*/1)) || | |||
2846 | failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1, | |||
2847 | stridesName, | |||
2848 | /*halfOpen=*/false)) || | |||
2849 | failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes, | |||
2850 | shape, offName, sizesName, | |||
2851 | /*halfOpen=*/false))) | |||
2852 | return failure(); | |||
2853 | ||||
2854 | auto resultType = | |||
2855 | inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides); | |||
2856 | if (getResult().getType() != resultType) | |||
2857 | return emitOpError("expected result type to be ") << resultType; | |||
2858 | ||||
2859 | return success(); | |||
2860 | } | |||
2861 | ||||
2862 | // When the source of ExtractStrided comes from a chain of InsertStrided ops try | |||
2863 | // to use the source of the InsertStrided ops if we can detect that the | |||
2864 | // extracted vector is a subset of one of the vector inserted. | |||
2865 | static LogicalResult | |||
2866 | foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) { | |||
2867 | // Helper to extract integer out of ArrayAttr. | |||
2868 | auto getElement = [](ArrayAttr array, int idx) { | |||
2869 | return array[idx].cast<IntegerAttr>().getInt(); | |||
2870 | }; | |||
2871 | ArrayAttr extractOffsets = op.getOffsets(); | |||
2872 | ArrayAttr extractStrides = op.getStrides(); | |||
2873 | ArrayAttr extractSizes = op.getSizes(); | |||
2874 | auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>(); | |||
2875 | while (insertOp) { | |||
2876 | if (op.getVectorType().getRank() != | |||
2877 | insertOp.getSourceVectorType().getRank()) | |||
2878 | return failure(); | |||
2879 | ArrayAttr insertOffsets = insertOp.getOffsets(); | |||
2880 | ArrayAttr insertStrides = insertOp.getStrides(); | |||
2881 | // If the rank of extract is greater than the rank of insert, we are likely | |||
2882 | // extracting a partial chunk of the vector inserted. | |||
2883 | if (extractOffsets.size() > insertOffsets.size()) | |||
2884 | return failure(); | |||
2885 | bool patialoverlap = false; | |||
2886 | bool disjoint = false; | |||
2887 | SmallVector<int64_t, 4> offsetDiffs; | |||
2888 | for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) { | |||
2889 | if (getElement(extractStrides, dim) != getElement(insertStrides, dim)) | |||
2890 | return failure(); | |||
2891 | int64_t start = getElement(insertOffsets, dim); | |||
2892 | int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim); | |||
2893 | int64_t offset = getElement(extractOffsets, dim); | |||
2894 | int64_t size = getElement(extractSizes, dim); | |||
2895 | // Check if the start of the extract offset is in the interval inserted. | |||
2896 | if (start <= offset && offset < end) { | |||
2897 | // If the extract interval overlaps but is not fully included we may | |||
2898 | // have a partial overlap that will prevent any folding. | |||
2899 | if (offset + size > end) | |||
2900 | patialoverlap = true; | |||
2901 | offsetDiffs.push_back(offset - start); | |||
2902 | continue; | |||
2903 | } | |||
2904 | disjoint = true; | |||
2905 | break; | |||
2906 | } | |||
2907 | // The extract element chunk is a subset of the insert element. | |||
2908 | if (!disjoint && !patialoverlap) { | |||
2909 | op.setOperand(insertOp.getSource()); | |||
2910 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
2911 | OpBuilder b(op.getContext()); | |||
2912 | op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(), | |||
2913 | b.getI64ArrayAttr(offsetDiffs)); | |||
2914 | return success(); | |||
2915 | } | |||
2916 | // If the chunk extracted is disjoint from the chunk inserted, keep looking | |||
2917 | // in the insert chain. | |||
2918 | if (disjoint) | |||
2919 | insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>(); | |||
2920 | else { | |||
2921 | // The extracted vector partially overlap the inserted vector, we cannot | |||
2922 | // fold. | |||
2923 | return failure(); | |||
2924 | } | |||
2925 | } | |||
2926 | return failure(); | |||
2927 | } | |||
2928 | ||||
2929 | OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) { | |||
2930 | if (getVectorType() == getResult().getType()) | |||
2931 | return getVector(); | |||
2932 | if (succeeded(foldExtractStridedOpFromInsertChain(*this))) | |||
2933 | return getResult(); | |||
2934 | return {}; | |||
2935 | } | |||
2936 | ||||
2937 | void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) { | |||
2938 | populateFromInt64AttrArray(getOffsets(), results); | |||
2939 | } | |||
2940 | ||||
2941 | namespace { | |||
2942 | ||||
2943 | // Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to | |||
2944 | // ConstantMaskOp. | |||
2945 | class StridedSliceConstantMaskFolder final | |||
2946 | : public OpRewritePattern<ExtractStridedSliceOp> { | |||
2947 | public: | |||
2948 | using OpRewritePattern::OpRewritePattern; | |||
2949 | ||||
2950 | LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | |||
2951 | PatternRewriter &rewriter) const override { | |||
2952 | // Return if 'extractStridedSliceOp' operand is not defined by a | |||
2953 | // ConstantMaskOp. | |||
2954 | auto *defOp = extractStridedSliceOp.getVector().getDefiningOp(); | |||
2955 | auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp); | |||
2956 | if (!constantMaskOp) | |||
2957 | return failure(); | |||
2958 | // Return if 'extractStridedSliceOp' has non-unit strides. | |||
2959 | if (extractStridedSliceOp.hasNonUnitStrides()) | |||
2960 | return failure(); | |||
2961 | // Gather constant mask dimension sizes. | |||
2962 | SmallVector<int64_t, 4> maskDimSizes; | |||
2963 | populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes); | |||
2964 | // Gather strided slice offsets and sizes. | |||
2965 | SmallVector<int64_t, 4> sliceOffsets; | |||
2966 | populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(), | |||
2967 | sliceOffsets); | |||
2968 | SmallVector<int64_t, 4> sliceSizes; | |||
2969 | populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes); | |||
2970 | ||||
2971 | // Compute slice of vector mask region. | |||
2972 | SmallVector<int64_t, 4> sliceMaskDimSizes; | |||
2973 | sliceMaskDimSizes.reserve(maskDimSizes.size()); | |||
2974 | for (auto [maskDimSize, sliceOffset, sliceSize] : | |||
2975 | llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) { | |||
2976 | int64_t sliceMaskDimSize = std::max( | |||
2977 | static_cast<int64_t>(0), | |||
2978 | std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset); | |||
2979 | sliceMaskDimSizes.push_back(sliceMaskDimSize); | |||
2980 | } | |||
2981 | // Add unchanged dimensions. | |||
2982 | if (sliceMaskDimSizes.size() < maskDimSizes.size()) | |||
2983 | for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i) | |||
2984 | sliceMaskDimSizes.push_back(maskDimSizes[i]); | |||
2985 | // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked | |||
2986 | // region is a conjunction of mask dim intervals). | |||
2987 | if (llvm::is_contained(sliceMaskDimSizes, 0)) | |||
2988 | sliceMaskDimSizes.assign(maskDimSizes.size(), 0); | |||
2989 | ||||
2990 | // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask | |||
2991 | // region. | |||
2992 | rewriter.replaceOpWithNewOp<ConstantMaskOp>( | |||
2993 | extractStridedSliceOp, extractStridedSliceOp.getResult().getType(), | |||
2994 | vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes)); | |||
2995 | return success(); | |||
2996 | } | |||
2997 | }; | |||
2998 | ||||
2999 | // Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp. | |||
3000 | class StridedSliceSplatConstantFolder final | |||
3001 | : public OpRewritePattern<ExtractStridedSliceOp> { | |||
3002 | public: | |||
3003 | using OpRewritePattern::OpRewritePattern; | |||
3004 | ||||
3005 | LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | |||
3006 | PatternRewriter &rewriter) const override { | |||
3007 | // Return if 'ExtractStridedSliceOp' operand is not defined by a splat | |||
3008 | // ConstantOp. | |||
3009 | Value sourceVector = extractStridedSliceOp.getVector(); | |||
3010 | Attribute vectorCst; | |||
3011 | if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | |||
3012 | return failure(); | |||
3013 | ||||
3014 | auto splat = vectorCst.dyn_cast<SplatElementsAttr>(); | |||
3015 | if (!splat) | |||
3016 | return failure(); | |||
3017 | ||||
3018 | auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(), | |||
3019 | splat.getSplatValue<Attribute>()); | |||
3020 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, | |||
3021 | newAttr); | |||
3022 | return success(); | |||
3023 | } | |||
3024 | }; | |||
3025 | ||||
3026 | // Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) -> | |||
3027 | // ConstantOp. | |||
3028 | class StridedSliceNonSplatConstantFolder final | |||
3029 | : public OpRewritePattern<ExtractStridedSliceOp> { | |||
3030 | public: | |||
3031 | using OpRewritePattern::OpRewritePattern; | |||
3032 | ||||
3033 | LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp, | |||
3034 | PatternRewriter &rewriter) const override { | |||
3035 | // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat | |||
3036 | // ConstantOp. | |||
3037 | Value sourceVector = extractStridedSliceOp.getVector(); | |||
3038 | Attribute vectorCst; | |||
3039 | if (!matchPattern(sourceVector, m_Constant(&vectorCst))) | |||
3040 | return failure(); | |||
3041 | ||||
3042 | // The splat case is handled by `StridedSliceSplatConstantFolder`. | |||
3043 | auto dense = vectorCst.dyn_cast<DenseElementsAttr>(); | |||
3044 | if (!dense || dense.isSplat()) | |||
3045 | return failure(); | |||
3046 | ||||
3047 | // TODO: Handle non-unit strides when they become available. | |||
3048 | if (extractStridedSliceOp.hasNonUnitStrides()) | |||
3049 | return failure(); | |||
3050 | ||||
3051 | auto sourceVecTy = sourceVector.getType().cast<VectorType>(); | |||
3052 | ArrayRef<int64_t> sourceShape = sourceVecTy.getShape(); | |||
3053 | SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape); | |||
3054 | ||||
3055 | VectorType sliceVecTy = extractStridedSliceOp.getType(); | |||
3056 | ArrayRef<int64_t> sliceShape = sliceVecTy.getShape(); | |||
3057 | int64_t sliceRank = sliceVecTy.getRank(); | |||
3058 | ||||
3059 | // Expand offsets and sizes to match the vector rank. | |||
3060 | SmallVector<int64_t, 4> offsets(sliceRank, 0); | |||
3061 | copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin()); | |||
3062 | ||||
3063 | SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end()); | |||
3064 | copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin()); | |||
3065 | ||||
3066 | // Calculate the slice elements by enumerating all slice positions and | |||
3067 | // linearizing them. The enumeration order is lexicographic which yields a | |||
3068 | // sequence of monotonically increasing linearized position indices. | |||
3069 | auto denseValuesBegin = dense.value_begin<Attribute>(); | |||
3070 | SmallVector<Attribute> sliceValues; | |||
3071 | sliceValues.reserve(sliceVecTy.getNumElements()); | |||
3072 | SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end()); | |||
3073 | do { | |||
3074 | int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides); | |||
3075 | assert(linearizedPosition < sourceVecTy.getNumElements() &&(static_cast <bool> (linearizedPosition < sourceVecTy .getNumElements() && "Invalid index") ? void (0) : __assert_fail ("linearizedPosition < sourceVecTy.getNumElements() && \"Invalid index\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3076, __extension__ __PRETTY_FUNCTION__)) | |||
3076 | "Invalid index")(static_cast <bool> (linearizedPosition < sourceVecTy .getNumElements() && "Invalid index") ? void (0) : __assert_fail ("linearizedPosition < sourceVecTy.getNumElements() && \"Invalid index\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3076, __extension__ __PRETTY_FUNCTION__)); | |||
3077 | sliceValues.push_back(*(denseValuesBegin + linearizedPosition)); | |||
3078 | } while ( | |||
3079 | succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets))); | |||
3080 | ||||
3081 | assert(static_cast<int64_t>(sliceValues.size()) ==(static_cast <bool> (static_cast<int64_t>(sliceValues .size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements" ) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__ __PRETTY_FUNCTION__)) | |||
3082 | sliceVecTy.getNumElements() &&(static_cast <bool> (static_cast<int64_t>(sliceValues .size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements" ) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__ __PRETTY_FUNCTION__)) | |||
3083 | "Invalid number of slice elements")(static_cast <bool> (static_cast<int64_t>(sliceValues .size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements" ) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__ __PRETTY_FUNCTION__)); | |||
3084 | auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues); | |||
3085 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp, | |||
3086 | newAttr); | |||
3087 | return success(); | |||
3088 | } | |||
3089 | }; | |||
3090 | ||||
3091 | // Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to | |||
3092 | // BroadcastOp(ExtractStrideSliceOp). | |||
3093 | class StridedSliceBroadcast final | |||
3094 | : public OpRewritePattern<ExtractStridedSliceOp> { | |||
3095 | public: | |||
3096 | using OpRewritePattern::OpRewritePattern; | |||
3097 | ||||
3098 | LogicalResult matchAndRewrite(ExtractStridedSliceOp op, | |||
3099 | PatternRewriter &rewriter) const override { | |||
3100 | auto broadcast = op.getVector().getDefiningOp<BroadcastOp>(); | |||
3101 | if (!broadcast) | |||
3102 | return failure(); | |||
3103 | auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>(); | |||
3104 | unsigned srcRank = srcVecType ? srcVecType.getRank() : 0; | |||
3105 | auto dstVecType = op.getType().cast<VectorType>(); | |||
3106 | unsigned dstRank = dstVecType.getRank(); | |||
3107 | unsigned rankDiff = dstRank - srcRank; | |||
3108 | // Check if the most inner dimensions of the source of the broadcast are the | |||
3109 | // same as the destination of the extract. If this is the case we can just | |||
3110 | // use a broadcast as the original dimensions are untouched. | |||
3111 | bool lowerDimMatch = true; | |||
3112 | for (unsigned i = 0; i < srcRank; i++) { | |||
3113 | if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) { | |||
3114 | lowerDimMatch = false; | |||
3115 | break; | |||
3116 | } | |||
3117 | } | |||
3118 | Value source = broadcast.getSource(); | |||
3119 | // If the inner dimensions don't match, it means we need to extract from the | |||
3120 | // source of the orignal broadcast and then broadcast the extracted value. | |||
3121 | // We also need to handle degenerated cases where the source is effectively | |||
3122 | // just a single scalar. | |||
3123 | bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1); | |||
3124 | if (!lowerDimMatch && !isScalarSrc) { | |||
3125 | source = rewriter.create<ExtractStridedSliceOp>( | |||
3126 | op->getLoc(), source, | |||
3127 | getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff), | |||
3128 | getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff), | |||
3129 | getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff)); | |||
3130 | } | |||
3131 | rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source); | |||
3132 | return success(); | |||
3133 | } | |||
3134 | }; | |||
3135 | ||||
3136 | /// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp. | |||
3137 | class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> { | |||
3138 | public: | |||
3139 | using OpRewritePattern::OpRewritePattern; | |||
3140 | ||||
3141 | LogicalResult matchAndRewrite(ExtractStridedSliceOp op, | |||
3142 | PatternRewriter &rewriter) const override { | |||
3143 | auto splat = op.getVector().getDefiningOp<SplatOp>(); | |||
3144 | if (!splat) | |||
3145 | return failure(); | |||
3146 | rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput()); | |||
3147 | return success(); | |||
3148 | } | |||
3149 | }; | |||
3150 | ||||
3151 | } // namespace | |||
3152 | ||||
3153 | void ExtractStridedSliceOp::getCanonicalizationPatterns( | |||
3154 | RewritePatternSet &results, MLIRContext *context) { | |||
3155 | // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) -> | |||
3156 | // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp. | |||
3157 | results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder, | |||
3158 | StridedSliceNonSplatConstantFolder, StridedSliceBroadcast, | |||
3159 | StridedSliceSplat>(context); | |||
3160 | } | |||
3161 | ||||
3162 | //===----------------------------------------------------------------------===// | |||
3163 | // TransferReadOp | |||
3164 | //===----------------------------------------------------------------------===// | |||
3165 | ||||
3166 | /// 1. Builder that sets padding to zero and an empty mask (variant with attrs). | |||
3167 | void TransferReadOp::build(OpBuilder &builder, OperationState &result, | |||
3168 | VectorType vectorType, Value source, | |||
3169 | ValueRange indices, AffineMapAttr permutationMapAttr, | |||
3170 | /*optional*/ ArrayAttr inBoundsAttr) { | |||
3171 | Type elemType = source.getType().cast<ShapedType>().getElementType(); | |||
3172 | Value padding = builder.create<arith::ConstantOp>( | |||
3173 | result.location, elemType, builder.getZeroAttr(elemType)); | |||
3174 | build(builder, result, vectorType, source, indices, permutationMapAttr, | |||
3175 | padding, /*mask=*/Value(), inBoundsAttr); | |||
3176 | } | |||
3177 | ||||
3178 | /// 2. Builder that sets padding to zero an empty mask (variant without attrs). | |||
3179 | void TransferReadOp::build(OpBuilder &builder, OperationState &result, | |||
3180 | VectorType vectorType, Value source, | |||
3181 | ValueRange indices, AffineMap permutationMap, | |||
3182 | std::optional<ArrayRef<bool>> inBounds) { | |||
3183 | auto permutationMapAttr = AffineMapAttr::get(permutationMap); | |||
3184 | auto inBoundsAttr = (inBounds && !inBounds.value().empty()) | |||
3185 | ? builder.getBoolArrayAttr(inBounds.value()) | |||
3186 | : ArrayAttr(); | |||
3187 | build(builder, result, vectorType, source, indices, permutationMapAttr, | |||
3188 | inBoundsAttr); | |||
3189 | } | |||
3190 | ||||
3191 | /// 3. Builder that sets permutation map to 'getMinorIdentityMap'. | |||
3192 | void TransferReadOp::build(OpBuilder &builder, OperationState &result, | |||
3193 | VectorType vectorType, Value source, | |||
3194 | ValueRange indices, Value padding, | |||
3195 | std::optional<ArrayRef<bool>> inBounds) { | |||
3196 | AffineMap permutationMap = getTransferMinorIdentityMap( | |||
3197 | source.getType().cast<ShapedType>(), vectorType); | |||
3198 | auto permutationMapAttr = AffineMapAttr::get(permutationMap); | |||
3199 | auto inBoundsAttr = (inBounds && !inBounds.value().empty()) | |||
3200 | ? builder.getBoolArrayAttr(inBounds.value()) | |||
3201 | : ArrayAttr(); | |||
3202 | build(builder, result, vectorType, source, indices, permutationMapAttr, | |||
3203 | padding, | |||
3204 | /*mask=*/Value(), inBoundsAttr); | |||
3205 | } | |||
3206 | ||||
3207 | /// 4. Builder that sets padding to zero and permutation map to | |||
3208 | /// 'getMinorIdentityMap'. | |||
3209 | void TransferReadOp::build(OpBuilder &builder, OperationState &result, | |||
3210 | VectorType vectorType, Value source, | |||
3211 | ValueRange indices, | |||
3212 | std::optional<ArrayRef<bool>> inBounds) { | |||
3213 | Type elemType = source.getType().cast<ShapedType>().getElementType(); | |||
3214 | Value padding = builder.create<arith::ConstantOp>( | |||
3215 | result.location, elemType, builder.getZeroAttr(elemType)); | |||
3216 | build(builder, result, vectorType, source, indices, padding, inBounds); | |||
3217 | } | |||
3218 | ||||
3219 | template <typename EmitFun> | |||
3220 | static LogicalResult verifyPermutationMap(AffineMap permutationMap, | |||
3221 | EmitFun emitOpError) { | |||
3222 | SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false); | |||
3223 | for (auto expr : permutationMap.getResults()) { | |||
3224 | auto dim = expr.dyn_cast<AffineDimExpr>(); | |||
3225 | auto zero = expr.dyn_cast<AffineConstantExpr>(); | |||
3226 | if (zero) { | |||
3227 | if (zero.getValue() != 0) { | |||
3228 | return emitOpError( | |||
3229 | "requires a projected permutation_map (at most one dim or the zero " | |||
3230 | "constant can appear in each result)"); | |||
3231 | } | |||
3232 | continue; | |||
3233 | } | |||
3234 | if (!dim) { | |||
3235 | return emitOpError("requires a projected permutation_map (at most one " | |||
3236 | "dim or the zero constant can appear in each result)"); | |||
3237 | } | |||
3238 | if (seen[dim.getPosition()]) { | |||
3239 | return emitOpError( | |||
3240 | "requires a permutation_map that is a permutation (found one dim " | |||
3241 | "used more than once)"); | |||
3242 | } | |||
3243 | seen[dim.getPosition()] = true; | |||
3244 | } | |||
3245 | return success(); | |||
3246 | } | |||
3247 | ||||
3248 | static LogicalResult | |||
3249 | verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType, | |||
3250 | VectorType vectorType, VectorType maskType, | |||
3251 | VectorType inferredMaskType, AffineMap permutationMap, | |||
3252 | ArrayAttr inBounds) { | |||
3253 | if (op->hasAttr("masked")) { | |||
3254 | return op->emitOpError("masked attribute has been removed. " | |||
3255 | "Use in_bounds instead."); | |||
3256 | } | |||
3257 | ||||
3258 | if (!shapedType.isa<MemRefType, RankedTensorType>()) | |||
3259 | return op->emitOpError( | |||
3260 | "requires source to be a memref or ranked tensor type"); | |||
3261 | ||||
3262 | auto elementType = shapedType.getElementType(); | |||
3263 | DataLayout dataLayout = DataLayout::closest(op); | |||
3264 | if (auto vectorElementType = elementType.dyn_cast<VectorType>()) { | |||
3265 | // Memref or tensor has vector element type. | |||
3266 | unsigned sourceVecSize = | |||
3267 | dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) * | |||
3268 | vectorElementType.getShape().back(); | |||
3269 | unsigned resultVecSize = | |||
3270 | dataLayout.getTypeSizeInBits(vectorType.getElementType()) * | |||
3271 | vectorType.getShape().back(); | |||
3272 | if (resultVecSize % sourceVecSize != 0) | |||
3273 | return op->emitOpError( | |||
3274 | "requires the bitwidth of the minor 1-D vector to be an integral " | |||
3275 | "multiple of the bitwidth of the minor 1-D vector of the source"); | |||
3276 | ||||
3277 | unsigned sourceVecEltRank = vectorElementType.getRank(); | |||
3278 | unsigned resultVecRank = vectorType.getRank(); | |||
3279 | if (sourceVecEltRank > resultVecRank) | |||
3280 | return op->emitOpError( | |||
3281 | "requires source vector element and vector result ranks to match."); | |||
3282 | unsigned rankOffset = resultVecRank - sourceVecEltRank; | |||
3283 | // Check that permutation map results match 'rankOffset' of vector type. | |||
3284 | if (permutationMap.getNumResults() != rankOffset) | |||
3285 | return op->emitOpError("requires a permutation_map with result dims of " | |||
3286 | "the same rank as the vector type"); | |||
3287 | ||||
3288 | if (maskType) | |||
3289 | return op->emitOpError("does not support masks with vector element type"); | |||
3290 | } else { | |||
3291 | // Memref or tensor has scalar element type. | |||
3292 | unsigned minorSize = | |||
3293 | vectorType.getRank() == 0 ? 1 : vectorType.getShape().back(); | |||
3294 | unsigned resultVecSize = | |||
3295 | dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize; | |||
3296 | if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0) | |||
3297 | return op->emitOpError( | |||
3298 | "requires the bitwidth of the minor 1-D vector to be an integral " | |||
3299 | "multiple of the bitwidth of the source element type"); | |||
3300 | ||||
3301 | // Check that permutation map results match rank of vector type. | |||
3302 | if (permutationMap.getNumResults() != vectorType.getRank()) | |||
3303 | return op->emitOpError("requires a permutation_map with result dims of " | |||
3304 | "the same rank as the vector type"); | |||
3305 | } | |||
3306 | ||||
3307 | if (permutationMap.getNumSymbols() != 0) | |||
3308 | return op->emitOpError("requires permutation_map without symbols"); | |||
3309 | ||||
3310 | if (permutationMap.getNumInputs() != shapedType.getRank()) | |||
3311 | return op->emitOpError("requires a permutation_map with input dims of the " | |||
3312 | "same rank as the source type"); | |||
3313 | ||||
3314 | if (maskType && maskType != inferredMaskType) | |||
3315 | return op->emitOpError("inferred mask type (") | |||
3316 | << inferredMaskType << ") and mask operand type (" << maskType | |||
3317 | << ") don't match"; | |||
3318 | ||||
3319 | if (inBounds) { | |||
3320 | if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size())) | |||
3321 | return op->emitOpError("expects the optional in_bounds attr of same rank " | |||
3322 | "as permutation_map results: ") | |||
3323 | << AffineMapAttr::get(permutationMap) | |||
3324 | << " vs inBounds of size: " << inBounds.size(); | |||
3325 | for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i) | |||
3326 | if (permutationMap.getResult(i).isa<AffineConstantExpr>() && | |||
3327 | !inBounds.getValue()[i].cast<BoolAttr>().getValue()) | |||
3328 | return op->emitOpError("requires broadcast dimensions to be in-bounds"); | |||
3329 | } | |||
3330 | ||||
3331 | return success(); | |||
3332 | } | |||
3333 | ||||
3334 | static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) { | |||
3335 | SmallVector<StringRef, 3> elidedAttrs; | |||
3336 | elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr()); | |||
3337 | if (op.permutation_map().isMinorIdentity()) | |||
3338 | elidedAttrs.push_back(op.getPermutationMapAttrStrName()); | |||
3339 | bool elideInBounds = true; | |||
3340 | if (auto inBounds = op.in_bounds()) { | |||
3341 | for (auto attr : *inBounds) { | |||
3342 | if (attr.template cast<BoolAttr>().getValue()) { | |||
3343 | elideInBounds = false; | |||
3344 | break; | |||
3345 | } | |||
3346 | } | |||
3347 | } | |||
3348 | if (elideInBounds) | |||
3349 | elidedAttrs.push_back(op.getInBoundsAttrStrName()); | |||
3350 | p.printOptionalAttrDict(op->getAttrs(), elidedAttrs); | |||
3351 | } | |||
3352 | ||||
3353 | void TransferReadOp::print(OpAsmPrinter &p) { | |||
3354 | p << " " << getSource() << "[" << getIndices() << "], " << getPadding(); | |||
3355 | if (getMask()) | |||
3356 | p << ", " << getMask(); | |||
3357 | printTransferAttrs(p, *this); | |||
3358 | p << " : " << getShapedType() << ", " << getVectorType(); | |||
3359 | } | |||
3360 | ||||
3361 | /// Infers the mask type for a transfer read given its vector type and | |||
3362 | /// permutation map. The mask in a transfer read operation applies to the | |||
3363 | /// tensor/buffer reading part of it and its type should match the shape read | |||
3364 | /// *before* any permutation or broadcasting. | |||
3365 | static VectorType inferTransferReadMaskType(VectorType vecType, | |||
3366 | AffineMap permMap) { | |||
3367 | auto i1Type = IntegerType::get(permMap.getContext(), 1); | |||
3368 | AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap)); | |||
3369 | assert(invPermMap && "Inversed permutation map couldn't be computed")(static_cast <bool> (invPermMap && "Inversed permutation map couldn't be computed" ) ? void (0) : __assert_fail ("invPermMap && \"Inversed permutation map couldn't be computed\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3369, __extension__ __PRETTY_FUNCTION__)); | |||
3370 | SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape()); | |||
3371 | return VectorType::get(maskShape, i1Type); | |||
3372 | } | |||
3373 | ||||
3374 | ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) { | |||
3375 | auto &builder = parser.getBuilder(); | |||
3376 | SMLoc typesLoc; | |||
3377 | OpAsmParser::UnresolvedOperand sourceInfo; | |||
3378 | SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo; | |||
3379 | OpAsmParser::UnresolvedOperand paddingInfo; | |||
3380 | SmallVector<Type, 2> types; | |||
3381 | OpAsmParser::UnresolvedOperand maskInfo; | |||
3382 | // Parsing with support for paddingValue. | |||
3383 | if (parser.parseOperand(sourceInfo) || | |||
3384 | parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || | |||
3385 | parser.parseComma() || parser.parseOperand(paddingInfo)) | |||
3386 | return failure(); | |||
3387 | ParseResult hasMask = parser.parseOptionalComma(); | |||
3388 | if (hasMask.succeeded()) { | |||
3389 | if (parser.parseOperand(maskInfo)) | |||
3390 | return failure(); | |||
3391 | } | |||
3392 | if (parser.parseOptionalAttrDict(result.attributes) || | |||
3393 | parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) | |||
3394 | return failure(); | |||
3395 | if (types.size() != 2) | |||
3396 | return parser.emitError(typesLoc, "requires two types"); | |||
3397 | auto indexType = builder.getIndexType(); | |||
3398 | auto shapedType = types[0].dyn_cast<ShapedType>(); | |||
3399 | if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) | |||
3400 | return parser.emitError(typesLoc, "requires memref or ranked tensor type"); | |||
3401 | VectorType vectorType = types[1].dyn_cast<VectorType>(); | |||
3402 | if (!vectorType) | |||
3403 | return parser.emitError(typesLoc, "requires vector type"); | |||
3404 | auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName(); | |||
3405 | Attribute permMapAttr = result.attributes.get(permMapAttrName); | |||
3406 | AffineMap permMap; | |||
3407 | if (!permMapAttr) { | |||
3408 | permMap = getTransferMinorIdentityMap(shapedType, vectorType); | |||
3409 | result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); | |||
3410 | } else { | |||
3411 | permMap = permMapAttr.cast<AffineMapAttr>().getValue(); | |||
3412 | } | |||
3413 | if (parser.resolveOperand(sourceInfo, shapedType, result.operands) || | |||
3414 | parser.resolveOperands(indexInfo, indexType, result.operands) || | |||
3415 | parser.resolveOperand(paddingInfo, shapedType.getElementType(), | |||
3416 | result.operands)) | |||
3417 | return failure(); | |||
3418 | if (hasMask.succeeded()) { | |||
3419 | if (shapedType.getElementType().dyn_cast<VectorType>()) | |||
3420 | return parser.emitError( | |||
3421 | maskInfo.location, "does not support masks with vector element type"); | |||
3422 | // Instead of adding the mask type as an op type, compute it based on the | |||
3423 | // vector type and the permutation map (to keep the type signature small). | |||
3424 | auto maskType = inferTransferReadMaskType(vectorType, permMap); | |||
3425 | if (parser.resolveOperand(maskInfo, maskType, result.operands)) | |||
3426 | return failure(); | |||
3427 | } | |||
3428 | result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(), | |||
3429 | builder.getDenseI32ArrayAttr( | |||
3430 | {1, static_cast<int32_t>(indexInfo.size()), 1, | |||
3431 | static_cast<int32_t>(hasMask.succeeded())})); | |||
3432 | return parser.addTypeToList(vectorType, result.types); | |||
3433 | } | |||
3434 | ||||
3435 | LogicalResult TransferReadOp::verify() { | |||
3436 | // Consistency of elemental types in source and vector. | |||
3437 | ShapedType shapedType = getShapedType(); | |||
3438 | VectorType vectorType = getVectorType(); | |||
3439 | VectorType maskType = getMaskType(); | |||
3440 | auto paddingType = getPadding().getType(); | |||
3441 | auto permutationMap = getPermutationMap(); | |||
3442 | VectorType inferredMaskType = | |||
3443 | maskType ? inferTransferReadMaskType(vectorType, permutationMap) | |||
3444 | : VectorType(); | |||
3445 | auto sourceElementType = shapedType.getElementType(); | |||
3446 | ||||
3447 | if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank()) | |||
3448 | return emitOpError("requires ") << shapedType.getRank() << " indices"; | |||
3449 | ||||
3450 | if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), | |||
3451 | shapedType, vectorType, maskType, | |||
3452 | inferredMaskType, permutationMap, | |||
3453 | getInBounds() ? *getInBounds() : ArrayAttr()))) | |||
3454 | return failure(); | |||
3455 | ||||
3456 | if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) { | |||
3457 | // Source has vector element type. | |||
3458 | // Check that 'sourceVectorElementType' and 'paddingType' types match. | |||
3459 | if (sourceVectorElementType != paddingType) | |||
3460 | return emitOpError( | |||
3461 | "requires source element type and padding type to match."); | |||
3462 | ||||
3463 | } else { | |||
3464 | // Check that 'paddingType' is valid to store in a vector type. | |||
3465 | if (!VectorType::isValidElementType(paddingType)) | |||
3466 | return emitOpError("requires valid padding vector elemental type"); | |||
3467 | ||||
3468 | // Check that padding type and vector element types match. | |||
3469 | if (paddingType != sourceElementType) | |||
3470 | return emitOpError( | |||
3471 | "requires formal padding and source of the same elemental type"); | |||
3472 | } | |||
3473 | ||||
3474 | return verifyPermutationMap(permutationMap, | |||
3475 | [&](Twine t) { return emitOpError(t); }); | |||
3476 | } | |||
3477 | ||||
3478 | // MaskableOpInterface methods. | |||
3479 | ||||
3480 | /// Returns the mask type expected by this operation. Mostly used for | |||
3481 | /// verification purposes. It requires the operation to be vectorized." | |||
3482 | Type TransferReadOp::getExpectedMaskType() { | |||
3483 | return inferTransferReadMaskType(getVectorType(), getPermutationMap()); | |||
3484 | } | |||
3485 | ||||
3486 | template <typename TransferOp> | |||
3487 | static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) { | |||
3488 | // TODO: support more aggressive createOrFold on: | |||
3489 | // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)` | |||
3490 | if (op.getShapedType().isDynamicDim(indicesIdx)) | |||
3491 | return false; | |||
3492 | Value index = op.getIndices()[indicesIdx]; | |||
3493 | auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>(); | |||
3494 | if (!cstOp) | |||
3495 | return false; | |||
3496 | ||||
3497 | int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx); | |||
3498 | int64_t vectorSize = op.getVectorType().getDimSize(resultIdx); | |||
3499 | ||||
3500 | return cstOp.value() + vectorSize <= sourceSize; | |||
3501 | } | |||
3502 | ||||
3503 | template <typename TransferOp> | |||
3504 | static LogicalResult foldTransferInBoundsAttribute(TransferOp op) { | |||
3505 | // TODO: support 0-d corner case. | |||
3506 | // TODO: Be less conservative. | |||
3507 | if (op.getTransferRank() == 0) | |||
3508 | return failure(); | |||
3509 | AffineMap permutationMap = op.getPermutationMap(); | |||
3510 | bool changed = false; | |||
3511 | SmallVector<bool, 4> newInBounds; | |||
3512 | newInBounds.reserve(op.getTransferRank()); | |||
3513 | for (unsigned i = 0; i < op.getTransferRank(); ++i) { | |||
3514 | // Already marked as in-bounds, nothing to see here. | |||
3515 | if (op.isDimInBounds(i)) { | |||
3516 | newInBounds.push_back(true); | |||
3517 | continue; | |||
3518 | } | |||
3519 | // Currently out-of-bounds, check whether we can statically determine it is | |||
3520 | // inBounds. | |||
3521 | auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>(); | |||
3522 | assert(dimExpr && "Broadcast dims must be in-bounds")(static_cast <bool> (dimExpr && "Broadcast dims must be in-bounds" ) ? void (0) : __assert_fail ("dimExpr && \"Broadcast dims must be in-bounds\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3522, __extension__ __PRETTY_FUNCTION__)); | |||
3523 | auto inBounds = | |||
3524 | isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition()); | |||
3525 | newInBounds.push_back(inBounds); | |||
3526 | // We commit the pattern if it is "more inbounds". | |||
3527 | changed |= inBounds; | |||
3528 | } | |||
3529 | if (!changed) | |||
3530 | return failure(); | |||
3531 | // OpBuilder is only used as a helper to build an I64ArrayAttr. | |||
3532 | OpBuilder b(op.getContext()); | |||
3533 | op->setAttr(TransferOp::getInBoundsAttrStrName(), | |||
3534 | b.getBoolArrayAttr(newInBounds)); | |||
3535 | return success(); | |||
3536 | } | |||
3537 | ||||
3538 | /// ``` | |||
3539 | /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} | |||
3540 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
3541 | /// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]} | |||
3542 | /// : tensor<4x4xf32>, vector<1x4xf32> | |||
3543 | /// ``` | |||
3544 | /// -> Folds into | |||
3545 | /// ``` | |||
3546 | /// %v0 | |||
3547 | /// ``` | |||
3548 | static Value foldRAW(TransferReadOp readOp) { | |||
3549 | if (!readOp.getShapedType().isa<RankedTensorType>()) | |||
3550 | return {}; | |||
3551 | auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>(); | |||
3552 | while (defWrite) { | |||
3553 | if (checkSameValueRAW(defWrite, readOp)) | |||
3554 | return defWrite.getVector(); | |||
3555 | if (!isDisjointTransferIndices( | |||
3556 | cast<VectorTransferOpInterface>(defWrite.getOperation()), | |||
3557 | cast<VectorTransferOpInterface>(readOp.getOperation()))) | |||
3558 | break; | |||
3559 | defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>(); | |||
3560 | } | |||
3561 | return {}; | |||
3562 | } | |||
3563 | ||||
3564 | OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) { | |||
3565 | if (Value vec = foldRAW(*this)) | |||
3566 | return vec; | |||
3567 | /// transfer_read(memrefcast) -> transfer_read | |||
3568 | if (succeeded(foldTransferInBoundsAttribute(*this))) | |||
3569 | return getResult(); | |||
3570 | if (succeeded(memref::foldMemRefCast(*this))) | |||
3571 | return getResult(); | |||
3572 | if (succeeded(tensor::foldTensorCast(*this))) | |||
3573 | return getResult(); | |||
3574 | return OpFoldResult(); | |||
3575 | } | |||
3576 | ||||
3577 | std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() { | |||
3578 | return llvm::to_vector<4>(getVectorType().getShape()); | |||
3579 | } | |||
3580 | ||||
3581 | void TransferReadOp::getEffects( | |||
3582 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | |||
3583 | &effects) { | |||
3584 | if (getShapedType().isa<MemRefType>()) | |||
3585 | effects.emplace_back(MemoryEffects::Read::get(), getSource(), | |||
3586 | SideEffects::DefaultResource::get()); | |||
3587 | } | |||
3588 | ||||
3589 | /// Returns true if all rank reduced in the given `extractOp` happen in leading | |||
3590 | /// dimensions earlier than last `trailingRank` dimensions. | |||
3591 | static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp, | |||
3592 | unsigned trailingRank) { | |||
3593 | // If no ranks are reduced at all, it's a degenerated case; always true. | |||
3594 | if (extractOp.getSourceType().getRank() == extractOp.getType().getRank()) | |||
3595 | return true; | |||
3596 | ||||
3597 | RankedTensorType inferredType = extractOp.inferResultType( | |||
3598 | extractOp.getSourceType(), extractOp.getMixedOffsets(), | |||
3599 | extractOp.getMixedSizes(), extractOp.getMixedStrides()); | |||
3600 | return extractOp.getType().getShape().take_back(trailingRank) == | |||
3601 | inferredType.getShape().take_back(trailingRank); | |||
3602 | } | |||
3603 | ||||
3604 | namespace { | |||
3605 | /// Fold transfer_reads of a tensor.extract_slice op. E.g.: | |||
3606 | /// | |||
3607 | /// ``` | |||
3608 | /// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1] | |||
3609 | /// : tensor<?x?xf32> to tensor<?x?xf32> | |||
3610 | /// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]} | |||
3611 | /// : tensor<?x?xf32>, vector<4x5xf32> | |||
3612 | /// ``` | |||
3613 | /// is rewritten to: | |||
3614 | /// ``` | |||
3615 | /// %p0 = arith.addi %a, %e : index | |||
3616 | /// %p1 = arith.addi %b, %f : index | |||
3617 | /// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]} | |||
3618 | /// : tensor<?x?xf32>, vector<4x5xf32> | |||
3619 | /// ``` | |||
3620 | struct FoldExtractSliceIntoTransferRead | |||
3621 | : public OpRewritePattern<TransferReadOp> { | |||
3622 | public: | |||
3623 | using OpRewritePattern::OpRewritePattern; | |||
3624 | ||||
3625 | LogicalResult matchAndRewrite(TransferReadOp xferOp, | |||
3626 | PatternRewriter &rewriter) const override { | |||
3627 | // TODO: support 0-d corner case. | |||
3628 | if (xferOp.getTransferRank() == 0) | |||
3629 | return failure(); | |||
3630 | if (xferOp.hasOutOfBoundsDim()) | |||
3631 | return failure(); | |||
3632 | if (!xferOp.getPermutationMap().isMinorIdentity()) | |||
3633 | return failure(); | |||
3634 | if (xferOp.getMask()) | |||
3635 | return failure(); | |||
3636 | auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); | |||
3637 | if (!extractOp) | |||
3638 | return failure(); | |||
3639 | if (!extractOp.hasUnitStride()) | |||
3640 | return failure(); | |||
3641 | ||||
3642 | // Bail on illegal rank-reduction: we need to check that the rank-reduced | |||
3643 | // dims are exactly the leading dims. I.e. the following is illegal: | |||
3644 | // ``` | |||
3645 | // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] : | |||
3646 | // tensor<2x1x4xf32> to tensor<2x4xf32> | |||
3647 | // %1 = vector.transfer_read %0[0,0], %cst : | |||
3648 | // tensor<2x4xf32>, vector<2x4xf32> | |||
3649 | // ``` | |||
3650 | // | |||
3651 | // Cannot fold into: | |||
3652 | // ``` | |||
3653 | // %0 = vector.transfer_read %t[0,0,0], %cst : | |||
3654 | // tensor<2x1x4xf32>, vector<2x4xf32> | |||
3655 | // ``` | |||
3656 | // For this, check the trailing `vectorRank` dims of the extract_slice | |||
3657 | // result tensor match the trailing dims of the inferred result tensor. | |||
3658 | if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank())) | |||
3659 | return failure(); | |||
3660 | ||||
3661 | int64_t rankReduced = | |||
3662 | extractOp.getSourceType().getRank() - extractOp.getType().getRank(); | |||
3663 | ||||
3664 | SmallVector<Value> newIndices; | |||
3665 | // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced | |||
3666 | // indices first. | |||
3667 | for (int64_t i = 0; i < rankReduced; ++i) { | |||
3668 | OpFoldResult offset = extractOp.getMixedOffsets()[i]; | |||
3669 | newIndices.push_back(getValueOrCreateConstantIndexOp( | |||
3670 | rewriter, extractOp.getLoc(), offset)); | |||
3671 | } | |||
3672 | for (const auto &it : llvm::enumerate(xferOp.getIndices())) { | |||
3673 | OpFoldResult offset = | |||
3674 | extractOp.getMixedOffsets()[it.index() + rankReduced]; | |||
3675 | newIndices.push_back(rewriter.create<arith::AddIOp>( | |||
3676 | xferOp->getLoc(), it.value(), | |||
3677 | getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), | |||
3678 | offset))); | |||
3679 | } | |||
3680 | SmallVector<bool> inBounds(xferOp.getTransferRank(), true); | |||
3681 | rewriter.replaceOpWithNewOp<TransferReadOp>( | |||
3682 | xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices, | |||
3683 | xferOp.getPadding(), ArrayRef<bool>{inBounds}); | |||
3684 | ||||
3685 | return success(); | |||
3686 | } | |||
3687 | }; | |||
3688 | ||||
3689 | /// Store to load forwarding for transfer operations with permuation maps. | |||
3690 | /// Even if the permutation maps are different we can still propagate the store | |||
3691 | /// into the load if the size of the dimensions read and written match. Then we | |||
3692 | /// can replace the transfer_read + transfer_write by vector.broadcast and | |||
3693 | /// vector.transpose. | |||
3694 | /// Example: | |||
3695 | /// ``` | |||
3696 | /// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0] | |||
3697 | /// {in_bounds = [true, true], | |||
3698 | /// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} : | |||
3699 | /// vector<4x1xf32>, tensor<4x4x4xf32> | |||
3700 | /// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0 | |||
3701 | /// {in_bounds = [true, true, true, true], | |||
3702 | /// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} : | |||
3703 | /// tensor<4x4x4xf32>, vector<1x100x4x5xf32> | |||
3704 | /// ``` | |||
3705 | /// To: | |||
3706 | /// ``` | |||
3707 | /// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32> | |||
3708 | /// %r = vector.transpose %0, [3, 0, 2, 1] : | |||
3709 | /// vector<100x5x4x1xf32> to vector<1x100x4x5xf32> | |||
3710 | /// ``` | |||
3711 | struct TransferReadAfterWriteToBroadcast | |||
3712 | : public OpRewritePattern<TransferReadOp> { | |||
3713 | using OpRewritePattern::OpRewritePattern; | |||
3714 | ||||
3715 | LogicalResult matchAndRewrite(TransferReadOp readOp, | |||
3716 | PatternRewriter &rewriter) const override { | |||
3717 | if (readOp.hasOutOfBoundsDim() || | |||
3718 | !readOp.getShapedType().isa<RankedTensorType>()) | |||
3719 | return failure(); | |||
3720 | auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>(); | |||
3721 | if (!defWrite) | |||
3722 | return failure(); | |||
3723 | ||||
3724 | SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed(); | |||
3725 | Value vec; | |||
3726 | if (readOp.getIndices() == defWrite.getIndices() && | |||
3727 | readOp.getMask() == defWrite.getMask()) { | |||
3728 | SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed(); | |||
3729 | // TODO: If the writeDim is a superset of the read dims we could do an | |||
3730 | // extract_strided_slice. | |||
3731 | if (writeDims == readDims) | |||
3732 | vec = defWrite.getVector(); | |||
3733 | } | |||
3734 | // TODO: loop through the chain of transfer_write if we can prove that they | |||
3735 | // don't overlap with the transfer_read. This requires improving | |||
3736 | // `isDisjointTransferIndices` helper. | |||
3737 | if (!vec) | |||
3738 | return failure(); | |||
3739 | SmallVector<unsigned> permutation; | |||
3740 | AffineMap readMap = compressUnusedDims(readOp.getPermutationMap()); | |||
3741 | AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap()); | |||
3742 | AffineMap map = readMap.compose(writeMap); | |||
3743 | if (map.getNumResults() == 0) | |||
3744 | return failure(); | |||
3745 | // Calculate the permuation to apply to go from the vector stored to the | |||
3746 | // vector read. | |||
3747 | if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation)) | |||
3748 | return failure(); | |||
3749 | ||||
3750 | Location loc = readOp.getLoc(); | |||
3751 | // Calculate the broadcast shape by applying the reverse permuation to the | |||
3752 | // final shape we want. | |||
3753 | ArrayRef<int64_t> destShape = readOp.getVectorType().getShape(); | |||
3754 | SmallVector<int64_t> broadcastShape(destShape.size()); | |||
3755 | for (const auto &pos : llvm::enumerate(permutation)) | |||
3756 | broadcastShape[pos.value()] = destShape[pos.index()]; | |||
3757 | VectorType broadcastedType = VectorType::get( | |||
3758 | broadcastShape, defWrite.getVectorType().getElementType()); | |||
3759 | vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec); | |||
3760 | SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end()); | |||
3761 | rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec, | |||
3762 | transposePerm); | |||
3763 | return success(); | |||
3764 | } | |||
3765 | }; | |||
3766 | } // namespace | |||
3767 | ||||
3768 | void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
3769 | MLIRContext *context) { | |||
3770 | results | |||
3771 | .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>( | |||
3772 | context); | |||
3773 | } | |||
3774 | ||||
3775 | //===----------------------------------------------------------------------===// | |||
3776 | // TransferWriteOp | |||
3777 | //===----------------------------------------------------------------------===// | |||
3778 | ||||
3779 | /// 1. Builder with type inference. | |||
3780 | void TransferWriteOp::build(OpBuilder &builder, OperationState &result, | |||
3781 | Value vector, Value dest, ValueRange indices, | |||
3782 | AffineMapAttr permutationMapAttr, | |||
3783 | /*optional*/ Value mask, | |||
3784 | /*optional*/ ArrayAttr inBoundsAttr) { | |||
3785 | Type resultType = dest.getType().dyn_cast<RankedTensorType>(); | |||
3786 | build(builder, result, resultType, vector, dest, indices, permutationMapAttr, | |||
3787 | mask, inBoundsAttr); | |||
3788 | } | |||
3789 | ||||
3790 | /// 2. Builder with type inference that sets an empty mask (variant with attrs). | |||
3791 | void TransferWriteOp::build(OpBuilder &builder, OperationState &result, | |||
3792 | Value vector, Value dest, ValueRange indices, | |||
3793 | AffineMapAttr permutationMapAttr, | |||
3794 | /*optional*/ ArrayAttr inBoundsAttr) { | |||
3795 | build(builder, result, vector, dest, indices, permutationMapAttr, | |||
3796 | /*mask=*/Value(), inBoundsAttr); | |||
3797 | } | |||
3798 | ||||
3799 | /// 3. Builder with type inference that sets an empty mask (variant without | |||
3800 | /// attrs) | |||
3801 | void TransferWriteOp::build(OpBuilder &builder, OperationState &result, | |||
3802 | Value vector, Value dest, ValueRange indices, | |||
3803 | AffineMap permutationMap, | |||
3804 | std::optional<ArrayRef<bool>> inBounds) { | |||
3805 | auto permutationMapAttr = AffineMapAttr::get(permutationMap); | |||
3806 | auto inBoundsAttr = (inBounds && !inBounds.value().empty()) | |||
3807 | ? builder.getBoolArrayAttr(inBounds.value()) | |||
3808 | : ArrayAttr(); | |||
3809 | build(builder, result, vector, dest, indices, permutationMapAttr, | |||
3810 | /*mask=*/Value(), inBoundsAttr); | |||
3811 | } | |||
3812 | ||||
3813 | /// 4. Builder with type inference that sets an empty mask and sets permutation | |||
3814 | /// map to 'getMinorIdentityMap'. | |||
3815 | void TransferWriteOp::build(OpBuilder &builder, OperationState &result, | |||
3816 | Value vector, Value dest, ValueRange indices, | |||
3817 | std::optional<ArrayRef<bool>> inBounds) { | |||
3818 | auto vectorType = vector.getType().cast<VectorType>(); | |||
3819 | AffineMap permutationMap = getTransferMinorIdentityMap( | |||
3820 | dest.getType().cast<ShapedType>(), vectorType); | |||
3821 | build(builder, result, vector, dest, indices, permutationMap, inBounds); | |||
3822 | } | |||
3823 | ||||
3824 | /// Infers the mask type for a transfer write given its vector type and | |||
3825 | /// permutation map. The mask in a transfer read operation applies to the | |||
3826 | /// tensor/buffer writing part of it and its type should match the shape written | |||
3827 | /// *after* any permutation. | |||
3828 | static VectorType inferTransferWriteMaskType(VectorType vecType, | |||
3829 | AffineMap permMap) { | |||
3830 | auto i1Type = IntegerType::get(permMap.getContext(), 1); | |||
3831 | SmallVector<int64_t, 8> maskShape = | |||
3832 | compressUnusedDims(permMap).compose(vecType.getShape()); | |||
3833 | return VectorType::get(maskShape, i1Type); | |||
3834 | } | |||
3835 | ||||
3836 | ParseResult TransferWriteOp::parse(OpAsmParser &parser, | |||
3837 | OperationState &result) { | |||
3838 | auto &builder = parser.getBuilder(); | |||
3839 | SMLoc typesLoc; | |||
3840 | OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo; | |||
3841 | SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo; | |||
3842 | SmallVector<Type, 2> types; | |||
3843 | OpAsmParser::UnresolvedOperand maskInfo; | |||
3844 | if (parser.parseOperand(vectorInfo) || parser.parseComma() || | |||
3845 | parser.parseOperand(sourceInfo) || | |||
3846 | parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square)) | |||
3847 | return failure(); | |||
3848 | ParseResult hasMask = parser.parseOptionalComma(); | |||
3849 | if (hasMask.succeeded() && parser.parseOperand(maskInfo)) | |||
3850 | return failure(); | |||
3851 | if (parser.parseOptionalAttrDict(result.attributes) || | |||
3852 | parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types)) | |||
3853 | return failure(); | |||
3854 | if (types.size() != 2) | |||
3855 | return parser.emitError(typesLoc, "requires two types"); | |||
3856 | auto indexType = builder.getIndexType(); | |||
3857 | VectorType vectorType = types[0].dyn_cast<VectorType>(); | |||
3858 | if (!vectorType) | |||
3859 | return parser.emitError(typesLoc, "requires vector type"); | |||
3860 | ShapedType shapedType = types[1].dyn_cast<ShapedType>(); | |||
3861 | if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>()) | |||
3862 | return parser.emitError(typesLoc, "requires memref or ranked tensor type"); | |||
3863 | auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName(); | |||
3864 | auto permMapAttr = result.attributes.get(permMapAttrName); | |||
3865 | AffineMap permMap; | |||
3866 | if (!permMapAttr) { | |||
3867 | permMap = getTransferMinorIdentityMap(shapedType, vectorType); | |||
3868 | result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap)); | |||
3869 | } else { | |||
3870 | permMap = permMapAttr.cast<AffineMapAttr>().getValue(); | |||
3871 | } | |||
3872 | if (parser.resolveOperand(vectorInfo, vectorType, result.operands) || | |||
3873 | parser.resolveOperand(sourceInfo, shapedType, result.operands) || | |||
3874 | parser.resolveOperands(indexInfo, indexType, result.operands)) | |||
3875 | return failure(); | |||
3876 | if (hasMask.succeeded()) { | |||
3877 | if (shapedType.getElementType().dyn_cast<VectorType>()) | |||
3878 | return parser.emitError( | |||
3879 | maskInfo.location, "does not support masks with vector element type"); | |||
3880 | auto maskType = inferTransferWriteMaskType(vectorType, permMap); | |||
3881 | if (parser.resolveOperand(maskInfo, maskType, result.operands)) | |||
3882 | return failure(); | |||
3883 | } | |||
3884 | result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(), | |||
3885 | builder.getDenseI32ArrayAttr( | |||
3886 | {1, 1, static_cast<int32_t>(indexInfo.size()), | |||
3887 | static_cast<int32_t>(hasMask.succeeded())})); | |||
3888 | return failure(shapedType.isa<RankedTensorType>() && | |||
3889 | parser.addTypeToList(shapedType, result.types)); | |||
3890 | } | |||
3891 | ||||
3892 | void TransferWriteOp::print(OpAsmPrinter &p) { | |||
3893 | p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]"; | |||
3894 | if (getMask()) | |||
3895 | p << ", " << getMask(); | |||
3896 | printTransferAttrs(p, *this); | |||
3897 | p << " : " << getVectorType() << ", " << getShapedType(); | |||
3898 | } | |||
3899 | ||||
3900 | LogicalResult TransferWriteOp::verify() { | |||
3901 | // Consistency of elemental types in shape and vector. | |||
3902 | ShapedType shapedType = getShapedType(); | |||
3903 | VectorType vectorType = getVectorType(); | |||
3904 | VectorType maskType = getMaskType(); | |||
3905 | auto permutationMap = getPermutationMap(); | |||
3906 | VectorType inferredMaskType = | |||
3907 | maskType ? inferTransferWriteMaskType(vectorType, permutationMap) | |||
3908 | : VectorType(); | |||
3909 | ||||
3910 | if (llvm::size(getIndices()) != shapedType.getRank()) | |||
3911 | return emitOpError("requires ") << shapedType.getRank() << " indices"; | |||
3912 | ||||
3913 | // We do not allow broadcast dimensions on TransferWriteOps for the moment, | |||
3914 | // as the semantics is unclear. This can be revisited later if necessary. | |||
3915 | if (hasBroadcastDim()) | |||
3916 | return emitOpError("should not have broadcast dimensions"); | |||
3917 | ||||
3918 | if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()), | |||
3919 | shapedType, vectorType, maskType, | |||
3920 | inferredMaskType, permutationMap, | |||
3921 | getInBounds() ? *getInBounds() : ArrayAttr()))) | |||
3922 | return failure(); | |||
3923 | ||||
3924 | return verifyPermutationMap(permutationMap, | |||
3925 | [&](Twine t) { return emitOpError(t); }); | |||
3926 | } | |||
3927 | ||||
3928 | // MaskableOpInterface methods. | |||
3929 | ||||
3930 | /// Returns the mask type expected by this operation. Mostly used for | |||
3931 | /// verification purposes. | |||
3932 | Type TransferWriteOp::getExpectedMaskType() { | |||
3933 | return inferTransferWriteMaskType(getVectorType(), getPermutationMap()); | |||
3934 | } | |||
3935 | ||||
3936 | /// Fold: | |||
3937 | /// ``` | |||
3938 | /// %t1 = ... | |||
3939 | /// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} : | |||
3940 | /// tensor<static_sizesxf32>, vector<static_sizesxf32> | |||
3941 | /// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} : | |||
3942 | /// vector<static_sizesxf32>, tensor<static_sizesxf32> | |||
3943 | /// ``` | |||
3944 | /// | |||
3945 | /// into: | |||
3946 | /// | |||
3947 | /// ``` | |||
3948 | /// %t0 | |||
3949 | /// ``` | |||
3950 | /// | |||
3951 | /// The producer of t1 may or may not be DCE'd depending on whether it is a | |||
3952 | /// block argument or has side effects. | |||
3953 | static LogicalResult foldReadInitWrite(TransferWriteOp write, | |||
3954 | ArrayRef<Attribute>, | |||
3955 | SmallVectorImpl<OpFoldResult> &results) { | |||
3956 | // TODO: support 0-d corner case. | |||
3957 | if (write.getTransferRank() == 0) | |||
3958 | return failure(); | |||
3959 | auto rankedTensorType = | |||
3960 | write.getSource().getType().dyn_cast<RankedTensorType>(); | |||
3961 | // If not operating on tensors, bail. | |||
3962 | if (!rankedTensorType) | |||
3963 | return failure(); | |||
3964 | // If no read, bail. | |||
3965 | auto read = write.getVector().getDefiningOp<vector::TransferReadOp>(); | |||
3966 | if (!read) | |||
3967 | return failure(); | |||
3968 | // TODO: support 0-d corner case. | |||
3969 | if (read.getTransferRank() == 0) | |||
3970 | return failure(); | |||
3971 | // For now, only accept minor identity. Future: composition is minor identity. | |||
3972 | if (!read.getPermutationMap().isMinorIdentity() || | |||
3973 | !write.getPermutationMap().isMinorIdentity()) | |||
3974 | return failure(); | |||
3975 | // Bail on mismatching ranks. | |||
3976 | if (read.getTransferRank() != write.getTransferRank()) | |||
3977 | return failure(); | |||
3978 | // Bail on potential out-of-bounds accesses. | |||
3979 | if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim()) | |||
3980 | return failure(); | |||
3981 | // Tensor types must be the same. | |||
3982 | if (read.getSource().getType() != rankedTensorType) | |||
3983 | return failure(); | |||
3984 | // Vector types must be the same. | |||
3985 | if (read.getVectorType() != write.getVectorType()) | |||
3986 | return failure(); | |||
3987 | // Vector and Tensor shapes must match. | |||
3988 | if (read.getVectorType().getShape() != rankedTensorType.getShape()) | |||
3989 | return failure(); | |||
3990 | // If any index is nonzero. | |||
3991 | auto isNotConstantZero = [](Value v) { | |||
3992 | auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>(); | |||
3993 | return !cstOp || cstOp.value() != 0; | |||
3994 | }; | |||
3995 | if (llvm::any_of(read.getIndices(), isNotConstantZero) || | |||
3996 | llvm::any_of(write.getIndices(), isNotConstantZero)) | |||
3997 | return failure(); | |||
3998 | // Success. | |||
3999 | results.push_back(read.getSource()); | |||
4000 | return success(); | |||
4001 | } | |||
4002 | ||||
4003 | static bool checkSameValueWAR(vector::TransferReadOp read, | |||
4004 | vector::TransferWriteOp write) { | |||
4005 | return read.getSource() == write.getSource() && | |||
4006 | read.getIndices() == write.getIndices() && | |||
4007 | read.getPermutationMap() == write.getPermutationMap() && | |||
4008 | read.getVectorType() == write.getVectorType() && !read.getMask() && | |||
4009 | !write.getMask(); | |||
4010 | } | |||
4011 | /// Fold transfer_write write after read: | |||
4012 | /// ``` | |||
4013 | /// %t0 = ... | |||
4014 | /// %v = vector.transfer_read %t0[%c0...] : | |||
4015 | /// tensor<static_sizesxf32>, vector<static_sizesxf32> | |||
4016 | /// %t1 = vector.transfer_write %v, %t0[%c0...] : | |||
4017 | /// vector<static_sizesxf32>, tensor<static_sizesxf32> | |||
4018 | /// ``` | |||
4019 | /// | |||
4020 | /// into: | |||
4021 | /// | |||
4022 | /// ``` | |||
4023 | /// %t0 | |||
4024 | /// ``` | |||
4025 | static LogicalResult foldWAR(TransferWriteOp write, | |||
4026 | SmallVectorImpl<OpFoldResult> &results) { | |||
4027 | if (!write.getSource().getType().isa<RankedTensorType>()) | |||
4028 | return failure(); | |||
4029 | auto read = write.getVector().getDefiningOp<vector::TransferReadOp>(); | |||
4030 | if (!read) | |||
4031 | return failure(); | |||
4032 | ||||
4033 | if (!checkSameValueWAR(read, write)) | |||
4034 | return failure(); | |||
4035 | results.push_back(read.getSource()); | |||
4036 | return success(); | |||
4037 | } | |||
4038 | ||||
4039 | LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands, | |||
4040 | SmallVectorImpl<OpFoldResult> &results) { | |||
4041 | if (succeeded(foldReadInitWrite(*this, operands, results))) | |||
4042 | return success(); | |||
4043 | if (succeeded(foldWAR(*this, results))) | |||
4044 | return success(); | |||
4045 | if (succeeded(foldTransferInBoundsAttribute(*this))) | |||
4046 | return success(); | |||
4047 | return memref::foldMemRefCast(*this); | |||
4048 | } | |||
4049 | ||||
4050 | std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() { | |||
4051 | return llvm::to_vector<4>(getVectorType().getShape()); | |||
4052 | } | |||
4053 | ||||
4054 | void TransferWriteOp::getEffects( | |||
4055 | SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | |||
4056 | &effects) { | |||
4057 | if (getShapedType().isa<MemRefType>()) | |||
4058 | effects.emplace_back(MemoryEffects::Write::get(), getSource(), | |||
4059 | SideEffects::DefaultResource::get()); | |||
4060 | } | |||
4061 | ||||
4062 | namespace { | |||
4063 | /// Remove dead transfer write from the SSA chain so that it an be eliminated by | |||
4064 | /// DCE | |||
4065 | /// ``` | |||
4066 | /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} | |||
4067 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4068 | /// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]} | |||
4069 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4070 | /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} | |||
4071 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4072 | /// ``` | |||
4073 | /// | |||
4074 | /// into: | |||
4075 | /// | |||
4076 | /// ``` | |||
4077 | /// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]} | |||
4078 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4079 | /// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]} | |||
4080 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4081 | /// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]} | |||
4082 | /// : vector<1x4xf32>, tensor<4x4xf32> | |||
4083 | /// ``` | |||
4084 | /// | |||
4085 | /// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have | |||
4086 | /// any other uses. | |||
4087 | class FoldWaw final : public OpRewritePattern<TransferWriteOp> { | |||
4088 | public: | |||
4089 | using OpRewritePattern::OpRewritePattern; | |||
4090 | LogicalResult matchAndRewrite(TransferWriteOp writeOp, | |||
4091 | PatternRewriter &rewriter) const override { | |||
4092 | if (!writeOp.getShapedType().isa<RankedTensorType>()) | |||
4093 | return failure(); | |||
4094 | vector::TransferWriteOp writeToModify = writeOp; | |||
4095 | ||||
4096 | auto defWrite = | |||
4097 | writeOp.getSource().getDefiningOp<vector::TransferWriteOp>(); | |||
4098 | while (defWrite) { | |||
4099 | if (checkSameValueWAW(writeOp, defWrite)) { | |||
4100 | writeToModify.getSourceMutable().assign(defWrite.getSource()); | |||
4101 | return success(); | |||
4102 | } | |||
4103 | if (!isDisjointTransferIndices( | |||
4104 | cast<VectorTransferOpInterface>(defWrite.getOperation()), | |||
4105 | cast<VectorTransferOpInterface>(writeOp.getOperation()))) | |||
4106 | break; | |||
4107 | // If the previous write op doesn't have any other use we an safely look | |||
4108 | // at the previous store to see if it can be removed. | |||
4109 | if (!defWrite->hasOneUse()) | |||
4110 | break; | |||
4111 | writeToModify = defWrite; | |||
4112 | defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>(); | |||
4113 | } | |||
4114 | return failure(); | |||
4115 | } | |||
4116 | }; | |||
4117 | ||||
4118 | /// Fold tensor.insert_slice into vector.transfer_write if the transfer_write | |||
4119 | /// could directly write to the insert_slice's destination. E.g.: | |||
4120 | /// | |||
4121 | /// ``` | |||
4122 | /// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]} | |||
4123 | /// : vector<4x5xf32>, tensor<4x5xf32> | |||
4124 | /// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1] | |||
4125 | /// : tensor<4x5xf32> into tensor<?x?xf32> | |||
4126 | /// ``` | |||
4127 | /// is rewritten to: | |||
4128 | /// ``` | |||
4129 | /// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]} | |||
4130 | /// : vector<4x5xf32>, tensor<?x?xf32> | |||
4131 | /// ``` | |||
4132 | struct FoldInsertSliceIntoTransferWrite | |||
4133 | : public OpRewritePattern<tensor::InsertSliceOp> { | |||
4134 | public: | |||
4135 | using OpRewritePattern::OpRewritePattern; | |||
4136 | ||||
4137 | LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, | |||
4138 | PatternRewriter &rewriter) const override { | |||
4139 | if (!insertOp.hasUnitStride()) | |||
4140 | return failure(); | |||
4141 | ||||
4142 | auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>(); | |||
4143 | if (!xferOp) | |||
4144 | return failure(); | |||
4145 | // TODO: support 0-d corner case. | |||
4146 | if (xferOp.getTransferRank() == 0) | |||
4147 | return failure(); | |||
4148 | ||||
4149 | if (xferOp.hasOutOfBoundsDim()) | |||
4150 | return failure(); | |||
4151 | if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank()) | |||
4152 | return failure(); | |||
4153 | if (xferOp.getMask()) | |||
4154 | return failure(); | |||
4155 | // Fold only if the TransferWriteOp completely overwrites the `source` with | |||
4156 | // a vector. I.e., the result of the TransferWriteOp is a new tensor whose | |||
4157 | // content is the data of the vector. | |||
4158 | if (!llvm::equal(xferOp.getVectorType().getShape(), | |||
4159 | xferOp.getShapedType().getShape())) | |||
4160 | return failure(); | |||
4161 | if (!xferOp.getPermutationMap().isIdentity()) | |||
4162 | return failure(); | |||
4163 | ||||
4164 | // Bail on illegal rank-reduction: we need to check that the rank-reduced | |||
4165 | // dims are exactly the leading dims. I.e. the following is illegal: | |||
4166 | // ``` | |||
4167 | // %0 = vector.transfer_write %v, %t[0,0], %cst : | |||
4168 | // vector<2x4xf32>, tensor<2x4xf32> | |||
4169 | // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] : | |||
4170 | // tensor<2x4xf32> into tensor<2x1x4xf32> | |||
4171 | // ``` | |||
4172 | // | |||
4173 | // Cannot fold into: | |||
4174 | // ``` | |||
4175 | // %0 = vector.transfer_write %v, %t[0,0,0], %cst : | |||
4176 | // vector<2x4xf32>, tensor<2x1x4xf32> | |||
4177 | // ``` | |||
4178 | // For this, check the trailing `vectorRank` dims of the insert_slice result | |||
4179 | // tensor match the trailing dims of the inferred result tensor. | |||
4180 | int64_t rankReduced = | |||
4181 | insertOp.getType().getRank() - insertOp.getSourceType().getRank(); | |||
4182 | int64_t vectorRank = xferOp.getVectorType().getRank(); | |||
4183 | RankedTensorType inferredSourceTensorType = | |||
4184 | tensor::ExtractSliceOp::inferResultType( | |||
4185 | insertOp.getType(), insertOp.getMixedOffsets(), | |||
4186 | insertOp.getMixedSizes(), insertOp.getMixedStrides()); | |||
4187 | auto actualSourceTensorShape = insertOp.getSourceType().getShape(); | |||
4188 | if (rankReduced > 0 && | |||
4189 | actualSourceTensorShape.take_back(vectorRank) != | |||
4190 | inferredSourceTensorType.getShape().take_back(vectorRank)) | |||
4191 | return failure(); | |||
4192 | ||||
4193 | SmallVector<Value> indices = getValueOrCreateConstantIndexOp( | |||
4194 | rewriter, insertOp.getLoc(), insertOp.getMixedOffsets()); | |||
4195 | SmallVector<bool> inBounds(xferOp.getTransferRank(), true); | |||
4196 | rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(), | |||
4197 | insertOp.getDest(), indices, | |||
4198 | ArrayRef<bool>{inBounds}); | |||
4199 | return success(); | |||
4200 | } | |||
4201 | }; | |||
4202 | ||||
4203 | /// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to | |||
4204 | /// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is | |||
4205 | /// overwritten and inserted into another tensor. After this rewrite, the | |||
4206 | /// operations bufferize in-place since all of them work on the same slice. | |||
4207 | /// | |||
4208 | /// For example: | |||
4209 | /// ```mlir | |||
4210 | /// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0] | |||
4211 | /// : vector<8x16xf32>, tensor<8x16xf32> | |||
4212 | /// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1] | |||
4213 | /// : tensor<8x16xf32> to tensor<?x?xf32> | |||
4214 | /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] | |||
4215 | /// : tensor<?x?xf32> into tensor<27x37xf32> | |||
4216 | /// ``` | |||
4217 | /// folds to | |||
4218 | /// ```mlir | |||
4219 | /// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] | |||
4220 | /// : tensor<27x37xf32> to tensor<?x?xf32> | |||
4221 | /// %1 = vector.transfer_write %vec, %0[%c0, %c0] | |||
4222 | /// : vector<8x16xf32>, tensor<?x?xf32> | |||
4223 | /// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1] | |||
4224 | /// : tensor<?x?xf32> into tensor<27x37xf32> | |||
4225 | /// ``` | |||
4226 | struct SwapExtractSliceOfTransferWrite | |||
4227 | : public OpRewritePattern<tensor::InsertSliceOp> { | |||
4228 | public: | |||
4229 | using OpRewritePattern::OpRewritePattern; | |||
4230 | ||||
4231 | LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp, | |||
4232 | PatternRewriter &rewriter) const override { | |||
4233 | if (!insertOp.hasUnitStride()) | |||
4234 | return failure(); | |||
4235 | auto extractOp = | |||
4236 | insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>(); | |||
4237 | if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse()) | |||
4238 | return failure(); | |||
4239 | auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>(); | |||
4240 | if (!transferOp || !transferOp->hasOneUse()) | |||
4241 | return failure(); | |||
4242 | ||||
4243 | // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is | |||
4244 | // rank-reducing. | |||
4245 | if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) { | |||
4246 | return rewriter.notifyMatchFailure(insertOp, | |||
4247 | "use-def chain is rank-reducing"); | |||
4248 | } | |||
4249 | ||||
4250 | // Fail if tensor::ExtractSliceOp has non-zero offset. | |||
4251 | if (!extractOp.hasZeroOffset()) { | |||
4252 | return rewriter.notifyMatchFailure(insertOp, | |||
4253 | "ExtractSliceOp has non-zero offset"); | |||
4254 | } | |||
4255 | ||||
4256 | // Fail if tensor::TransferWriteOp has non-zero offset. | |||
4257 | if (!llvm::all_of(transferOp.getIndices(), [](Value value) { | |||
4258 | return getConstantIntValue(value) == static_cast<int64_t>(0); | |||
4259 | })) { | |||
4260 | return rewriter.notifyMatchFailure(insertOp, | |||
4261 | "TranferWriteOp has non-zero offset"); | |||
4262 | } | |||
4263 | ||||
4264 | // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ. | |||
4265 | for (auto [insertSize, extractSize] : | |||
4266 | llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) { | |||
4267 | if (!isEqualConstantIntOrValue(insertSize, extractSize)) { | |||
4268 | return rewriter.notifyMatchFailure( | |||
4269 | insertOp, "InsertSliceOp and ExtractSliceOp sizes differ"); | |||
4270 | } | |||
4271 | } | |||
4272 | ||||
4273 | // Fail if the vector::TransferWriteOp may not overwrite the full tensor. | |||
4274 | assert(transferOp.getVectorType().hasStaticShape() &&(static_cast <bool> (transferOp.getVectorType().hasStaticShape () && "expected vector to have a static shape") ? void (0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4275, __extension__ __PRETTY_FUNCTION__)) | |||
4275 | "expected vector to have a static shape")(static_cast <bool> (transferOp.getVectorType().hasStaticShape () && "expected vector to have a static shape") ? void (0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4275, __extension__ __PRETTY_FUNCTION__)); | |||
4276 | ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape(); | |||
4277 | SmallVector<int64_t> resultShape = applyPermutationMap( | |||
4278 | transferOp.getPermutationMap(), transferOp.getShapedType().getShape()); | |||
4279 | if (transferOp.getMask() || !vectorShape.equals(resultShape)) { | |||
4280 | return rewriter.notifyMatchFailure( | |||
4281 | insertOp, "TransferWriteOp may not write the full tensor."); | |||
4282 | } | |||
4283 | ||||
4284 | // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp. | |||
4285 | SmallVector<int64_t> newResultShape = applyPermutationMap( | |||
4286 | transferOp.getPermutationMap(), insertOp.getSourceType().getShape()); | |||
4287 | SmallVector<bool> newInBounds; | |||
4288 | for (const auto &en : enumerate(newResultShape)) | |||
4289 | newInBounds.push_back(en.value() == vectorShape[en.index()]); | |||
4290 | auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>( | |||
4291 | extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(), | |||
4292 | insertOp.getMixedOffsets(), insertOp.getMixedSizes(), | |||
4293 | insertOp.getMixedStrides()); | |||
4294 | auto newTransferWriteOp = rewriter.create<TransferWriteOp>( | |||
4295 | transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(), | |||
4296 | transferOp.getIndices(), transferOp.getPermutationMapAttr(), | |||
4297 | rewriter.getBoolArrayAttr(newInBounds)); | |||
4298 | rewriter.updateRootInPlace(insertOp, [&]() { | |||
4299 | insertOp.getSourceMutable().assign(newTransferWriteOp.getResult()); | |||
4300 | }); | |||
4301 | return success(); | |||
4302 | } | |||
4303 | }; | |||
4304 | ||||
4305 | } // namespace | |||
4306 | ||||
4307 | void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4308 | MLIRContext *context) { | |||
4309 | results.add<FoldWaw, FoldInsertSliceIntoTransferWrite, | |||
4310 | SwapExtractSliceOfTransferWrite>(context); | |||
4311 | } | |||
4312 | ||||
4313 | //===----------------------------------------------------------------------===// | |||
4314 | // LoadOp | |||
4315 | //===----------------------------------------------------------------------===// | |||
4316 | ||||
4317 | static LogicalResult verifyLoadStoreMemRefLayout(Operation *op, | |||
4318 | MemRefType memRefTy) { | |||
4319 | if (!isLastMemrefDimUnitStride(memRefTy)) | |||
4320 | return op->emitOpError("most minor memref dim must have unit stride"); | |||
4321 | return success(); | |||
4322 | } | |||
4323 | ||||
4324 | LogicalResult vector::LoadOp::verify() { | |||
4325 | VectorType resVecTy = getVectorType(); | |||
4326 | MemRefType memRefTy = getMemRefType(); | |||
4327 | ||||
4328 | if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) | |||
4329 | return failure(); | |||
4330 | ||||
4331 | // Checks for vector memrefs. | |||
4332 | Type memElemTy = memRefTy.getElementType(); | |||
4333 | if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { | |||
4334 | if (memVecTy != resVecTy) | |||
4335 | return emitOpError("base memref and result vector types should match"); | |||
4336 | memElemTy = memVecTy.getElementType(); | |||
4337 | } | |||
4338 | ||||
4339 | if (resVecTy.getElementType() != memElemTy) | |||
4340 | return emitOpError("base and result element types should match"); | |||
4341 | if (llvm::size(getIndices()) != memRefTy.getRank()) | |||
4342 | return emitOpError("requires ") << memRefTy.getRank() << " indices"; | |||
4343 | return success(); | |||
4344 | } | |||
4345 | ||||
4346 | OpFoldResult LoadOp::fold(ArrayRef<Attribute>) { | |||
4347 | if (succeeded(memref::foldMemRefCast(*this))) | |||
4348 | return getResult(); | |||
4349 | return OpFoldResult(); | |||
4350 | } | |||
4351 | ||||
4352 | //===----------------------------------------------------------------------===// | |||
4353 | // StoreOp | |||
4354 | //===----------------------------------------------------------------------===// | |||
4355 | ||||
4356 | LogicalResult vector::StoreOp::verify() { | |||
4357 | VectorType valueVecTy = getVectorType(); | |||
4358 | MemRefType memRefTy = getMemRefType(); | |||
4359 | ||||
4360 | if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy))) | |||
4361 | return failure(); | |||
4362 | ||||
4363 | // Checks for vector memrefs. | |||
4364 | Type memElemTy = memRefTy.getElementType(); | |||
4365 | if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) { | |||
4366 | if (memVecTy != valueVecTy) | |||
4367 | return emitOpError( | |||
4368 | "base memref and valueToStore vector types should match"); | |||
4369 | memElemTy = memVecTy.getElementType(); | |||
4370 | } | |||
4371 | ||||
4372 | if (valueVecTy.getElementType() != memElemTy) | |||
4373 | return emitOpError("base and valueToStore element type should match"); | |||
4374 | if (llvm::size(getIndices()) != memRefTy.getRank()) | |||
4375 | return emitOpError("requires ") << memRefTy.getRank() << " indices"; | |||
4376 | return success(); | |||
4377 | } | |||
4378 | ||||
4379 | LogicalResult StoreOp::fold(ArrayRef<Attribute> operands, | |||
4380 | SmallVectorImpl<OpFoldResult> &results) { | |||
4381 | return memref::foldMemRefCast(*this); | |||
4382 | } | |||
4383 | ||||
4384 | //===----------------------------------------------------------------------===// | |||
4385 | // MaskedLoadOp | |||
4386 | //===----------------------------------------------------------------------===// | |||
4387 | ||||
4388 | LogicalResult MaskedLoadOp::verify() { | |||
4389 | VectorType maskVType = getMaskVectorType(); | |||
4390 | VectorType passVType = getPassThruVectorType(); | |||
4391 | VectorType resVType = getVectorType(); | |||
4392 | MemRefType memType = getMemRefType(); | |||
4393 | ||||
4394 | if (resVType.getElementType() != memType.getElementType()) | |||
4395 | return emitOpError("base and result element type should match"); | |||
4396 | if (llvm::size(getIndices()) != memType.getRank()) | |||
4397 | return emitOpError("requires ") << memType.getRank() << " indices"; | |||
4398 | if (resVType.getDimSize(0) != maskVType.getDimSize(0)) | |||
4399 | return emitOpError("expected result dim to match mask dim"); | |||
4400 | if (resVType != passVType) | |||
4401 | return emitOpError("expected pass_thru of same type as result type"); | |||
4402 | return success(); | |||
4403 | } | |||
4404 | ||||
4405 | namespace { | |||
4406 | class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> { | |||
4407 | public: | |||
4408 | using OpRewritePattern::OpRewritePattern; | |||
4409 | LogicalResult matchAndRewrite(MaskedLoadOp load, | |||
4410 | PatternRewriter &rewriter) const override { | |||
4411 | switch (getMaskFormat(load.getMask())) { | |||
4412 | case MaskFormat::AllTrue: | |||
4413 | rewriter.replaceOpWithNewOp<vector::LoadOp>( | |||
4414 | load, load.getType(), load.getBase(), load.getIndices()); | |||
4415 | return success(); | |||
4416 | case MaskFormat::AllFalse: | |||
4417 | rewriter.replaceOp(load, load.getPassThru()); | |||
4418 | return success(); | |||
4419 | case MaskFormat::Unknown: | |||
4420 | return failure(); | |||
4421 | } | |||
4422 | llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedLoad" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4422); | |||
4423 | } | |||
4424 | }; | |||
4425 | } // namespace | |||
4426 | ||||
4427 | void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4428 | MLIRContext *context) { | |||
4429 | results.add<MaskedLoadFolder>(context); | |||
4430 | } | |||
4431 | ||||
4432 | OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) { | |||
4433 | if (succeeded(memref::foldMemRefCast(*this))) | |||
4434 | return getResult(); | |||
4435 | return OpFoldResult(); | |||
4436 | } | |||
4437 | ||||
4438 | //===----------------------------------------------------------------------===// | |||
4439 | // MaskedStoreOp | |||
4440 | //===----------------------------------------------------------------------===// | |||
4441 | ||||
4442 | LogicalResult MaskedStoreOp::verify() { | |||
4443 | VectorType maskVType = getMaskVectorType(); | |||
4444 | VectorType valueVType = getVectorType(); | |||
4445 | MemRefType memType = getMemRefType(); | |||
4446 | ||||
4447 | if (valueVType.getElementType() != memType.getElementType()) | |||
4448 | return emitOpError("base and valueToStore element type should match"); | |||
4449 | if (llvm::size(getIndices()) != memType.getRank()) | |||
4450 | return emitOpError("requires ") << memType.getRank() << " indices"; | |||
4451 | if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) | |||
4452 | return emitOpError("expected valueToStore dim to match mask dim"); | |||
4453 | return success(); | |||
4454 | } | |||
4455 | ||||
4456 | namespace { | |||
4457 | class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> { | |||
4458 | public: | |||
4459 | using OpRewritePattern::OpRewritePattern; | |||
4460 | LogicalResult matchAndRewrite(MaskedStoreOp store, | |||
4461 | PatternRewriter &rewriter) const override { | |||
4462 | switch (getMaskFormat(store.getMask())) { | |||
4463 | case MaskFormat::AllTrue: | |||
4464 | rewriter.replaceOpWithNewOp<vector::StoreOp>( | |||
4465 | store, store.getValueToStore(), store.getBase(), store.getIndices()); | |||
4466 | return success(); | |||
4467 | case MaskFormat::AllFalse: | |||
4468 | rewriter.eraseOp(store); | |||
4469 | return success(); | |||
4470 | case MaskFormat::Unknown: | |||
4471 | return failure(); | |||
4472 | } | |||
4473 | llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedStore" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4473); | |||
4474 | } | |||
4475 | }; | |||
4476 | } // namespace | |||
4477 | ||||
4478 | void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4479 | MLIRContext *context) { | |||
4480 | results.add<MaskedStoreFolder>(context); | |||
4481 | } | |||
4482 | ||||
4483 | LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands, | |||
4484 | SmallVectorImpl<OpFoldResult> &results) { | |||
4485 | return memref::foldMemRefCast(*this); | |||
4486 | } | |||
4487 | ||||
4488 | //===----------------------------------------------------------------------===// | |||
4489 | // GatherOp | |||
4490 | //===----------------------------------------------------------------------===// | |||
4491 | ||||
4492 | LogicalResult GatherOp::verify() { | |||
4493 | VectorType indVType = getIndexVectorType(); | |||
4494 | VectorType maskVType = getMaskVectorType(); | |||
4495 | VectorType resVType = getVectorType(); | |||
4496 | ShapedType baseType = getBaseType(); | |||
4497 | ||||
4498 | if (!baseType.isa<MemRefType, RankedTensorType>()) | |||
4499 | return emitOpError("requires base to be a memref or ranked tensor type"); | |||
4500 | ||||
4501 | if (resVType.getElementType() != baseType.getElementType()) | |||
4502 | return emitOpError("base and result element type should match"); | |||
4503 | if (llvm::size(getIndices()) != baseType.getRank()) | |||
4504 | return emitOpError("requires ") << baseType.getRank() << " indices"; | |||
4505 | if (resVType.getShape() != indVType.getShape()) | |||
4506 | return emitOpError("expected result dim to match indices dim"); | |||
4507 | if (resVType.getShape() != maskVType.getShape()) | |||
4508 | return emitOpError("expected result dim to match mask dim"); | |||
4509 | if (resVType != getPassThruVectorType()) | |||
4510 | return emitOpError("expected pass_thru of same type as result type"); | |||
4511 | return success(); | |||
4512 | } | |||
4513 | ||||
4514 | namespace { | |||
4515 | class GatherFolder final : public OpRewritePattern<GatherOp> { | |||
4516 | public: | |||
4517 | using OpRewritePattern::OpRewritePattern; | |||
4518 | LogicalResult matchAndRewrite(GatherOp gather, | |||
4519 | PatternRewriter &rewriter) const override { | |||
4520 | switch (getMaskFormat(gather.getMask())) { | |||
4521 | case MaskFormat::AllTrue: | |||
4522 | return failure(); // no unmasked equivalent | |||
4523 | case MaskFormat::AllFalse: | |||
4524 | rewriter.replaceOp(gather, gather.getPassThru()); | |||
4525 | return success(); | |||
4526 | case MaskFormat::Unknown: | |||
4527 | return failure(); | |||
4528 | } | |||
4529 | llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on GatherFolder" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4529); | |||
4530 | } | |||
4531 | }; | |||
4532 | } // namespace | |||
4533 | ||||
4534 | void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4535 | MLIRContext *context) { | |||
4536 | results.add<GatherFolder>(context); | |||
4537 | } | |||
4538 | ||||
4539 | //===----------------------------------------------------------------------===// | |||
4540 | // ScatterOp | |||
4541 | //===----------------------------------------------------------------------===// | |||
4542 | ||||
4543 | LogicalResult ScatterOp::verify() { | |||
4544 | VectorType indVType = getIndexVectorType(); | |||
4545 | VectorType maskVType = getMaskVectorType(); | |||
4546 | VectorType valueVType = getVectorType(); | |||
4547 | MemRefType memType = getMemRefType(); | |||
4548 | ||||
4549 | if (valueVType.getElementType() != memType.getElementType()) | |||
4550 | return emitOpError("base and valueToStore element type should match"); | |||
4551 | if (llvm::size(getIndices()) != memType.getRank()) | |||
4552 | return emitOpError("requires ") << memType.getRank() << " indices"; | |||
4553 | if (valueVType.getDimSize(0) != indVType.getDimSize(0)) | |||
4554 | return emitOpError("expected valueToStore dim to match indices dim"); | |||
4555 | if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) | |||
4556 | return emitOpError("expected valueToStore dim to match mask dim"); | |||
4557 | return success(); | |||
4558 | } | |||
4559 | ||||
4560 | namespace { | |||
4561 | class ScatterFolder final : public OpRewritePattern<ScatterOp> { | |||
4562 | public: | |||
4563 | using OpRewritePattern::OpRewritePattern; | |||
4564 | LogicalResult matchAndRewrite(ScatterOp scatter, | |||
4565 | PatternRewriter &rewriter) const override { | |||
4566 | switch (getMaskFormat(scatter.getMask())) { | |||
4567 | case MaskFormat::AllTrue: | |||
4568 | return failure(); // no unmasked equivalent | |||
4569 | case MaskFormat::AllFalse: | |||
4570 | rewriter.eraseOp(scatter); | |||
4571 | return success(); | |||
4572 | case MaskFormat::Unknown: | |||
4573 | return failure(); | |||
4574 | } | |||
4575 | llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ScatterFolder" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4575); | |||
4576 | } | |||
4577 | }; | |||
4578 | } // namespace | |||
4579 | ||||
4580 | void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4581 | MLIRContext *context) { | |||
4582 | results.add<ScatterFolder>(context); | |||
4583 | } | |||
4584 | ||||
4585 | //===----------------------------------------------------------------------===// | |||
4586 | // ExpandLoadOp | |||
4587 | //===----------------------------------------------------------------------===// | |||
4588 | ||||
4589 | LogicalResult ExpandLoadOp::verify() { | |||
4590 | VectorType maskVType = getMaskVectorType(); | |||
4591 | VectorType passVType = getPassThruVectorType(); | |||
4592 | VectorType resVType = getVectorType(); | |||
4593 | MemRefType memType = getMemRefType(); | |||
4594 | ||||
4595 | if (resVType.getElementType() != memType.getElementType()) | |||
4596 | return emitOpError("base and result element type should match"); | |||
4597 | if (llvm::size(getIndices()) != memType.getRank()) | |||
4598 | return emitOpError("requires ") << memType.getRank() << " indices"; | |||
4599 | if (resVType.getDimSize(0) != maskVType.getDimSize(0)) | |||
4600 | return emitOpError("expected result dim to match mask dim"); | |||
4601 | if (resVType != passVType) | |||
4602 | return emitOpError("expected pass_thru of same type as result type"); | |||
4603 | return success(); | |||
4604 | } | |||
4605 | ||||
4606 | namespace { | |||
4607 | class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> { | |||
4608 | public: | |||
4609 | using OpRewritePattern::OpRewritePattern; | |||
4610 | LogicalResult matchAndRewrite(ExpandLoadOp expand, | |||
4611 | PatternRewriter &rewriter) const override { | |||
4612 | switch (getMaskFormat(expand.getMask())) { | |||
4613 | case MaskFormat::AllTrue: | |||
4614 | rewriter.replaceOpWithNewOp<vector::LoadOp>( | |||
4615 | expand, expand.getType(), expand.getBase(), expand.getIndices()); | |||
4616 | return success(); | |||
4617 | case MaskFormat::AllFalse: | |||
4618 | rewriter.replaceOp(expand, expand.getPassThru()); | |||
4619 | return success(); | |||
4620 | case MaskFormat::Unknown: | |||
4621 | return failure(); | |||
4622 | } | |||
4623 | llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ExpandLoadFolder" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4623); | |||
4624 | } | |||
4625 | }; | |||
4626 | } // namespace | |||
4627 | ||||
4628 | void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4629 | MLIRContext *context) { | |||
4630 | results.add<ExpandLoadFolder>(context); | |||
4631 | } | |||
4632 | ||||
4633 | //===----------------------------------------------------------------------===// | |||
4634 | // CompressStoreOp | |||
4635 | //===----------------------------------------------------------------------===// | |||
4636 | ||||
4637 | LogicalResult CompressStoreOp::verify() { | |||
4638 | VectorType maskVType = getMaskVectorType(); | |||
4639 | VectorType valueVType = getVectorType(); | |||
4640 | MemRefType memType = getMemRefType(); | |||
4641 | ||||
4642 | if (valueVType.getElementType() != memType.getElementType()) | |||
4643 | return emitOpError("base and valueToStore element type should match"); | |||
4644 | if (llvm::size(getIndices()) != memType.getRank()) | |||
4645 | return emitOpError("requires ") << memType.getRank() << " indices"; | |||
4646 | if (valueVType.getDimSize(0) != maskVType.getDimSize(0)) | |||
4647 | return emitOpError("expected valueToStore dim to match mask dim"); | |||
4648 | return success(); | |||
4649 | } | |||
4650 | ||||
4651 | namespace { | |||
4652 | class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> { | |||
4653 | public: | |||
4654 | using OpRewritePattern::OpRewritePattern; | |||
4655 | LogicalResult matchAndRewrite(CompressStoreOp compress, | |||
4656 | PatternRewriter &rewriter) const override { | |||
4657 | switch (getMaskFormat(compress.getMask())) { | |||
4658 | case MaskFormat::AllTrue: | |||
4659 | rewriter.replaceOpWithNewOp<vector::StoreOp>( | |||
4660 | compress, compress.getValueToStore(), compress.getBase(), | |||
4661 | compress.getIndices()); | |||
4662 | return success(); | |||
4663 | case MaskFormat::AllFalse: | |||
4664 | rewriter.eraseOp(compress); | |||
4665 | return success(); | |||
4666 | case MaskFormat::Unknown: | |||
4667 | return failure(); | |||
4668 | } | |||
4669 | llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on CompressStoreFolder" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4669); | |||
4670 | } | |||
4671 | }; | |||
4672 | } // namespace | |||
4673 | ||||
4674 | void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4675 | MLIRContext *context) { | |||
4676 | results.add<CompressStoreFolder>(context); | |||
4677 | } | |||
4678 | ||||
4679 | //===----------------------------------------------------------------------===// | |||
4680 | // ShapeCastOp | |||
4681 | //===----------------------------------------------------------------------===// | |||
4682 | ||||
4683 | /// Returns true if each element of 'a' is equal to the product of a contiguous | |||
4684 | /// sequence of the elements of 'b'. Returns false otherwise. | |||
4685 | static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) { | |||
4686 | unsigned rankA = a.size(); | |||
4687 | unsigned rankB = b.size(); | |||
4688 | assert(rankA < rankB)(static_cast <bool> (rankA < rankB) ? void (0) : __assert_fail ("rankA < rankB", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp" , 4688, __extension__ __PRETTY_FUNCTION__)); | |||
4689 | ||||
4690 | unsigned i = 0; | |||
4691 | unsigned j = 0; | |||
4692 | while (i < rankA && j < rankB) { | |||
4693 | int64_t dimA = a[i]; | |||
4694 | int64_t dimB = 1; | |||
4695 | while (dimB < dimA && j < rankB) | |||
4696 | dimB *= b[j++]; | |||
4697 | if (dimA != dimB) | |||
4698 | break; | |||
4699 | ++i; | |||
4700 | ||||
4701 | // Handle the case when trailing dimensions are of size 1. | |||
4702 | // Include them into the contiguous sequence. | |||
4703 | auto isOne = [](int64_t v) { return v == 1; }; | |||
4704 | if (i < rankA && llvm::all_of(a.slice(i), isOne)) | |||
4705 | i = rankA; | |||
4706 | if (j < rankB && llvm::all_of(b.slice(j), isOne)) | |||
4707 | j = rankB; | |||
4708 | } | |||
4709 | ||||
4710 | return i == rankA && j == rankB; | |||
4711 | } | |||
4712 | ||||
4713 | static LogicalResult verifyVectorShapeCast(Operation *op, | |||
4714 | VectorType sourceVectorType, | |||
4715 | VectorType resultVectorType) { | |||
4716 | // Check that element type is the same. | |||
4717 | if (sourceVectorType.getElementType() != resultVectorType.getElementType()) | |||
4718 | return op->emitOpError("source/result vectors must have same element type"); | |||
4719 | auto sourceShape = sourceVectorType.getShape(); | |||
4720 | auto resultShape = resultVectorType.getShape(); | |||
4721 | ||||
4722 | // Check that product of source dim sizes matches product of result dim sizes. | |||
4723 | int64_t sourceDimProduct = std::accumulate( | |||
4724 | sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{}); | |||
4725 | int64_t resultDimProduct = std::accumulate( | |||
4726 | resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{}); | |||
4727 | if (sourceDimProduct != resultDimProduct) | |||
4728 | return op->emitOpError("source/result number of elements must match"); | |||
4729 | ||||
4730 | // Check that expanding/contracting rank cases. | |||
4731 | unsigned sourceRank = sourceVectorType.getRank(); | |||
4732 | unsigned resultRank = resultVectorType.getRank(); | |||
4733 | if (sourceRank < resultRank) { | |||
4734 | if (!isValidShapeCast(sourceShape, resultShape)) | |||
4735 | return op->emitOpError("invalid shape cast"); | |||
4736 | } else if (sourceRank > resultRank) { | |||
4737 | if (!isValidShapeCast(resultShape, sourceShape)) | |||
4738 | return op->emitOpError("invalid shape cast"); | |||
4739 | } | |||
4740 | return success(); | |||
4741 | } | |||
4742 | ||||
4743 | LogicalResult ShapeCastOp::verify() { | |||
4744 | auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>(); | |||
4745 | auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>(); | |||
4746 | ||||
4747 | // Check if source/result are of vector type. | |||
4748 | if (sourceVectorType && resultVectorType) | |||
4749 | return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType); | |||
4750 | ||||
4751 | return success(); | |||
4752 | } | |||
4753 | ||||
4754 | OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) { | |||
4755 | // No-op shape cast. | |||
4756 | if (getSource().getType() == getResult().getType()) | |||
4757 | return getSource(); | |||
4758 | ||||
4759 | // Canceling shape casts. | |||
4760 | if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) { | |||
4761 | if (getResult().getType() == otherOp.getSource().getType()) | |||
4762 | return otherOp.getSource(); | |||
4763 | ||||
4764 | // Only allows valid transitive folding. | |||
4765 | VectorType srcType = otherOp.getSource().getType().cast<VectorType>(); | |||
4766 | VectorType resultType = getResult().getType().cast<VectorType>(); | |||
4767 | if (srcType.getRank() < resultType.getRank()) { | |||
4768 | if (!isValidShapeCast(srcType.getShape(), resultType.getShape())) | |||
4769 | return {}; | |||
4770 | } else if (srcType.getRank() > resultType.getRank()) { | |||
4771 | if (!isValidShapeCast(resultType.getShape(), srcType.getShape())) | |||
4772 | return {}; | |||
4773 | } else { | |||
4774 | return {}; | |||
4775 | } | |||
4776 | ||||
4777 | setOperand(otherOp.getSource()); | |||
4778 | return getResult(); | |||
4779 | } | |||
4780 | ||||
4781 | // Cancelling broadcast and shape cast ops. | |||
4782 | if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) { | |||
4783 | if (bcastOp.getSourceType() == getType()) | |||
4784 | return bcastOp.getSource(); | |||
4785 | } | |||
4786 | ||||
4787 | return {}; | |||
4788 | } | |||
4789 | ||||
4790 | namespace { | |||
4791 | // Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp. | |||
4792 | class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> { | |||
4793 | public: | |||
4794 | using OpRewritePattern::OpRewritePattern; | |||
4795 | ||||
4796 | LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, | |||
4797 | PatternRewriter &rewriter) const override { | |||
4798 | auto constantOp = | |||
4799 | shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>(); | |||
4800 | if (!constantOp) | |||
4801 | return failure(); | |||
4802 | // Only handle splat for now. | |||
4803 | auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>(); | |||
4804 | if (!dense) | |||
4805 | return failure(); | |||
4806 | auto newAttr = | |||
4807 | DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(), | |||
4808 | dense.getSplatValue<Attribute>()); | |||
4809 | rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr); | |||
4810 | return success(); | |||
4811 | } | |||
4812 | }; | |||
4813 | ||||
4814 | /// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast. | |||
4815 | /// This only applies when the shape of the broadcast source is a suffix of the | |||
4816 | /// shape of the result (i.e. when broadcast without reshape is expressive | |||
4817 | /// enough to capture the result in a single op). | |||
4818 | class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> { | |||
4819 | public: | |||
4820 | using OpRewritePattern::OpRewritePattern; | |||
4821 | ||||
4822 | LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp, | |||
4823 | PatternRewriter &rewriter) const override { | |||
4824 | auto broadcastOp = | |||
4825 | shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>(); | |||
4826 | if (!broadcastOp) | |||
4827 | return failure(); | |||
4828 | ||||
4829 | auto broadcastSourceVectorType = | |||
4830 | broadcastOp.getSourceType().dyn_cast<VectorType>(); | |||
4831 | auto broadcastSourceShape = broadcastSourceVectorType | |||
4832 | ? broadcastSourceVectorType.getShape() | |||
4833 | : ArrayRef<int64_t>{}; | |||
4834 | auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape(); | |||
4835 | ||||
4836 | // Bail if `broadcastSourceShape` is not a suffix of the result. | |||
4837 | bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back( | |||
4838 | broadcastSourceShape.size())); | |||
4839 | if (!isSuffix) | |||
4840 | return failure(); | |||
4841 | ||||
4842 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | |||
4843 | shapeCastOp, shapeCastOp.getResultVectorType(), | |||
4844 | broadcastOp.getSource()); | |||
4845 | return success(); | |||
4846 | } | |||
4847 | }; | |||
4848 | ||||
4849 | } // namespace | |||
4850 | ||||
4851 | void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
4852 | MLIRContext *context) { | |||
4853 | results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context); | |||
4854 | } | |||
4855 | ||||
4856 | //===----------------------------------------------------------------------===// | |||
4857 | // VectorBitCastOp | |||
4858 | //===----------------------------------------------------------------------===// | |||
4859 | ||||
4860 | LogicalResult BitCastOp::verify() { | |||
4861 | auto sourceVectorType = getSourceVectorType(); | |||
4862 | auto resultVectorType = getResultVectorType(); | |||
4863 | ||||
4864 | for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) { | |||
4865 | if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i)) | |||
4866 | return emitOpError("dimension size mismatch at: ") << i; | |||
4867 | } | |||
4868 | ||||
4869 | DataLayout dataLayout = DataLayout::closest(*this); | |||
4870 | auto sourceElementBits = | |||
4871 | dataLayout.getTypeSizeInBits(sourceVectorType.getElementType()); | |||
4872 | auto resultElementBits = | |||
4873 | dataLayout.getTypeSizeInBits(resultVectorType.getElementType()); | |||
4874 | ||||
4875 | if (sourceVectorType.getRank() == 0) { | |||
4876 | if (sourceElementBits != resultElementBits) | |||
4877 | return emitOpError("source/result bitwidth of the 0-D vector element " | |||
4878 | "types must be equal"); | |||
4879 | } else if (sourceElementBits * sourceVectorType.getShape().back() != | |||
4880 | resultElementBits * resultVectorType.getShape().back()) { | |||
4881 | return emitOpError( | |||
4882 | "source/result bitwidth of the minor 1-D vectors must be equal"); | |||
4883 | } | |||
4884 | ||||
4885 | return success(); | |||
4886 | } | |||
4887 | ||||
4888 | OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) { | |||
4889 | // Nop cast. | |||
4890 | if (getSource().getType() == getResult().getType()) | |||
4891 | return getSource(); | |||
4892 | ||||
4893 | // Canceling bitcasts. | |||
4894 | if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) { | |||
4895 | if (getResult().getType() == otherOp.getSource().getType()) | |||
4896 | return otherOp.getSource(); | |||
4897 | ||||
4898 | setOperand(otherOp.getSource()); | |||
4899 | return getResult(); | |||
4900 | } | |||
4901 | ||||
4902 | Attribute sourceConstant = operands.front(); | |||
4903 | if (!sourceConstant) | |||
4904 | return {}; | |||
4905 | ||||
4906 | Type srcElemType = getSourceVectorType().getElementType(); | |||
4907 | Type dstElemType = getResultVectorType().getElementType(); | |||
4908 | ||||
4909 | if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) { | |||
4910 | if (floatPack.isSplat()) { | |||
4911 | auto splat = floatPack.getSplatValue<FloatAttr>(); | |||
4912 | ||||
4913 | // Casting fp16 into fp32. | |||
4914 | if (srcElemType.isF16() && dstElemType.isF32()) { | |||
4915 | uint32_t bits = static_cast<uint32_t>( | |||
4916 | splat.getValue().bitcastToAPInt().getZExtValue()); | |||
4917 | // Duplicate the 16-bit pattern. | |||
4918 | bits = (bits << 16) | (bits & 0xffff); | |||
4919 | APInt intBits(32, bits); | |||
4920 | APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits); | |||
4921 | return DenseElementsAttr::get(getResultVectorType(), floatBits); | |||
4922 | } | |||
4923 | } | |||
4924 | } | |||
4925 | ||||
4926 | return {}; | |||
4927 | } | |||
4928 | ||||
4929 | //===----------------------------------------------------------------------===// | |||
4930 | // TypeCastOp | |||
4931 | //===----------------------------------------------------------------------===// | |||
4932 | ||||
4933 | static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) { | |||
4934 | auto vectorType = memRefType.getElementType().dyn_cast<VectorType>(); | |||
4935 | SmallVector<int64_t, 8> res(memRefType.getShape().begin(), | |||
4936 | memRefType.getShape().end()); | |||
4937 | if (vectorType) | |||
4938 | res.append(vectorType.getShape().begin(), vectorType.getShape().end()); | |||
4939 | return res; | |||
4940 | } | |||
4941 | ||||
4942 | /// Build the canonical memRefType with a single vector. | |||
4943 | /// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>. | |||
4944 | void TypeCastOp::build(OpBuilder &builder, OperationState &result, | |||
4945 | Value source) { | |||
4946 | result.addOperands(source); | |||
4947 | MemRefType memRefType = source.getType().cast<MemRefType>(); | |||
4948 | VectorType vectorType = | |||
4949 | VectorType::get(extractShape(memRefType), | |||
4950 | getElementTypeOrSelf(getElementTypeOrSelf(memRefType))); | |||
4951 | result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(), | |||
4952 | memRefType.getMemorySpace())); | |||
4953 | } | |||
4954 | ||||
4955 | LogicalResult TypeCastOp::verify() { | |||
4956 | MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType()); | |||
4957 | if (!canonicalType.getLayout().isIdentity()) | |||
4958 | return emitOpError("expects operand to be a memref with identity layout"); | |||
4959 | if (!getResultMemRefType().getLayout().isIdentity()) | |||
4960 | return emitOpError("expects result to be a memref with identity layout"); | |||
4961 | if (getResultMemRefType().getMemorySpace() != | |||
4962 | getMemRefType().getMemorySpace()) | |||
4963 | return emitOpError("expects result in same memory space"); | |||
4964 | ||||
4965 | auto sourceType = getMemRefType(); | |||
4966 | auto resultType = getResultMemRefType(); | |||
4967 | if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) != | |||
4968 | getElementTypeOrSelf(getElementTypeOrSelf(resultType))) | |||
4969 | return emitOpError( | |||
4970 | "expects result and operand with same underlying scalar type: ") | |||
4971 | << resultType; | |||
4972 | if (extractShape(sourceType) != extractShape(resultType)) | |||
4973 | return emitOpError( | |||
4974 | "expects concatenated result and operand shapes to be equal: ") | |||
4975 | << resultType; | |||
4976 | return success(); | |||
4977 | } | |||
4978 | ||||
4979 | //===----------------------------------------------------------------------===// | |||
4980 | // TransposeOp | |||
4981 | //===----------------------------------------------------------------------===// | |||
4982 | ||||
4983 | void vector::TransposeOp::build(OpBuilder &builder, OperationState &result, | |||
4984 | Value vector, ArrayRef<int64_t> transp) { | |||
4985 | VectorType vt = vector.getType().cast<VectorType>(); | |||
4986 | SmallVector<int64_t, 4> transposedShape(vt.getRank()); | |||
4987 | for (unsigned i = 0; i < transp.size(); ++i) | |||
4988 | transposedShape[i] = vt.getShape()[transp[i]]; | |||
4989 | ||||
4990 | result.addOperands(vector); | |||
4991 | result.addTypes(VectorType::get(transposedShape, vt.getElementType())); | |||
4992 | result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp)); | |||
4993 | } | |||
4994 | ||||
4995 | OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) { | |||
4996 | // Eliminate splat constant transpose ops. | |||
4997 | if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>()) | |||
4998 | if (attr.isSplat()) | |||
4999 | return attr.reshape(getResultType()); | |||
5000 | ||||
5001 | // Eliminate identity transpose ops. This happens when the dimensions of the | |||
5002 | // input vector remain in their original order after the transpose operation. | |||
5003 | SmallVector<int64_t, 4> transp; | |||
5004 | getTransp(transp); | |||
5005 | ||||
5006 | // Check if the permutation of the dimensions contains sequential values: | |||
5007 | // {0, 1, 2, ...}. | |||
5008 | for (int64_t i = 0, e = transp.size(); i < e; i++) { | |||
5009 | if (transp[i] != i) | |||
5010 | return {}; | |||
5011 | } | |||
5012 | ||||
5013 | return getVector(); | |||
5014 | } | |||
5015 | ||||
5016 | LogicalResult vector::TransposeOp::verify() { | |||
5017 | VectorType vectorType = getVectorType(); | |||
5018 | VectorType resultType = getResultType(); | |||
5019 | int64_t rank = resultType.getRank(); | |||
5020 | if (vectorType.getRank() != rank) | |||
5021 | return emitOpError("vector result rank mismatch: ") << rank; | |||
5022 | // Verify transposition array. | |||
5023 | auto transpAttr = getTransp().getValue(); | |||
5024 | int64_t size = transpAttr.size(); | |||
5025 | if (rank != size) | |||
5026 | return emitOpError("transposition length mismatch: ") << size; | |||
5027 | SmallVector<bool, 8> seen(rank, false); | |||
5028 | for (const auto &ta : llvm::enumerate(transpAttr)) { | |||
5029 | int64_t i = ta.value().cast<IntegerAttr>().getInt(); | |||
5030 | if (i < 0 || i >= rank) | |||
5031 | return emitOpError("transposition index out of range: ") << i; | |||
5032 | if (seen[i]) | |||
5033 | return emitOpError("duplicate position index: ") << i; | |||
5034 | seen[i] = true; | |||
5035 | if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i)) | |||
5036 | return emitOpError("dimension size mismatch at: ") << i; | |||
5037 | } | |||
5038 | return success(); | |||
5039 | } | |||
5040 | ||||
5041 | std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() { | |||
5042 | return llvm::to_vector<4>(getResultType().getShape()); | |||
5043 | } | |||
5044 | ||||
5045 | namespace { | |||
5046 | ||||
5047 | // Rewrites two back-to-back TransposeOp operations into a single TransposeOp. | |||
5048 | class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> { | |||
5049 | public: | |||
5050 | using OpRewritePattern::OpRewritePattern; | |||
5051 | ||||
5052 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, | |||
5053 | PatternRewriter &rewriter) const override { | |||
5054 | // Wrapper around vector::TransposeOp::getTransp() for cleaner code. | |||
5055 | auto getPermutation = [](vector::TransposeOp transpose) { | |||
5056 | SmallVector<int64_t, 4> permutation; | |||
5057 | transpose.getTransp(permutation); | |||
5058 | return permutation; | |||
5059 | }; | |||
5060 | ||||
5061 | // Composes two permutations: result[i] = permutation1[permutation2[i]]. | |||
5062 | auto composePermutations = [](ArrayRef<int64_t> permutation1, | |||
5063 | ArrayRef<int64_t> permutation2) { | |||
5064 | SmallVector<int64_t, 4> result; | |||
5065 | for (auto index : permutation2) | |||
5066 | result.push_back(permutation1[index]); | |||
5067 | return result; | |||
5068 | }; | |||
5069 | ||||
5070 | // Return if the input of 'transposeOp' is not defined by another transpose. | |||
5071 | vector::TransposeOp parentTransposeOp = | |||
5072 | transposeOp.getVector().getDefiningOp<vector::TransposeOp>(); | |||
5073 | if (!parentTransposeOp) | |||
5074 | return failure(); | |||
5075 | ||||
5076 | SmallVector<int64_t, 4> permutation = composePermutations( | |||
5077 | getPermutation(parentTransposeOp), getPermutation(transposeOp)); | |||
5078 | // Replace 'transposeOp' with a new transpose operation. | |||
5079 | rewriter.replaceOpWithNewOp<vector::TransposeOp>( | |||
5080 | transposeOp, transposeOp.getResult().getType(), | |||
5081 | parentTransposeOp.getVector(), | |||
5082 | vector::getVectorSubscriptAttr(rewriter, permutation)); | |||
5083 | return success(); | |||
5084 | } | |||
5085 | }; | |||
5086 | ||||
5087 | // Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>). | |||
5088 | struct FoldTransposedScalarBroadcast final | |||
5089 | : public OpRewritePattern<vector::TransposeOp> { | |||
5090 | using OpRewritePattern::OpRewritePattern; | |||
5091 | ||||
5092 | LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, | |||
5093 | PatternRewriter &rewriter) const override { | |||
5094 | auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>(); | |||
5095 | if (!bcastOp) | |||
5096 | return failure(); | |||
5097 | ||||
5098 | auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>(); | |||
5099 | if (!srcVectorType || srcVectorType.getNumElements() == 1) { | |||
5100 | rewriter.replaceOpWithNewOp<vector::BroadcastOp>( | |||
5101 | transposeOp, transposeOp.getResultType(), bcastOp.getSource()); | |||
5102 | return success(); | |||
5103 | } | |||
5104 | ||||
5105 | return failure(); | |||
5106 | } | |||
5107 | }; | |||
5108 | ||||
5109 | // Folds transpose(splat x : src_type) : res_type into splat x : res_type. | |||
5110 | class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> { | |||
5111 | public: | |||
5112 | using OpRewritePattern::OpRewritePattern; | |||
5113 | ||||
5114 | LogicalResult matchAndRewrite(TransposeOp transposeOp, | |||
5115 | PatternRewriter &rewriter) const override { | |||
5116 | auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>(); | |||
5117 | if (!splatOp) | |||
5118 | return failure(); | |||
5119 | ||||
5120 | rewriter.replaceOpWithNewOp<vector::SplatOp>( | |||
5121 | transposeOp, transposeOp.getResultType(), splatOp.getInput()); | |||
5122 | return success(); | |||
5123 | } | |||
5124 | }; | |||
5125 | ||||
5126 | } // namespace | |||
5127 | ||||
5128 | void vector::TransposeOp::getCanonicalizationPatterns( | |||
5129 | RewritePatternSet &results, MLIRContext *context) { | |||
5130 | results | |||
5131 | .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>( | |||
5132 | context); | |||
5133 | } | |||
5134 | ||||
5135 | void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) { | |||
5136 | populateFromInt64AttrArray(getTransp(), results); | |||
5137 | } | |||
5138 | ||||
5139 | //===----------------------------------------------------------------------===// | |||
5140 | // ConstantMaskOp | |||
5141 | //===----------------------------------------------------------------------===// | |||
5142 | ||||
5143 | LogicalResult ConstantMaskOp::verify() { | |||
5144 | auto resultType = getResult().getType().cast<VectorType>(); | |||
5145 | // Check the corner case of 0-D vectors first. | |||
5146 | if (resultType.getRank() == 0) { | |||
5147 | if (getMaskDimSizes().size() != 1) | |||
5148 | return emitError("array attr must have length 1 for 0-D vectors"); | |||
5149 | auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt(); | |||
5150 | if (dim != 0 && dim != 1) | |||
5151 | return emitError("mask dim size must be either 0 or 1 for 0-D vectors"); | |||
5152 | return success(); | |||
5153 | } | |||
5154 | ||||
5155 | // Verify that array attr size matches the rank of the vector result. | |||
5156 | if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank()) | |||
5157 | return emitOpError( | |||
5158 | "must specify array attr of size equal vector result rank"); | |||
5159 | // Verify that each array attr element is in bounds of corresponding vector | |||
5160 | // result dimension size. | |||
5161 | auto resultShape = resultType.getShape(); | |||
5162 | SmallVector<int64_t, 4> maskDimSizes; | |||
5163 | for (const auto &it : llvm::enumerate(getMaskDimSizes())) { | |||
5164 | int64_t attrValue = it.value().cast<IntegerAttr>().getInt(); | |||
5165 | if (attrValue < 0 || attrValue > resultShape[it.index()]) | |||
5166 | return emitOpError( | |||
5167 | "array attr of size out of bounds of vector result dimension size"); | |||
5168 | maskDimSizes.push_back(attrValue); | |||
5169 | } | |||
5170 | // Verify that if one mask dim size is zero, they all should be zero (because | |||
5171 | // the mask region is a conjunction of each mask dimension interval). | |||
5172 | bool anyZeros = llvm::is_contained(maskDimSizes, 0); | |||
5173 | bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; }); | |||
5174 | if (anyZeros && !allZeros) | |||
5175 | return emitOpError("expected all mask dim sizes to be zeros, " | |||
5176 | "as a result of conjunction with zero mask dim"); | |||
5177 | // Verify that if the mask type is scalable, dimensions should be zero because | |||
5178 | // constant scalable masks can only be defined for the "none set" or "all set" | |||
5179 | // cases, and there is no VLA way to define an "all set" case for | |||
5180 | // `vector.constant_mask`. In the future, a convention could be established | |||
5181 | // to decide if a specific dimension value could be considered as "all set". | |||
5182 | if (resultType.isScalable() && | |||
5183 | getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0) | |||
5184 | return emitOpError("expected mask dim sizes for scalable masks to be 0"); | |||
5185 | return success(); | |||
5186 | } | |||
5187 | ||||
5188 | //===----------------------------------------------------------------------===// | |||
5189 | // CreateMaskOp | |||
5190 | //===----------------------------------------------------------------------===// | |||
5191 | ||||
5192 | LogicalResult CreateMaskOp::verify() { | |||
5193 | auto vectorType = getResult().getType().cast<VectorType>(); | |||
5194 | // Verify that an operand was specified for each result vector each dimension. | |||
5195 | if (vectorType.getRank() == 0) { | |||
5196 | if (getNumOperands() != 1) | |||
5197 | return emitOpError( | |||
5198 | "must specify exactly one operand for 0-D create_mask"); | |||
5199 | } else if (getNumOperands() != | |||
5200 | getResult().getType().cast<VectorType>().getRank()) { | |||
5201 | return emitOpError( | |||
5202 | "must specify an operand for each result vector dimension"); | |||
5203 | } | |||
5204 | return success(); | |||
5205 | } | |||
5206 | ||||
5207 | namespace { | |||
5208 | ||||
5209 | // Pattern to rewrite a CreateMaskOp with a ConstantMaskOp. | |||
5210 | class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> { | |||
5211 | public: | |||
5212 | using OpRewritePattern::OpRewritePattern; | |||
5213 | ||||
5214 | LogicalResult matchAndRewrite(CreateMaskOp createMaskOp, | |||
5215 | PatternRewriter &rewriter) const override { | |||
5216 | // Return if any of 'createMaskOp' operands are not defined by a constant. | |||
5217 | auto isNotDefByConstant = [](Value operand) { | |||
5218 | return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp()); | |||
5219 | }; | |||
5220 | if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant)) | |||
5221 | return failure(); | |||
5222 | ||||
5223 | // CreateMaskOp for scalable vectors can be folded only if all dimensions | |||
5224 | // are negative or zero. | |||
5225 | if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) { | |||
5226 | if (vType.isScalable()) | |||
5227 | for (auto opDim : createMaskOp.getOperands()) { | |||
5228 | APInt intVal; | |||
5229 | if (matchPattern(opDim, m_ConstantInt(&intVal)) && | |||
5230 | intVal.isStrictlyPositive()) | |||
5231 | return failure(); | |||
5232 | } | |||
5233 | } | |||
5234 | ||||
5235 | // Gather constant mask dimension sizes. | |||
5236 | SmallVector<int64_t, 4> maskDimSizes; | |||
5237 | maskDimSizes.reserve(createMaskOp->getNumOperands()); | |||
5238 | for (auto [operand, maxDimSize] : llvm::zip_equal( | |||
5239 | createMaskOp.getOperands(), createMaskOp.getType().getShape())) { | |||
5240 | Operation *defOp = operand.getDefiningOp(); | |||
5241 | int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value(); | |||
5242 | dimSize = std::min(dimSize, maxDimSize); | |||
5243 | // If one of dim sizes is zero, set all dims to zero. | |||
5244 | if (dimSize <= 0) { | |||
5245 | maskDimSizes.assign(createMaskOp.getType().getRank(), 0); | |||
5246 | break; | |||
5247 | } | |||
5248 | maskDimSizes.push_back(dimSize); | |||
5249 | } | |||
5250 | // Replace 'createMaskOp' with ConstantMaskOp. | |||
5251 | rewriter.replaceOpWithNewOp<ConstantMaskOp>( | |||
5252 | createMaskOp, createMaskOp.getResult().getType(), | |||
5253 | vector::getVectorSubscriptAttr(rewriter, maskDimSizes)); | |||
5254 | return success(); | |||
5255 | } | |||
5256 | }; | |||
5257 | ||||
5258 | } // namespace | |||
5259 | ||||
5260 | void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, | |||
5261 | MLIRContext *context) { | |||
5262 | results.add<CreateMaskFolder>(context); | |||
5263 | } | |||
5264 | ||||
5265 | //===----------------------------------------------------------------------===// | |||
5266 | // MaskOp | |||
5267 | //===----------------------------------------------------------------------===// | |||
5268 | ||||
5269 | void MaskOp::build( | |||
5270 | OpBuilder &builder, OperationState &result, Value mask, | |||
5271 | function_ref<void(OpBuilder &, Location)> maskRegionBuilder) { | |||
5272 | assert(maskRegionBuilder &&(static_cast <bool> (maskRegionBuilder && "builder callback for 'maskRegion' must be present" ) ? void (0) : __assert_fail ("maskRegionBuilder && \"builder callback for 'maskRegion' must be present\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5273, __extension__ __PRETTY_FUNCTION__)) | |||
5273 | "builder callback for 'maskRegion' must be present")(static_cast <bool> (maskRegionBuilder && "builder callback for 'maskRegion' must be present" ) ? void (0) : __assert_fail ("maskRegionBuilder && \"builder callback for 'maskRegion' must be present\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5273, __extension__ __PRETTY_FUNCTION__)); | |||
5274 | ||||
5275 | result.addOperands(mask); | |||
5276 | OpBuilder::InsertionGuard guard(builder); | |||
5277 | Region *maskRegion = result.addRegion(); | |||
5278 | builder.createBlock(maskRegion); | |||
5279 | maskRegionBuilder(builder, result.location); | |||
5280 | } | |||
5281 | ||||
5282 | void MaskOp::build( | |||
5283 | OpBuilder &builder, OperationState &result, Type resultType, Value mask, | |||
5284 | function_ref<void(OpBuilder &, Location)> maskRegionBuilder) { | |||
5285 | build(builder, result, resultType, mask, /*passthru=*/Value(), | |||
5286 | maskRegionBuilder); | |||
5287 | } | |||
5288 | ||||
5289 | void MaskOp::build( | |||
5290 | OpBuilder &builder, OperationState &result, Type resultType, Value mask, | |||
5291 | Value passthru, | |||
5292 | function_ref<void(OpBuilder &, Location)> maskRegionBuilder) { | |||
5293 | build(builder, result, mask, maskRegionBuilder); | |||
5294 | if (passthru) | |||
5295 | result.addOperands(passthru); | |||
5296 | result.addTypes(resultType); | |||
5297 | } | |||
5298 | ||||
5299 | ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { | |||
5300 | // Create the op region. | |||
5301 | result.regions.reserve(1); | |||
5302 | Region &maskRegion = *result.addRegion(); | |||
5303 | ||||
5304 | auto &builder = parser.getBuilder(); | |||
5305 | ||||
5306 | // Parse all the operands. | |||
5307 | OpAsmParser::UnresolvedOperand mask; | |||
5308 | if (parser.parseOperand(mask)) | |||
5309 | return failure(); | |||
5310 | ||||
5311 | // Optional passthru operand. | |||
5312 | OpAsmParser::UnresolvedOperand passthru; | |||
5313 | ParseResult parsePassthru = parser.parseOptionalComma(); | |||
5314 | if (parsePassthru.succeeded() && parser.parseOperand(passthru)) | |||
5315 | return failure(); | |||
5316 | ||||
5317 | // Parse op region. | |||
5318 | if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{})) | |||
5319 | return failure(); | |||
5320 | ||||
5321 | MaskOp::ensureTerminator(maskRegion, builder, result.location); | |||
5322 | ||||
5323 | // Parse the optional attribute list. | |||
5324 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
5325 | return failure(); | |||
5326 | ||||
5327 | // Parse all the types. | |||
5328 | Type maskType; | |||
5329 | if (parser.parseColonType(maskType)) | |||
5330 | return failure(); | |||
5331 | ||||
5332 | SmallVector<Type> resultTypes; | |||
5333 | if (parser.parseOptionalArrowTypeList(resultTypes)) | |||
5334 | return failure(); | |||
5335 | result.types.append(resultTypes); | |||
5336 | ||||
5337 | // Resolve operands. | |||
5338 | if (parser.resolveOperand(mask, maskType, result.operands)) | |||
5339 | return failure(); | |||
5340 | ||||
5341 | if (parsePassthru.succeeded()) | |||
5342 | if (parser.resolveOperand(passthru, resultTypes[0], result.operands)) | |||
5343 | return failure(); | |||
5344 | ||||
5345 | return success(); | |||
5346 | } | |||
5347 | ||||
5348 | void mlir::vector::MaskOp::print(OpAsmPrinter &p) { | |||
5349 | p << " " << getMask(); | |||
5350 | if (getPassthru()) | |||
5351 | p << ", " << getPassthru(); | |||
5352 | ||||
5353 | // Print single masked operation and skip terminator. | |||
5354 | p << " { "; | |||
5355 | Block *singleBlock = &getMaskRegion().getBlocks().front(); | |||
5356 | if (singleBlock && singleBlock->getOperations().size() > 1) | |||
5357 | p.printCustomOrGenericOp(&singleBlock->front()); | |||
5358 | p << " }"; | |||
5359 | ||||
5360 | p.printOptionalAttrDict(getOperation()->getAttrs()); | |||
5361 | ||||
5362 | p << " : " << getMask().getType(); | |||
5363 | if (getNumResults() > 0) | |||
5364 | p << " -> " << getResultTypes(); | |||
5365 | } | |||
5366 | ||||
5367 | void MaskOp::ensureTerminator(Region ®ion, Builder &builder, Location loc) { | |||
5368 | OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl< | |||
5369 | MaskOp>::ensureTerminator(region, builder, loc); | |||
5370 | // Keep the default yield terminator if the number of masked operations is not | |||
5371 | // the expected. This case will trigger a verification failure. | |||
5372 | if (region.front().getOperations().size() != 2) | |||
5373 | return; | |||
5374 | ||||
5375 | // Replace default yield terminator with a new one that returns the results | |||
5376 | // from the masked operation. | |||
5377 | OpBuilder opBuilder(builder.getContext()); | |||
5378 | Operation *maskedOp = ®ion.front().front(); | |||
5379 | Operation *oldYieldOp = ®ion.front().back(); | |||
5380 | assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp")(static_cast <bool> (isa<vector::YieldOp>(oldYieldOp ) && "Expected vector::YieldOp") ? void (0) : __assert_fail ("isa<vector::YieldOp>(oldYieldOp) && \"Expected vector::YieldOp\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5380, __extension__ __PRETTY_FUNCTION__)); | |||
5381 | ||||
5382 | opBuilder.setInsertionPoint(oldYieldOp); | |||
5383 | opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults()); | |||
5384 | oldYieldOp->dropAllReferences(); | |||
5385 | oldYieldOp->erase(); | |||
5386 | } | |||
5387 | ||||
5388 | LogicalResult MaskOp::verify() { | |||
5389 | // Structural checks. | |||
5390 | Block &block = getMaskRegion().getBlocks().front(); | |||
5391 | if (block.getOperations().size() < 2) | |||
5392 | return emitOpError("expects an operation to mask"); | |||
5393 | if (block.getOperations().size() > 2) | |||
5394 | return emitOpError("expects only one operation to mask"); | |||
5395 | ||||
5396 | auto maskableOp = dyn_cast<MaskableOpInterface>(block.front()); | |||
5397 | if (!maskableOp) | |||
5398 | return emitOpError("expects a maskable operation"); | |||
5399 | ||||
5400 | // Result checks. | |||
5401 | if (maskableOp->getNumResults() != getNumResults()) | |||
5402 | return emitOpError("expects number of results to match maskable operation " | |||
5403 | "number of results"); | |||
5404 | ||||
5405 | if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes())) | |||
5406 | return emitOpError( | |||
5407 | "expects result type to match maskable operation result type"); | |||
5408 | ||||
5409 | // Mask checks. | |||
5410 | Type expectedMaskType = maskableOp.getExpectedMaskType(); | |||
5411 | if (getMask().getType() != expectedMaskType) | |||
5412 | return emitOpError("expects a ") | |||
5413 | << expectedMaskType << " mask for the maskable operation"; | |||
5414 | ||||
5415 | // Passthru checks. | |||
5416 | Value passthru = getPassthru(); | |||
5417 | if (passthru) { | |||
5418 | if (!maskableOp.supportsPassthru()) | |||
5419 | return emitOpError( | |||
5420 | "doesn't expect a passthru argument for this maskable operation"); | |||
5421 | ||||
5422 | if (maskableOp->getNumResults() != 1) | |||
5423 | return emitOpError("expects result when passthru argument is provided"); | |||
5424 | ||||
5425 | if (passthru.getType() != maskableOp->getResultTypes()[0]) | |||
5426 | return emitOpError("expects passthru type to match result type"); | |||
5427 | } | |||
5428 | ||||
5429 | return success(); | |||
5430 | } | |||
5431 | ||||
5432 | // MaskingOpInterface definitions. | |||
5433 | ||||
5434 | /// Returns the operation masked by this 'vector.mask'. | |||
5435 | Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); } | |||
5436 | ||||
5437 | /// Returns true if 'vector.mask' has a passthru value. | |||
5438 | bool MaskOp::hasPassthru() { return getPassthru() != Value(); } | |||
5439 | ||||
5440 | //===----------------------------------------------------------------------===// | |||
5441 | // ScanOp | |||
5442 | //===----------------------------------------------------------------------===// | |||
5443 | ||||
5444 | LogicalResult ScanOp::verify() { | |||
5445 | VectorType srcType = getSourceType(); | |||
5446 | VectorType initialType = getInitialValueType(); | |||
5447 | // Check reduction dimension < rank. | |||
5448 | int64_t srcRank = srcType.getRank(); | |||
5449 | int64_t reductionDim = getReductionDim(); | |||
5450 | if (reductionDim >= srcRank) | |||
5451 | return emitOpError("reduction dimension ") | |||
5452 | << reductionDim << " has to be less than " << srcRank; | |||
5453 | ||||
5454 | // Check that rank(initial_value) = rank(src) - 1. | |||
5455 | int64_t initialValueRank = initialType.getRank(); | |||
5456 | if (initialValueRank != srcRank - 1) | |||
5457 | return emitOpError("initial value rank ") | |||
5458 | << initialValueRank << " has to be equal to " << srcRank - 1; | |||
5459 | ||||
5460 | // Check shapes of initial value and src. | |||
5461 | ArrayRef<int64_t> srcShape = srcType.getShape(); | |||
5462 | ArrayRef<int64_t> initialValueShapes = initialType.getShape(); | |||
5463 | SmallVector<int64_t> expectedShape; | |||
5464 | for (int i = 0; i < srcRank; i++) { | |||
5465 | if (i != reductionDim) | |||
5466 | expectedShape.push_back(srcShape[i]); | |||
5467 | } | |||
5468 | if (!llvm::equal(initialValueShapes, expectedShape)) { | |||
5469 | return emitOpError("incompatible input/initial value shapes"); | |||
5470 | } | |||
5471 | ||||
5472 | // Verify supported reduction kind. | |||
5473 | Type eltType = getDestType().getElementType(); | |||
5474 | if (!isSupportedCombiningKind(getKind(), eltType)) | |||
5475 | return emitOpError("unsupported reduction type ") | |||
5476 | << eltType << " for kind '" << stringifyCombiningKind(getKind()) | |||
5477 | << "'"; | |||
5478 | ||||
5479 | return success(); | |||
5480 | } | |||
5481 | ||||
5482 | void mlir::vector::populateVectorToVectorCanonicalizationPatterns( | |||
5483 | RewritePatternSet &patterns, PatternBenefit benefit) { | |||
5484 | patterns | |||
5485 | .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder, | |||
5486 | ScatterFolder, ExpandLoadFolder, CompressStoreFolder, | |||
5487 | StridedSliceConstantMaskFolder, TransposeFolder>( | |||
5488 | patterns.getContext(), benefit); | |||
5489 | } | |||
5490 | ||||
5491 | //===----------------------------------------------------------------------===// | |||
5492 | // SplatOp | |||
5493 | //===----------------------------------------------------------------------===// | |||
5494 | ||||
5495 | OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) { | |||
5496 | auto constOperand = operands.front(); | |||
5497 | if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>()) | |||
5498 | return {}; | |||
5499 | ||||
5500 | // SplatElementsAttr::get treats single value for second arg as being a splat. | |||
5501 | return SplatElementsAttr::get(getType(), {constOperand}); | |||
5502 | } | |||
5503 | ||||
5504 | //===----------------------------------------------------------------------===// | |||
5505 | // WarpExecuteOnLane0Op | |||
5506 | //===----------------------------------------------------------------------===// | |||
5507 | ||||
5508 | void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) { | |||
5509 | p << "(" << getLaneid() << ")"; | |||
5510 | ||||
5511 | SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()}; | |||
5512 | auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName()); | |||
5513 | p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]"; | |||
5514 | ||||
5515 | if (!getArgs().empty()) | |||
5516 | p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")"; | |||
5517 | if (!getResults().empty()) | |||
5518 | p << " -> (" << getResults().getTypes() << ')'; | |||
5519 | p << " "; | |||
5520 | p.printRegion(getRegion(), | |||
5521 | /*printEntryBlockArgs=*/true, | |||
5522 | /*printBlockTerminators=*/!getResults().empty()); | |||
5523 | p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr); | |||
5524 | } | |||
5525 | ||||
5526 | ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser, | |||
5527 | OperationState &result) { | |||
5528 | // Create the region. | |||
5529 | result.regions.reserve(1); | |||
5530 | Region *warpRegion = result.addRegion(); | |||
5531 | ||||
5532 | auto &builder = parser.getBuilder(); | |||
5533 | OpAsmParser::UnresolvedOperand laneId; | |||
5534 | ||||
5535 | // Parse predicate operand. | |||
5536 | if (parser.parseLParen() || | |||
| ||||
5537 | parser.parseOperand(laneId, /*allowResultNumber=*/false) || | |||
5538 | parser.parseRParen()) | |||
5539 | return failure(); | |||
5540 | ||||
5541 | int64_t warpSize; | |||
5542 | if (parser.parseLSquare() || parser.parseInteger(warpSize) || | |||
5543 | parser.parseRSquare()) | |||
5544 | return failure(); | |||
5545 | result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(), | |||
5546 | builder.getContext())), | |||
5547 | builder.getI64IntegerAttr(warpSize)); | |||
| ||||
5548 | ||||
5549 | if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands)) | |||
5550 | return failure(); | |||
5551 | ||||
5552 | llvm::SMLoc inputsOperandsLoc; | |||
5553 | SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands; | |||
5554 | SmallVector<Type> inputTypes; | |||
5555 | if (succeeded(parser.parseOptionalKeyword("args"))) { | |||
5556 | if (parser.parseLParen()) | |||
5557 | return failure(); | |||
5558 | ||||
5559 | inputsOperandsLoc = parser.getCurrentLocation(); | |||
5560 | if (parser.parseOperandList(inputsOperands) || | |||
5561 | parser.parseColonTypeList(inputTypes) || parser.parseRParen()) | |||
5562 | return failure(); | |||
5563 | } | |||
5564 | if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc, | |||
5565 | result.operands)) | |||
5566 | return failure(); | |||
5567 | ||||
5568 | // Parse optional results type list. | |||
5569 | if (parser.parseOptionalArrowTypeList(result.types)) | |||
5570 | return failure(); | |||
5571 | // Parse the region. | |||
5572 | if (parser.parseRegion(*warpRegion, /*arguments=*/{}, | |||
5573 | /*argTypes=*/{})) | |||
5574 | return failure(); | |||
5575 | WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location); | |||
5576 | ||||
5577 | // Parse the optional attribute list. | |||
5578 | if (parser.parseOptionalAttrDict(result.attributes)) | |||
5579 | return failure(); | |||
5580 | return success(); | |||
5581 | } | |||
5582 | ||||
5583 | void WarpExecuteOnLane0Op::getSuccessorRegions( | |||
5584 | std::optional<unsigned> index, ArrayRef<Attribute> operands, | |||
5585 | SmallVectorImpl<RegionSuccessor> ®ions) { | |||
5586 | if (index) { | |||
5587 | regions.push_back(RegionSuccessor(getResults())); | |||
5588 | return; | |||
5589 | } | |||
5590 | ||||
5591 | // The warp region is always executed | |||
5592 | regions.push_back(RegionSuccessor(&getWarpRegion())); | |||
5593 | } | |||
5594 | ||||
5595 | void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, | |||
5596 | TypeRange resultTypes, Value laneId, | |||
5597 | int64_t warpSize) { | |||
5598 | build(builder, result, resultTypes, laneId, warpSize, | |||
5599 | /*operands=*/std::nullopt, /*argTypes=*/std::nullopt); | |||
5600 | } | |||
5601 | ||||
5602 | void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result, | |||
5603 | TypeRange resultTypes, Value laneId, | |||
5604 | int64_t warpSize, ValueRange args, | |||
5605 | TypeRange blockArgTypes) { | |||
5606 | result.addOperands(laneId); | |||
5607 | result.addAttribute(getAttributeNames()[0], | |||
5608 | builder.getI64IntegerAttr(warpSize)); | |||
5609 | result.addTypes(resultTypes); | |||
5610 | result.addOperands(args); | |||
5611 | assert(args.size() == blockArgTypes.size())(static_cast <bool> (args.size() == blockArgTypes.size( )) ? void (0) : __assert_fail ("args.size() == blockArgTypes.size()" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5611, __extension__ __PRETTY_FUNCTION__)); | |||
5612 | OpBuilder::InsertionGuard guard(builder); | |||
5613 | Region *warpRegion = result.addRegion(); | |||
5614 | Block *block = builder.createBlock(warpRegion); | |||
5615 | for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args)) | |||
5616 | block->addArgument(type, arg.getLoc()); | |||
5617 | } | |||
5618 | ||||
5619 | /// Helper check if the distributed vector type is consistent with the expanded | |||
5620 | /// type and distributed size. | |||
5621 | static LogicalResult verifyDistributedType(Type expanded, Type distributed, | |||
5622 | int64_t warpSize, Operation *op) { | |||
5623 | // If the types matches there is no distribution. | |||
5624 | if (expanded == distributed) | |||
5625 | return success(); | |||
5626 | auto expandedVecType = expanded.dyn_cast<VectorType>(); | |||
5627 | auto distributedVecType = distributed.dyn_cast<VectorType>(); | |||
5628 | if (!expandedVecType || !distributedVecType) | |||
5629 | return op->emitOpError("expected vector type for distributed operands."); | |||
5630 | if (expandedVecType.getRank() != distributedVecType.getRank() || | |||
5631 | expandedVecType.getElementType() != distributedVecType.getElementType()) | |||
5632 | return op->emitOpError( | |||
5633 | "expected distributed vectors to have same rank and element type."); | |||
5634 | bool foundDistributedDim = false; | |||
5635 | for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) { | |||
5636 | if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i)) | |||
5637 | continue; | |||
5638 | if (expandedVecType.getDimSize(i) == | |||
5639 | distributedVecType.getDimSize(i) * warpSize) { | |||
5640 | if (foundDistributedDim) | |||
5641 | return op->emitOpError() | |||
5642 | << "expected only one dimension to be distributed from " | |||
5643 | << expandedVecType << " to " << distributedVecType; | |||
5644 | foundDistributedDim = true; | |||
5645 | continue; | |||
5646 | } | |||
5647 | return op->emitOpError() << "incompatible distribution dimensions from " | |||
5648 | << expandedVecType << " to " << distributedVecType; | |||
5649 | } | |||
5650 | return success(); | |||
5651 | } | |||
5652 | ||||
5653 | LogicalResult WarpExecuteOnLane0Op::verify() { | |||
5654 | if (getArgs().size() != getWarpRegion().getNumArguments()) | |||
5655 | return emitOpError( | |||
5656 | "expected same number op arguments and block arguments."); | |||
5657 | auto yield = | |||
5658 | cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator()); | |||
5659 | if (yield.getNumOperands() != getNumResults()) | |||
5660 | return emitOpError( | |||
5661 | "expected same number of yield operands and return values."); | |||
5662 | int64_t warpSize = getWarpSize(); | |||
5663 | for (auto [regionArg, arg] : | |||
5664 | llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) { | |||
5665 | if (failed(verifyDistributedType(regionArg.getType(), arg.getType(), | |||
5666 | warpSize, getOperation()))) | |||
5667 | return failure(); | |||
5668 | } | |||
5669 | for (auto [yieldOperand, result] : | |||
5670 | llvm::zip_equal(yield.getOperands(), getResults())) { | |||
5671 | if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(), | |||
5672 | warpSize, getOperation()))) | |||
5673 | return failure(); | |||
5674 | } | |||
5675 | return success(); | |||
5676 | } | |||
5677 | ||||
5678 | bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { | |||
5679 | return succeeded( | |||
5680 | verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); | |||
5681 | } | |||
5682 | ||||
5683 | Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, | |||
5684 | CombiningKind kind, Value v1, Value v2) { | |||
5685 | Type t1 = getElementTypeOrSelf(v1.getType()); | |||
5686 | Type t2 = getElementTypeOrSelf(v2.getType()); | |||
5687 | switch (kind) { | |||
5688 | case CombiningKind::ADD: | |||
5689 | if (t1.isIntOrIndex() && t2.isIntOrIndex()) | |||
5690 | return b.createOrFold<arith::AddIOp>(loc, v1, v2); | |||
5691 | else if (t1.isa<FloatType>() && t2.isa<FloatType>()) | |||
5692 | return b.createOrFold<arith::AddFOp>(loc, v1, v2); | |||
5693 | llvm_unreachable("invalid value types for ADD reduction")::llvm::llvm_unreachable_internal("invalid value types for ADD reduction" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5693); | |||
5694 | case CombiningKind::AND: | |||
5695 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5695, __extension__ __PRETTY_FUNCTION__)); | |||
5696 | return b.createOrFold<arith::AndIOp>(loc, v1, v2); | |||
5697 | case CombiningKind::MAXF: | |||
5698 | assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() && t2.isa<FloatType>() && "expected float values" ) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5699, __extension__ __PRETTY_FUNCTION__)) | |||
5699 | "expected float values")(static_cast <bool> (t1.isa<FloatType>() && t2.isa<FloatType>() && "expected float values" ) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5699, __extension__ __PRETTY_FUNCTION__)); | |||
5700 | return b.createOrFold<arith::MaxFOp>(loc, v1, v2); | |||
5701 | case CombiningKind::MINF: | |||
5702 | assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() && t2.isa<FloatType>() && "expected float values" ) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5703, __extension__ __PRETTY_FUNCTION__)) | |||
5703 | "expected float values")(static_cast <bool> (t1.isa<FloatType>() && t2.isa<FloatType>() && "expected float values" ) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5703, __extension__ __PRETTY_FUNCTION__)); | |||
5704 | return b.createOrFold<arith::MinFOp>(loc, v1, v2); | |||
5705 | case CombiningKind::MAXSI: | |||
5706 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5706, __extension__ __PRETTY_FUNCTION__)); | |||
5707 | return b.createOrFold<arith::MaxSIOp>(loc, v1, v2); | |||
5708 | case CombiningKind::MINSI: | |||
5709 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5709, __extension__ __PRETTY_FUNCTION__)); | |||
5710 | return b.createOrFold<arith::MinSIOp>(loc, v1, v2); | |||
5711 | case CombiningKind::MAXUI: | |||
5712 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5712, __extension__ __PRETTY_FUNCTION__)); | |||
5713 | return b.createOrFold<arith::MaxUIOp>(loc, v1, v2); | |||
5714 | case CombiningKind::MINUI: | |||
5715 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5715, __extension__ __PRETTY_FUNCTION__)); | |||
5716 | return b.createOrFold<arith::MinUIOp>(loc, v1, v2); | |||
5717 | case CombiningKind::MUL: | |||
5718 | if (t1.isIntOrIndex() && t2.isIntOrIndex()) | |||
5719 | return b.createOrFold<arith::MulIOp>(loc, v1, v2); | |||
5720 | else if (t1.isa<FloatType>() && t2.isa<FloatType>()) | |||
5721 | return b.createOrFold<arith::MulFOp>(loc, v1, v2); | |||
5722 | llvm_unreachable("invalid value types for MUL reduction")::llvm::llvm_unreachable_internal("invalid value types for MUL reduction" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5722); | |||
5723 | case CombiningKind::OR: | |||
5724 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5724, __extension__ __PRETTY_FUNCTION__)); | |||
5725 | return b.createOrFold<arith::OrIOp>(loc, v1, v2); | |||
5726 | case CombiningKind::XOR: | |||
5727 | assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex () && "expected int values") ? void (0) : __assert_fail ("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\"" , "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5727, __extension__ __PRETTY_FUNCTION__)); | |||
5728 | return b.createOrFold<arith::XOrIOp>(loc, v1, v2); | |||
5729 | }; | |||
5730 | llvm_unreachable("unknown CombiningKind")::llvm::llvm_unreachable_internal("unknown CombiningKind", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp" , 5730); | |||
5731 | } | |||
5732 | ||||
5733 | //===----------------------------------------------------------------------===// | |||
5734 | // TableGen'd op method definitions | |||
5735 | //===----------------------------------------------------------------------===// | |||
5736 | ||||
5737 | #define GET_ATTRDEF_CLASSES | |||
5738 | #include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc" | |||
5739 | ||||
5740 | #define GET_OP_CLASSES | |||
5741 | #include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc" |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | class AsmParsedResourceEntry; |
24 | class AsmResourceBuilder; |
25 | class Builder; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // AsmDialectResourceHandle |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// This class represents an opaque handle to a dialect resource entry. |
32 | class AsmDialectResourceHandle { |
33 | public: |
34 | AsmDialectResourceHandle() = default; |
35 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
36 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
37 | bool operator==(const AsmDialectResourceHandle &other) const { |
38 | return resource == other.resource; |
39 | } |
40 | |
41 | /// Return an opaque pointer to the referenced resource. |
42 | void *getResource() const { return resource; } |
43 | |
44 | /// Return the type ID of the resource. |
45 | TypeID getTypeID() const { return opaqueID; } |
46 | |
47 | /// Return the dialect that owns the resource. |
48 | Dialect *getDialect() const { return dialect; } |
49 | |
50 | private: |
51 | /// The opaque handle to the dialect resource. |
52 | void *resource = nullptr; |
53 | /// The type of the resource referenced. |
54 | TypeID opaqueID; |
55 | /// The dialect owning the given resource. |
56 | Dialect *dialect; |
57 | }; |
58 | |
59 | /// This class represents a CRTP base class for dialect resource handles. It |
60 | /// abstracts away various utilities necessary for defined derived resource |
61 | /// handles. |
62 | template <typename DerivedT, typename ResourceT, typename DialectT> |
63 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
64 | public: |
65 | using Dialect = DialectT; |
66 | |
67 | /// Construct a handle from a pointer to the resource. The given pointer |
68 | /// should be guaranteed to live beyond the life of this handle. |
69 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
70 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
71 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
72 | : AsmDialectResourceHandle(handle) { |
73 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__ __PRETTY_FUNCTION__)); |
74 | } |
75 | |
76 | /// Return the resource referenced by this handle. |
77 | ResourceT *getResource() { |
78 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
79 | } |
80 | const ResourceT *getResource() const { |
81 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
82 | } |
83 | |
84 | /// Return the dialect that owns the resource. |
85 | DialectT *getDialect() const { |
86 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
87 | } |
88 | |
89 | /// Support llvm style casting. |
90 | static bool classof(const AsmDialectResourceHandle *handle) { |
91 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
92 | } |
93 | }; |
94 | |
95 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
96 | return llvm::hash_value(param.getResource()); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // AsmPrinter |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | /// This base class exposes generic asm printer hooks, usable across the various |
104 | /// derived printers. |
105 | class AsmPrinter { |
106 | public: |
107 | /// This class contains the internal default implementation of the base |
108 | /// printer methods. |
109 | class Impl; |
110 | |
111 | /// Initialize the printer with the given internal implementation. |
112 | AsmPrinter(Impl &impl) : impl(&impl) {} |
113 | virtual ~AsmPrinter(); |
114 | |
115 | /// Return the raw output stream used by this printer. |
116 | virtual raw_ostream &getStream() const; |
117 | |
118 | /// Print the given floating point value in a stabilized form that can be |
119 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
120 | /// hook on the AsmParser. |
121 | virtual void printFloat(const APFloat &value); |
122 | |
123 | virtual void printType(Type type); |
124 | virtual void printAttribute(Attribute attr); |
125 | |
126 | /// Trait to check if `AttrType` provides a `print` method. |
127 | template <typename AttrOrType> |
128 | using has_print_method = |
129 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
130 | template <typename AttrOrType> |
131 | using detect_has_print_method = |
132 | llvm::is_detected<has_print_method, AttrOrType>; |
133 | |
134 | /// Print the provided attribute in the context of an operation custom |
135 | /// printer/parser: this will invoke directly the print method on the |
136 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
137 | template <typename AttrOrType, |
138 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
139 | *sfinae = nullptr> |
140 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
141 | if (succeeded(printAlias(attrOrType))) |
142 | return; |
143 | attrOrType.print(*this); |
144 | } |
145 | |
146 | /// Print the provided array of attributes or types in the context of an |
147 | /// operation custom printer/parser: this will invoke directly the print |
148 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
149 | /// most cases. |
150 | template <typename AttrOrType, |
151 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
152 | *sfinae = nullptr> |
153 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
154 | llvm::interleaveComma( |
155 | attrOrTypes, getStream(), |
156 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
157 | } |
158 | |
159 | /// SFINAE for printing the provided attribute in the context of an operation |
160 | /// custom printer in the case where the attribute does not define a print |
161 | /// method. |
162 | template <typename AttrOrType, |
163 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
164 | *sfinae = nullptr> |
165 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
166 | *this << attrOrType; |
167 | } |
168 | |
169 | /// Print the given attribute without its type. The corresponding parser must |
170 | /// provide a valid type for the attribute. |
171 | virtual void printAttributeWithoutType(Attribute attr); |
172 | |
173 | /// Print the given string as a keyword, or a quoted and escaped string if it |
174 | /// has any special or non-printable characters in it. |
175 | virtual void printKeywordOrString(StringRef keyword); |
176 | |
177 | /// Print the given string as a symbol reference, i.e. a form representable by |
178 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
179 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
180 | /// special or non-printable characters in it. |
181 | virtual void printSymbolName(StringRef symbolRef); |
182 | |
183 | /// Print a handle to the given dialect resource. |
184 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); |
185 | |
186 | /// Print an optional arrow followed by a type list. |
187 | template <typename TypeRange> |
188 | void printOptionalArrowTypeList(TypeRange &&types) { |
189 | if (types.begin() != types.end()) |
190 | printArrowTypeList(types); |
191 | } |
192 | template <typename TypeRange> |
193 | void printArrowTypeList(TypeRange &&types) { |
194 | auto &os = getStream() << " -> "; |
195 | |
196 | bool wrapped = !llvm::hasSingleElement(types) || |
197 | (*types.begin()).template isa<FunctionType>(); |
198 | if (wrapped) |
199 | os << '('; |
200 | llvm::interleaveComma(types, *this); |
201 | if (wrapped) |
202 | os << ')'; |
203 | } |
204 | |
205 | /// Print the two given type ranges in a functional form. |
206 | template <typename InputRangeT, typename ResultRangeT> |
207 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
208 | auto &os = getStream(); |
209 | os << '('; |
210 | llvm::interleaveComma(inputs, *this); |
211 | os << ')'; |
212 | printArrowTypeList(results); |
213 | } |
214 | |
215 | protected: |
216 | /// Initialize the printer with no internal implementation. In this case, all |
217 | /// virtual methods of this class must be overriden. |
218 | AsmPrinter() = default; |
219 | |
220 | private: |
221 | AsmPrinter(const AsmPrinter &) = delete; |
222 | void operator=(const AsmPrinter &) = delete; |
223 | |
224 | /// Print the alias for the given attribute, return failure if no alias could |
225 | /// be printed. |
226 | virtual LogicalResult printAlias(Attribute attr); |
227 | |
228 | /// Print the alias for the given type, return failure if no alias could |
229 | /// be printed. |
230 | virtual LogicalResult printAlias(Type type); |
231 | |
232 | /// The internal implementation of the printer. |
233 | Impl *impl{nullptr}; |
234 | }; |
235 | |
236 | template <typename AsmPrinterT> |
237 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
238 | AsmPrinterT &> |
239 | operator<<(AsmPrinterT &p, Type type) { |
240 | p.printType(type); |
241 | return p; |
242 | } |
243 | |
244 | template <typename AsmPrinterT> |
245 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
246 | AsmPrinterT &> |
247 | operator<<(AsmPrinterT &p, Attribute attr) { |
248 | p.printAttribute(attr); |
249 | return p; |
250 | } |
251 | |
252 | template <typename AsmPrinterT> |
253 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
254 | AsmPrinterT &> |
255 | operator<<(AsmPrinterT &p, const APFloat &value) { |
256 | p.printFloat(value); |
257 | return p; |
258 | } |
259 | template <typename AsmPrinterT> |
260 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
261 | AsmPrinterT &> |
262 | operator<<(AsmPrinterT &p, float value) { |
263 | return p << APFloat(value); |
264 | } |
265 | template <typename AsmPrinterT> |
266 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
267 | AsmPrinterT &> |
268 | operator<<(AsmPrinterT &p, double value) { |
269 | return p << APFloat(value); |
270 | } |
271 | |
272 | // Support printing anything that isn't convertible to one of the other |
273 | // streamable types, even if it isn't exactly one of them. For example, we want |
274 | // to print FunctionType with the Type version above, not have it match this. |
275 | template <typename AsmPrinterT, typename T, |
276 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && |
277 | !std::is_convertible<T &, Type &>::value && |
278 | !std::is_convertible<T &, Attribute &>::value && |
279 | !std::is_convertible<T &, ValueRange>::value && |
280 | !std::is_convertible<T &, APFloat &>::value && |
281 | !llvm::is_one_of<T, bool, float, double>::value, |
282 | T> * = nullptr> |
283 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
284 | AsmPrinterT &> |
285 | operator<<(AsmPrinterT &p, const T &other) { |
286 | p.getStream() << other; |
287 | return p; |
288 | } |
289 | |
290 | template <typename AsmPrinterT> |
291 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
292 | AsmPrinterT &> |
293 | operator<<(AsmPrinterT &p, bool value) { |
294 | return p << (value ? StringRef("true") : "false"); |
295 | } |
296 | |
297 | template <typename AsmPrinterT, typename ValueRangeT> |
298 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
299 | AsmPrinterT &> |
300 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
301 | llvm::interleaveComma(types, p); |
302 | return p; |
303 | } |
304 | template <typename AsmPrinterT> |
305 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
306 | AsmPrinterT &> |
307 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
308 | llvm::interleaveComma(types, p); |
309 | return p; |
310 | } |
311 | template <typename AsmPrinterT, typename ElementT> |
312 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
313 | AsmPrinterT &> |
314 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
315 | llvm::interleaveComma(types, p); |
316 | return p; |
317 | } |
318 | |
319 | //===----------------------------------------------------------------------===// |
320 | // OpAsmPrinter |
321 | //===----------------------------------------------------------------------===// |
322 | |
323 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
324 | /// necessary to implement a custom print() method. |
325 | class OpAsmPrinter : public AsmPrinter { |
326 | public: |
327 | using AsmPrinter::AsmPrinter; |
328 | ~OpAsmPrinter() override; |
329 | |
330 | /// Print a loc(...) specifier if printing debug info is enabled. |
331 | virtual void printOptionalLocationSpecifier(Location loc) = 0; |
332 | |
333 | /// Print a newline and indent the printer to the start of the current |
334 | /// operation. |
335 | virtual void printNewline() = 0; |
336 | |
337 | /// Increase indentation. |
338 | virtual void increaseIndent() = 0; |
339 | |
340 | /// Decrease indentation. |
341 | virtual void decreaseIndent() = 0; |
342 | |
343 | /// Print a block argument in the usual format of: |
344 | /// %ssaName : type {attr1=42} loc("here") |
345 | /// where location printing is controlled by the standard internal option. |
346 | /// You may pass omitType=true to not print a type, and pass an empty |
347 | /// attribute list if you don't care for attributes. |
348 | virtual void printRegionArgument(BlockArgument arg, |
349 | ArrayRef<NamedAttribute> argAttrs = {}, |
350 | bool omitType = false) = 0; |
351 | |
352 | /// Print implementations for various things an operation contains. |
353 | virtual void printOperand(Value value) = 0; |
354 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
355 | |
356 | /// Print a comma separated list of operands. |
357 | template <typename ContainerType> |
358 | void printOperands(const ContainerType &container) { |
359 | printOperands(container.begin(), container.end()); |
360 | } |
361 | |
362 | /// Print a comma separated list of operands. |
363 | template <typename IteratorType> |
364 | void printOperands(IteratorType it, IteratorType end) { |
365 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), |
366 | [this](Value value) { printOperand(value); }); |
367 | } |
368 | |
369 | /// Print the given successor. |
370 | virtual void printSuccessor(Block *successor) = 0; |
371 | |
372 | /// Print the successor and its operands. |
373 | virtual void printSuccessorAndUseList(Block *successor, |
374 | ValueRange succOperands) = 0; |
375 | |
376 | /// If the specified operation has attributes, print out an attribute |
377 | /// dictionary with their values. elidedAttrs allows the client to ignore |
378 | /// specific well known attributes, commonly used if the attribute value is |
379 | /// printed some other way (like as a fixed operand). |
380 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
381 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
382 | |
383 | /// If the specified operation has attributes, print out an attribute |
384 | /// dictionary prefixed with 'attributes'. |
385 | virtual void |
386 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
387 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
388 | |
389 | /// Prints the entire operation with the custom assembly form, if available, |
390 | /// or the generic assembly form, otherwise. |
391 | virtual void printCustomOrGenericOp(Operation *op) = 0; |
392 | |
393 | /// Print the entire operation with the default generic assembly form. |
394 | /// If `printOpName` is true, then the operation name is printed (the default) |
395 | /// otherwise it is omitted and the print will start with the operand list. |
396 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
397 | |
398 | /// Prints a region. |
399 | /// If 'printEntryBlockArgs' is false, the arguments of the |
400 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
401 | /// operation of the block is not printed. If printEmptyBlock is true, then |
402 | /// the block header is printed even if the block is empty. |
403 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
404 | bool printBlockTerminators = true, |
405 | bool printEmptyBlock = false) = 0; |
406 | |
407 | /// Renumber the arguments for the specified region to the same names as the |
408 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
409 | /// operations. If any entry in namesToUse is null, the corresponding |
410 | /// argument name is left alone. |
411 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
412 | |
413 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
414 | /// of dims/symbols. |
415 | /// Operand values must come from single-result sources, and be valid |
416 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
417 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
418 | ValueRange operands) = 0; |
419 | |
420 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
421 | /// dims and symbols. |
422 | /// Operand values must come from single-result sources, and be valid |
423 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
424 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
425 | ValueRange symOperands) = 0; |
426 | |
427 | /// Print the complete type of an operation in functional form. |
428 | void printFunctionalType(Operation *op); |
429 | using AsmPrinter::printFunctionalType; |
430 | }; |
431 | |
432 | // Make the implementations convenient to use. |
433 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
434 | p.printOperand(value); |
435 | return p; |
436 | } |
437 | |
438 | template <typename T, |
439 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && |
440 | !std::is_convertible<T &, Value &>::value, |
441 | T> * = nullptr> |
442 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
443 | p.printOperands(values); |
444 | return p; |
445 | } |
446 | |
447 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
448 | p.printSuccessor(value); |
449 | return p; |
450 | } |
451 | |
452 | //===----------------------------------------------------------------------===// |
453 | // AsmParser |
454 | //===----------------------------------------------------------------------===// |
455 | |
456 | /// This base class exposes generic asm parser hooks, usable across the various |
457 | /// derived parsers. |
458 | class AsmParser { |
459 | public: |
460 | AsmParser() = default; |
461 | virtual ~AsmParser(); |
462 | |
463 | MLIRContext *getContext() const; |
464 | |
465 | /// Return the location of the original name token. |
466 | virtual SMLoc getNameLoc() const = 0; |
467 | |
468 | //===--------------------------------------------------------------------===// |
469 | // Utilities |
470 | //===--------------------------------------------------------------------===// |
471 | |
472 | /// Emit a diagnostic at the specified location and return failure. |
473 | virtual InFlightDiagnostic emitError(SMLoc loc, |
474 | const Twine &message = {}) = 0; |
475 | |
476 | /// Return a builder which provides useful access to MLIRContext, global |
477 | /// objects like types and attributes. |
478 | virtual Builder &getBuilder() const = 0; |
479 | |
480 | /// Get the location of the next token and store it into the argument. This |
481 | /// always succeeds. |
482 | virtual SMLoc getCurrentLocation() = 0; |
483 | ParseResult getCurrentLocation(SMLoc *loc) { |
484 | *loc = getCurrentLocation(); |
485 | return success(); |
486 | } |
487 | |
488 | /// Re-encode the given source location as an MLIR location and return it. |
489 | /// Note: This method should only be used when a `Location` is necessary, as |
490 | /// the encoding process is not efficient. |
491 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
492 | |
493 | //===--------------------------------------------------------------------===// |
494 | // Token Parsing |
495 | //===--------------------------------------------------------------------===// |
496 | |
497 | /// Parse a '->' token. |
498 | virtual ParseResult parseArrow() = 0; |
499 | |
500 | /// Parse a '->' token if present |
501 | virtual ParseResult parseOptionalArrow() = 0; |
502 | |
503 | /// Parse a `{` token. |
504 | virtual ParseResult parseLBrace() = 0; |
505 | |
506 | /// Parse a `{` token if present. |
507 | virtual ParseResult parseOptionalLBrace() = 0; |
508 | |
509 | /// Parse a `}` token. |
510 | virtual ParseResult parseRBrace() = 0; |
511 | |
512 | /// Parse a `}` token if present. |
513 | virtual ParseResult parseOptionalRBrace() = 0; |
514 | |
515 | /// Parse a `:` token. |
516 | virtual ParseResult parseColon() = 0; |
517 | |
518 | /// Parse a `:` token if present. |
519 | virtual ParseResult parseOptionalColon() = 0; |
520 | |
521 | /// Parse a `,` token. |
522 | virtual ParseResult parseComma() = 0; |
523 | |
524 | /// Parse a `,` token if present. |
525 | virtual ParseResult parseOptionalComma() = 0; |
526 | |
527 | /// Parse a `=` token. |
528 | virtual ParseResult parseEqual() = 0; |
529 | |
530 | /// Parse a `=` token if present. |
531 | virtual ParseResult parseOptionalEqual() = 0; |
532 | |
533 | /// Parse a '<' token. |
534 | virtual ParseResult parseLess() = 0; |
535 | |
536 | /// Parse a '<' token if present. |
537 | virtual ParseResult parseOptionalLess() = 0; |
538 | |
539 | /// Parse a '>' token. |
540 | virtual ParseResult parseGreater() = 0; |
541 | |
542 | /// Parse a '>' token if present. |
543 | virtual ParseResult parseOptionalGreater() = 0; |
544 | |
545 | /// Parse a '?' token. |
546 | virtual ParseResult parseQuestion() = 0; |
547 | |
548 | /// Parse a '?' token if present. |
549 | virtual ParseResult parseOptionalQuestion() = 0; |
550 | |
551 | /// Parse a '+' token. |
552 | virtual ParseResult parsePlus() = 0; |
553 | |
554 | /// Parse a '+' token if present. |
555 | virtual ParseResult parseOptionalPlus() = 0; |
556 | |
557 | /// Parse a '*' token. |
558 | virtual ParseResult parseStar() = 0; |
559 | |
560 | /// Parse a '*' token if present. |
561 | virtual ParseResult parseOptionalStar() = 0; |
562 | |
563 | /// Parse a '|' token. |
564 | virtual ParseResult parseVerticalBar() = 0; |
565 | |
566 | /// Parse a '|' token if present. |
567 | virtual ParseResult parseOptionalVerticalBar() = 0; |
568 | |
569 | /// Parse a quoted string token. |
570 | ParseResult parseString(std::string *string) { |
571 | auto loc = getCurrentLocation(); |
572 | if (parseOptionalString(string)) |
573 | return emitError(loc, "expected string"); |
574 | return success(); |
575 | } |
576 | |
577 | /// Parse a quoted string token if present. |
578 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
579 | |
580 | /// Parses a Base64 encoded string of bytes. |
581 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; |
582 | |
583 | /// Parse a `(` token. |
584 | virtual ParseResult parseLParen() = 0; |
585 | |
586 | /// Parse a `(` token if present. |
587 | virtual ParseResult parseOptionalLParen() = 0; |
588 | |
589 | /// Parse a `)` token. |
590 | virtual ParseResult parseRParen() = 0; |
591 | |
592 | /// Parse a `)` token if present. |
593 | virtual ParseResult parseOptionalRParen() = 0; |
594 | |
595 | /// Parse a `[` token. |
596 | virtual ParseResult parseLSquare() = 0; |
597 | |
598 | /// Parse a `[` token if present. |
599 | virtual ParseResult parseOptionalLSquare() = 0; |
600 | |
601 | /// Parse a `]` token. |
602 | virtual ParseResult parseRSquare() = 0; |
603 | |
604 | /// Parse a `]` token if present. |
605 | virtual ParseResult parseOptionalRSquare() = 0; |
606 | |
607 | /// Parse a `...` token. |
608 | virtual ParseResult parseEllipsis() = 0; |
609 | |
610 | /// Parse a `...` token if present; |
611 | virtual ParseResult parseOptionalEllipsis() = 0; |
612 | |
613 | /// Parse a floating point value from the stream. |
614 | virtual ParseResult parseFloat(double &result) = 0; |
615 | |
616 | /// Parse an integer value from the stream. |
617 | template <typename IntT> |
618 | ParseResult parseInteger(IntT &result) { |
619 | auto loc = getCurrentLocation(); |
620 | OptionalParseResult parseResult = parseOptionalInteger(result); |
621 | if (!parseResult.has_value()) |
622 | return emitError(loc, "expected integer value"); |
623 | return *parseResult; |
624 | } |
625 | |
626 | /// Parse an optional integer value from the stream. |
627 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
628 | |
629 | template <typename IntT> |
630 | OptionalParseResult parseOptionalInteger(IntT &result) { |
631 | auto loc = getCurrentLocation(); |
632 | |
633 | // Parse the unsigned variant. |
634 | APInt uintResult; |
635 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
636 | if (!parseResult.has_value() || failed(*parseResult)) |
637 | return parseResult; |
638 | |
639 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
640 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
641 | // zero for non-negated integers. |
642 | result = |
643 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
644 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
645 | return emitError(loc, "integer value too large"); |
646 | return success(); |
647 | } |
648 | |
649 | /// These are the supported delimiters around operand lists and region |
650 | /// argument lists, used by parseOperandList. |
651 | enum class Delimiter { |
652 | /// Zero or more operands with no delimiters. |
653 | None, |
654 | /// Parens surrounding zero or more operands. |
655 | Paren, |
656 | /// Square brackets surrounding zero or more operands. |
657 | Square, |
658 | /// <> brackets surrounding zero or more operands. |
659 | LessGreater, |
660 | /// {} brackets surrounding zero or more operands. |
661 | Braces, |
662 | /// Parens supporting zero or more operands, or nothing. |
663 | OptionalParen, |
664 | /// Square brackets supporting zero or more ops, or nothing. |
665 | OptionalSquare, |
666 | /// <> brackets supporting zero or more ops, or nothing. |
667 | OptionalLessGreater, |
668 | /// {} brackets surrounding zero or more operands, or nothing. |
669 | OptionalBraces, |
670 | }; |
671 | |
672 | /// Parse a list of comma-separated items with an optional delimiter. If a |
673 | /// delimiter is provided, then an empty list is allowed. If not, then at |
674 | /// least one element will be parsed. |
675 | /// |
676 | /// contextMessage is an optional message appended to "expected '('" sorts of |
677 | /// diagnostics when parsing the delimeters. |
678 | virtual ParseResult |
679 | parseCommaSeparatedList(Delimiter delimiter, |
680 | function_ref<ParseResult()> parseElementFn, |
681 | StringRef contextMessage = StringRef()) = 0; |
682 | |
683 | /// Parse a comma separated list of elements that must have at least one entry |
684 | /// in it. |
685 | ParseResult |
686 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
687 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
688 | } |
689 | |
690 | //===--------------------------------------------------------------------===// |
691 | // Keyword Parsing |
692 | //===--------------------------------------------------------------------===// |
693 | |
694 | /// This class represents a StringSwitch like class that is useful for parsing |
695 | /// expected keywords. On construction, it invokes `parseKeyword` and |
696 | /// processes each of the provided cases statements until a match is hit. The |
697 | /// provided `ResultT` must be assignable from `failure()`. |
698 | template <typename ResultT = ParseResult> |
699 | class KeywordSwitch { |
700 | public: |
701 | KeywordSwitch(AsmParser &parser) |
702 | : parser(parser), loc(parser.getCurrentLocation()) { |
703 | if (failed(parser.parseKeywordOrCompletion(&keyword))) |
704 | result = failure(); |
705 | } |
706 | |
707 | /// Case that uses the provided value when true. |
708 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
709 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
710 | } |
711 | KeywordSwitch &Default(ResultT value) { |
712 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
713 | } |
714 | /// Case that invokes the provided functor when true. The parameters passed |
715 | /// to the functor are the keyword, and the location of the keyword (in case |
716 | /// any errors need to be emitted). |
717 | template <typename FnT> |
718 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
719 | Case(StringLiteral str, FnT &&fn) { |
720 | if (result) |
721 | return *this; |
722 | |
723 | // If the word was empty, record this as a completion. |
724 | if (keyword.empty()) |
725 | parser.codeCompleteExpectedTokens(str); |
726 | else if (keyword == str) |
727 | result.emplace(std::move(fn(keyword, loc))); |
728 | return *this; |
729 | } |
730 | template <typename FnT> |
731 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
732 | Default(FnT &&fn) { |
733 | if (!result) |
734 | result.emplace(fn(keyword, loc)); |
735 | return *this; |
736 | } |
737 | |
738 | /// Returns true if this switch has a value yet. |
739 | bool hasValue() const { return result.has_value(); } |
740 | |
741 | /// Return the result of the switch. |
742 | [[nodiscard]] operator ResultT() { |
743 | if (!result) |
744 | return parser.emitError(loc, "unexpected keyword: ") << keyword; |
745 | return std::move(*result); |
746 | } |
747 | |
748 | private: |
749 | /// The parser used to construct this switch. |
750 | AsmParser &parser; |
751 | |
752 | /// The location of the keyword, used to emit errors as necessary. |
753 | SMLoc loc; |
754 | |
755 | /// The parsed keyword itself. |
756 | StringRef keyword; |
757 | |
758 | /// The result of the switch statement or none if currently unknown. |
759 | Optional<ResultT> result; |
760 | }; |
761 | |
762 | /// Parse a given keyword. |
763 | ParseResult parseKeyword(StringRef keyword) { |
764 | return parseKeyword(keyword, ""); |
765 | } |
766 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
767 | |
768 | /// Parse a keyword into 'keyword'. |
769 | ParseResult parseKeyword(StringRef *keyword) { |
770 | auto loc = getCurrentLocation(); |
771 | if (parseOptionalKeyword(keyword)) |
772 | return emitError(loc, "expected valid keyword"); |
773 | return success(); |
774 | } |
775 | |
776 | /// Parse the given keyword if present. |
777 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
778 | |
779 | /// Parse a keyword, if present, into 'keyword'. |
780 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
781 | |
782 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
783 | /// into 'keyword' |
784 | virtual ParseResult |
785 | parseOptionalKeyword(StringRef *keyword, |
786 | ArrayRef<StringRef> allowedValues) = 0; |
787 | |
788 | /// Parse a keyword or a quoted string. |
789 | ParseResult parseKeywordOrString(std::string *result) { |
790 | if (failed(parseOptionalKeywordOrString(result))) |
791 | return emitError(getCurrentLocation()) |
792 | << "expected valid keyword or string"; |
793 | return success(); |
794 | } |
795 | |
796 | /// Parse an optional keyword or string. |
797 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
798 | |
799 | //===--------------------------------------------------------------------===// |
800 | // Attribute/Type Parsing |
801 | //===--------------------------------------------------------------------===// |
802 | |
803 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
804 | /// the provided location to emit errors in the case of failure. Note that |
805 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
806 | /// context parameter. |
807 | template <typename T, typename... ParamsT> |
808 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
809 | return T::getChecked([&] { return emitError(loc); }, |
810 | std::forward<ParamsT>(params)...); |
811 | } |
812 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
813 | /// errors. |
814 | template <typename T, typename... ParamsT> |
815 | auto getChecked(ParamsT &&...params) { |
816 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
817 | std::forward<ParamsT>(params)...); |
818 | } |
819 | |
820 | //===--------------------------------------------------------------------===// |
821 | // Attribute Parsing |
822 | //===--------------------------------------------------------------------===// |
823 | |
824 | /// Parse an arbitrary attribute of a given type and return it in result. |
825 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
826 | |
827 | /// Parse a custom attribute with the provided callback, unless the next |
828 | /// token is `#`, in which case the generic parser is invoked. |
829 | virtual ParseResult parseCustomAttributeWithFallback( |
830 | Attribute &result, Type type, |
831 | function_ref<ParseResult(Attribute &result, Type type)> |
832 | parseAttribute) = 0; |
833 | |
834 | /// Parse an attribute of a specific kind and type. |
835 | template <typename AttrType> |
836 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
837 | SMLoc loc = getCurrentLocation(); |
838 | |
839 | // Parse any kind of attribute. |
840 | Attribute attr; |
841 | if (parseAttribute(attr, type)) |
842 | return failure(); |
843 | |
844 | // Check for the right kind of attribute. |
845 | if (!(result = attr.dyn_cast<AttrType>())) |
846 | return emitError(loc, "invalid kind of attribute specified"); |
847 | |
848 | return success(); |
849 | } |
850 | |
851 | /// Parse an arbitrary attribute and return it in result. This also adds the |
852 | /// attribute to the specified attribute list with the specified name. |
853 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
854 | NamedAttrList &attrs) { |
855 | return parseAttribute(result, Type(), attrName, attrs); |
856 | } |
857 | |
858 | /// Parse an attribute of a specific kind and type. |
859 | template <typename AttrType> |
860 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
861 | NamedAttrList &attrs) { |
862 | return parseAttribute(result, Type(), attrName, attrs); |
863 | } |
864 | |
865 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
866 | /// This also adds the attribute to the specified attribute list with the |
867 | /// specified name. |
868 | template <typename AttrType> |
869 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
870 | NamedAttrList &attrs) { |
871 | SMLoc loc = getCurrentLocation(); |
872 | |
873 | // Parse any kind of attribute. |
874 | Attribute attr; |
875 | if (parseAttribute(attr, type)) |
876 | return failure(); |
877 | |
878 | // Check for the right kind of attribute. |
879 | result = attr.dyn_cast<AttrType>(); |
880 | if (!result) |
881 | return emitError(loc, "invalid kind of attribute specified"); |
882 | |
883 | attrs.append(attrName, result); |
884 | return success(); |
885 | } |
886 | |
887 | /// Trait to check if `AttrType` provides a `parse` method. |
888 | template <typename AttrType> |
889 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
890 | std::declval<Type>())); |
891 | template <typename AttrType> |
892 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
893 | |
894 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
895 | /// which case the generic parser is invoked. The parsed attribute is |
896 | /// populated in `result` and also added to the specified attribute list with |
897 | /// the specified name. |
898 | template <typename AttrType> |
899 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
900 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
901 | StringRef attrName, NamedAttrList &attrs) { |
902 | SMLoc loc = getCurrentLocation(); |
903 | |
904 | // Parse any kind of attribute. |
905 | Attribute attr; |
906 | if (parseCustomAttributeWithFallback( |
907 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
908 | result = AttrType::parse(*this, type); |
909 | if (!result) |
910 | return failure(); |
911 | return success(); |
912 | })) |
913 | return failure(); |
914 | |
915 | // Check for the right kind of attribute. |
916 | result = attr.dyn_cast<AttrType>(); |
917 | if (!result) |
918 | return emitError(loc, "invalid kind of attribute specified"); |
919 | |
920 | attrs.append(attrName, result); |
921 | return success(); |
922 | } |
923 | |
924 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
925 | template <typename AttrType> |
926 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
927 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
928 | StringRef attrName, NamedAttrList &attrs) { |
929 | return parseAttribute(result, type, attrName, attrs); |
930 | } |
931 | |
932 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
933 | /// which case the generic parser is invoked. The parsed attribute is |
934 | /// populated in `result`. |
935 | template <typename AttrType> |
936 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
937 | parseCustomAttributeWithFallback(AttrType &result) { |
938 | SMLoc loc = getCurrentLocation(); |
939 | |
940 | // Parse any kind of attribute. |
941 | Attribute attr; |
942 | if (parseCustomAttributeWithFallback( |
943 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
944 | result = AttrType::parse(*this, type); |
945 | return success(!!result); |
946 | })) |
947 | return failure(); |
948 | |
949 | // Check for the right kind of attribute. |
950 | result = attr.dyn_cast<AttrType>(); |
951 | if (!result) |
952 | return emitError(loc, "invalid kind of attribute specified"); |
953 | return success(); |
954 | } |
955 | |
956 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
957 | template <typename AttrType> |
958 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
959 | parseCustomAttributeWithFallback(AttrType &result) { |
960 | return parseAttribute(result); |
961 | } |
962 | |
963 | /// Parse an arbitrary optional attribute of a given type and return it in |
964 | /// result. |
965 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
966 | Type type = {}) = 0; |
967 | |
968 | /// Parse an optional array attribute and return it in result. |
969 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
970 | Type type = {}) = 0; |
971 | |
972 | /// Parse an optional string attribute and return it in result. |
973 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
974 | Type type = {}) = 0; |
975 | |
976 | /// Parse an optional symbol ref attribute and return it in result. |
977 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, |
978 | Type type = {}) = 0; |
979 | |
980 | /// Parse an optional attribute of a specific type and add it to the list with |
981 | /// the specified name. |
982 | template <typename AttrType> |
983 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
984 | StringRef attrName, |
985 | NamedAttrList &attrs) { |
986 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
987 | } |
988 | |
989 | /// Parse an optional attribute of a specific type and add it to the list with |
990 | /// the specified name. |
991 | template <typename AttrType> |
992 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
993 | StringRef attrName, |
994 | NamedAttrList &attrs) { |
995 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
996 | if (parseResult.has_value() && succeeded(*parseResult)) |
997 | attrs.append(attrName, result); |
998 | return parseResult; |
999 | } |
1000 | |
1001 | /// Parse a named dictionary into 'result' if it is present. |
1002 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
1003 | |
1004 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
1005 | /// present. |
1006 | virtual ParseResult |
1007 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
1008 | |
1009 | /// Parse an affine map instance into 'map'. |
1010 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
1011 | |
1012 | /// Parse an integer set instance into 'set'. |
1013 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
1014 | |
1015 | //===--------------------------------------------------------------------===// |
1016 | // Identifier Parsing |
1017 | //===--------------------------------------------------------------------===// |
1018 | |
1019 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1020 | /// attribute. |
1021 | ParseResult parseSymbolName(StringAttr &result) { |
1022 | if (failed(parseOptionalSymbolName(result))) |
1023 | return emitError(getCurrentLocation()) |
1024 | << "expected valid '@'-identifier for symbol name"; |
1025 | return success(); |
1026 | } |
1027 | |
1028 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1029 | /// attribute named 'attrName'. |
1030 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1031 | NamedAttrList &attrs) { |
1032 | if (parseSymbolName(result)) |
1033 | return failure(); |
1034 | attrs.append(attrName, result); |
1035 | return success(); |
1036 | } |
1037 | |
1038 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1039 | /// string attribute. |
1040 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; |
1041 | |
1042 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1043 | /// string attribute named 'attrName'. |
1044 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
1045 | NamedAttrList &attrs) { |
1046 | if (succeeded(parseOptionalSymbolName(result))) { |
1047 | attrs.append(attrName, result); |
1048 | return success(); |
1049 | } |
1050 | return failure(); |
1051 | } |
1052 | |
1053 | //===--------------------------------------------------------------------===// |
1054 | // Resource Parsing |
1055 | //===--------------------------------------------------------------------===// |
1056 | |
1057 | /// Parse a handle to a resource within the assembly format. |
1058 | template <typename ResourceT> |
1059 | FailureOr<ResourceT> parseResourceHandle() { |
1060 | SMLoc handleLoc = getCurrentLocation(); |
1061 | |
1062 | // Try to load the dialect that owns the handle. |
1063 | auto *dialect = |
1064 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1065 | if (!dialect) { |
1066 | return emitError(handleLoc) |
1067 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1068 | << "' is unknown"; |
1069 | } |
1070 | |
1071 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1072 | if (failed(handle)) |
1073 | return failure(); |
1074 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1075 | return std::move(*result); |
1076 | return emitError(handleLoc) << "provided resource handle differs from the " |
1077 | "expected resource type"; |
1078 | } |
1079 | |
1080 | //===--------------------------------------------------------------------===// |
1081 | // Type Parsing |
1082 | //===--------------------------------------------------------------------===// |
1083 | |
1084 | /// Parse a type. |
1085 | virtual ParseResult parseType(Type &result) = 0; |
1086 | |
1087 | /// Parse a custom type with the provided callback, unless the next |
1088 | /// token is `#`, in which case the generic parser is invoked. |
1089 | virtual ParseResult parseCustomTypeWithFallback( |
1090 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1091 | |
1092 | /// Parse an optional type. |
1093 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1094 | |
1095 | /// Parse a type of a specific type. |
1096 | template <typename TypeT> |
1097 | ParseResult parseType(TypeT &result) { |
1098 | SMLoc loc = getCurrentLocation(); |
1099 | |
1100 | // Parse any kind of type. |
1101 | Type type; |
1102 | if (parseType(type)) |
1103 | return failure(); |
1104 | |
1105 | // Check for the right kind of type. |
1106 | result = type.dyn_cast<TypeT>(); |
1107 | if (!result) |
1108 | return emitError(loc, "invalid kind of type specified"); |
1109 | |
1110 | return success(); |
1111 | } |
1112 | |
1113 | /// Trait to check if `TypeT` provides a `parse` method. |
1114 | template <typename TypeT> |
1115 | using type_has_parse_method = |
1116 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1117 | template <typename TypeT> |
1118 | using detect_type_has_parse_method = |
1119 | llvm::is_detected<type_has_parse_method, TypeT>; |
1120 | |
1121 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1122 | /// which case the generic parser is invoked. The parsed Type is |
1123 | /// populated in `result`. |
1124 | template <typename TypeT> |
1125 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1126 | parseCustomTypeWithFallback(TypeT &result) { |
1127 | SMLoc loc = getCurrentLocation(); |
1128 | |
1129 | // Parse any kind of Type. |
1130 | Type type; |
1131 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1132 | result = TypeT::parse(*this); |
1133 | return success(!!result); |
1134 | })) |
1135 | return failure(); |
1136 | |
1137 | // Check for the right kind of Type. |
1138 | result = type.dyn_cast<TypeT>(); |
1139 | if (!result) |
1140 | return emitError(loc, "invalid kind of Type specified"); |
1141 | return success(); |
1142 | } |
1143 | |
1144 | /// SFINAE parsing method for Type that don't implement a parse method. |
1145 | template <typename TypeT> |
1146 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1147 | parseCustomTypeWithFallback(TypeT &result) { |
1148 | return parseType(result); |
1149 | } |
1150 | |
1151 | /// Parse a type list. |
1152 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
1153 | return parseCommaSeparatedList( |
1154 | [&]() { return parseType(result.emplace_back()); }); |
1155 | } |
1156 | |
1157 | /// Parse an arrow followed by a type list. |
1158 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1159 | |
1160 | /// Parse an optional arrow followed by a type list. |
1161 | virtual ParseResult |
1162 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1163 | |
1164 | /// Parse a colon followed by a type. |
1165 | virtual ParseResult parseColonType(Type &result) = 0; |
1166 | |
1167 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1168 | template <typename TypeType> |
1169 | ParseResult parseColonType(TypeType &result) { |
1170 | SMLoc loc = getCurrentLocation(); |
1171 | |
1172 | // Parse any kind of type. |
1173 | Type type; |
1174 | if (parseColonType(type)) |
1175 | return failure(); |
1176 | |
1177 | // Check for the right kind of type. |
1178 | result = type.dyn_cast<TypeType>(); |
1179 | if (!result) |
1180 | return emitError(loc, "invalid kind of type specified"); |
1181 | |
1182 | return success(); |
1183 | } |
1184 | |
1185 | /// Parse a colon followed by a type list, which must have at least one type. |
1186 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1187 | |
1188 | /// Parse an optional colon followed by a type list, which if present must |
1189 | /// have at least one type. |
1190 | virtual ParseResult |
1191 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1192 | |
1193 | /// Parse a keyword followed by a type. |
1194 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1195 | return failure(parseKeyword(keyword) || parseType(result)); |
1196 | } |
1197 | |
1198 | /// Add the specified type to the end of the specified type list and return |
1199 | /// success. This is a helper designed to allow parse methods to be simple |
1200 | /// and chain through || operators. |
1201 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1202 | result.push_back(type); |
1203 | return success(); |
1204 | } |
1205 | |
1206 | /// Add the specified types to the end of the specified type list and return |
1207 | /// success. This is a helper designed to allow parse methods to be simple |
1208 | /// and chain through || operators. |
1209 | ParseResult addTypesToList(ArrayRef<Type> types, |
1210 | SmallVectorImpl<Type> &result) { |
1211 | result.append(types.begin(), types.end()); |
1212 | return success(); |
1213 | } |
1214 | |
1215 | /// Parse a dimension list of a tensor or memref type. This populates the |
1216 | /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set |
1217 | /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable. |
1218 | /// |
1219 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1220 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1221 | /// dimension ::= `?` | decimal-literal |
1222 | /// |
1223 | /// When `allowDynamic` is not set, this is used to parse: |
1224 | /// |
1225 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1226 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1227 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1228 | bool allowDynamic = true, |
1229 | bool withTrailingX = true) = 0; |
1230 | |
1231 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1232 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1233 | /// next token. |
1234 | virtual ParseResult parseXInDimensionList() = 0; |
1235 | |
1236 | protected: |
1237 | /// Parse a handle to a resource within the assembly format for the given |
1238 | /// dialect. |
1239 | virtual FailureOr<AsmDialectResourceHandle> |
1240 | parseResourceHandle(Dialect *dialect) = 0; |
1241 | |
1242 | //===--------------------------------------------------------------------===// |
1243 | // Code Completion |
1244 | //===--------------------------------------------------------------------===// |
1245 | |
1246 | /// Parse a keyword, or an empty string if the current location signals a code |
1247 | /// completion. |
1248 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1249 | |
1250 | /// Signal the code completion of a set of expected tokens. |
1251 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1252 | |
1253 | private: |
1254 | AsmParser(const AsmParser &) = delete; |
1255 | void operator=(const AsmParser &) = delete; |
1256 | }; |
1257 | |
1258 | //===----------------------------------------------------------------------===// |
1259 | // OpAsmParser |
1260 | //===----------------------------------------------------------------------===// |
1261 | |
1262 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1263 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1264 | /// that is designed to reduce/constrain syntax innovation in individual |
1265 | /// operations. |
1266 | /// |
1267 | /// For example, consider an op like this: |
1268 | /// |
1269 | /// %x = load %p[%1, %2] : memref<...> |
1270 | /// |
1271 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1272 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1273 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1274 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1275 | /// type. |
1276 | /// |
1277 | class OpAsmParser : public AsmParser { |
1278 | public: |
1279 | using AsmParser::AsmParser; |
1280 | ~OpAsmParser() override; |
1281 | |
1282 | /// Parse a loc(...) specifier if present, filling in result if so. |
1283 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1284 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1285 | /// completes. |
1286 | virtual ParseResult |
1287 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1288 | |
1289 | /// Return the name of the specified result in the specified syntax, as well |
1290 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1291 | /// invalid result numbers. For example, in this operation: |
1292 | /// |
1293 | /// %x, %y:2, %z = foo.op |
1294 | /// |
1295 | /// getResultName(0) == {"x", 0 } |
1296 | /// getResultName(1) == {"y", 0 } |
1297 | /// getResultName(2) == {"y", 1 } |
1298 | /// getResultName(3) == {"z", 0 } |
1299 | /// getResultName(4) == {"", ~0U } |
1300 | virtual std::pair<StringRef, unsigned> |
1301 | getResultName(unsigned resultNo) const = 0; |
1302 | |
1303 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1304 | /// example in the comment for `getResultName`. |
1305 | virtual size_t getNumResults() const = 0; |
1306 | |
1307 | // These methods emit an error and return failure or success. This allows |
1308 | // these to be chained together into a linear sequence of || expressions in |
1309 | // many cases. |
1310 | |
1311 | /// Parse an operation in its generic form. |
1312 | /// The parsed operation is parsed in the current context and inserted in the |
1313 | /// provided block and insertion point. The results produced by this operation |
1314 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1315 | /// failure. |
1316 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1317 | Block::iterator insertPt) = 0; |
1318 | |
1319 | /// Parse the name of an operation, in the custom form. On success, return a |
1320 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1321 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1322 | |
1323 | //===--------------------------------------------------------------------===// |
1324 | // Operand Parsing |
1325 | //===--------------------------------------------------------------------===// |
1326 | |
1327 | /// This is the representation of an operand reference. |
1328 | struct UnresolvedOperand { |
1329 | SMLoc location; // Location of the token. |
1330 | StringRef name; // Value name, e.g. %42 or %abc |
1331 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1332 | }; |
1333 | |
1334 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1335 | /// region(s), attribute(s) and function-type, of the generic form of an |
1336 | /// operation instance and populate the input operation-state 'result' with |
1337 | /// those components. If any of the components is explicitly provided, then |
1338 | /// skip parsing that component. |
1339 | virtual ParseResult parseGenericOperationAfterOpName( |
1340 | OperationState &result, |
1341 | Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt, |
1342 | Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, |
1343 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1344 | std::nullopt, |
1345 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, |
1346 | Optional<FunctionType> parsedFnType = std::nullopt) = 0; |
1347 | |
1348 | /// Parse a single SSA value operand name along with a result number if |
1349 | /// `allowResultNumber` is true. |
1350 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1351 | bool allowResultNumber = true) = 0; |
1352 | |
1353 | /// Parse a single operand if present. |
1354 | virtual OptionalParseResult |
1355 | parseOptionalOperand(UnresolvedOperand &result, |
1356 | bool allowResultNumber = true) = 0; |
1357 | |
1358 | /// Parse zero or more SSA comma-separated operand references with a specified |
1359 | /// surrounding delimiter, and an optional required operand count. |
1360 | virtual ParseResult |
1361 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1362 | Delimiter delimiter = Delimiter::None, |
1363 | bool allowResultNumber = true, |
1364 | int requiredOperandCount = -1) = 0; |
1365 | |
1366 | /// Parse a specified number of comma separated operands. |
1367 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1368 | int requiredOperandCount, |
1369 | Delimiter delimiter = Delimiter::None) { |
1370 | return parseOperandList(result, delimiter, |
1371 | /*allowResultNumber=*/true, requiredOperandCount); |
1372 | } |
1373 | |
1374 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1375 | /// references with a specified surrounding delimiter, and an optional |
1376 | /// required operand count. A leading comma is expected before the |
1377 | /// operands. |
1378 | ParseResult |
1379 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1380 | Delimiter delimiter = Delimiter::None) { |
1381 | if (failed(parseOptionalComma())) |
1382 | return success(); // The comma is optional. |
1383 | return parseOperandList(result, delimiter); |
1384 | } |
1385 | |
1386 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1387 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1388 | Type type, |
1389 | SmallVectorImpl<Value> &result) = 0; |
1390 | |
1391 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1392 | /// appending the results to the list on success. This method should be used |
1393 | /// when all operands have the same type. |
1394 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1395 | ParseResult resolveOperands(Operands &&operands, Type type, |
1396 | SmallVectorImpl<Value> &result) { |
1397 | for (const UnresolvedOperand &operand : operands) |
1398 | if (resolveOperand(operand, type, result)) |
1399 | return failure(); |
1400 | return success(); |
1401 | } |
1402 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1403 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1404 | SmallVectorImpl<Value> &result) { |
1405 | return resolveOperands(std::forward<Operands>(operands), type, result); |
1406 | } |
1407 | |
1408 | /// Resolve a list of operands and a list of operand types to SSA values, |
1409 | /// emitting an error and returning failure, or appending the results |
1410 | /// to the list on success. |
1411 | template <typename Operands = ArrayRef<UnresolvedOperand>, |
1412 | typename Types = ArrayRef<Type>> |
1413 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1414 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1415 | SmallVectorImpl<Value> &result) { |
1416 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1417 | size_t typeSize = std::distance(types.begin(), types.end()); |
1418 | if (operandSize != typeSize) |
1419 | return emitError(loc) |
1420 | << operandSize << " operands present, but expected " << typeSize; |
1421 | |
1422 | for (auto [operand, type] : llvm::zip(operands, types)) |
1423 | if (resolveOperand(operand, type, result)) |
1424 | return failure(); |
1425 | return success(); |
1426 | } |
1427 | |
1428 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1429 | /// Operand values must come from single-result sources, and be valid |
1430 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1431 | virtual ParseResult |
1432 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1433 | Attribute &map, StringRef attrName, |
1434 | NamedAttrList &attrs, |
1435 | Delimiter delimiter = Delimiter::Square) = 0; |
1436 | |
1437 | /// Parses an affine expression where dims and symbols are SSA operands. |
1438 | /// Operand values must come from single-result sources, and be valid |
1439 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1440 | virtual ParseResult |
1441 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1442 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1443 | AffineExpr &expr) = 0; |
1444 | |
1445 | //===--------------------------------------------------------------------===// |
1446 | // Argument Parsing |
1447 | //===--------------------------------------------------------------------===// |
1448 | |
1449 | struct Argument { |
1450 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1451 | Type type; // Type. |
1452 | DictionaryAttr attrs; // Attributes if present. |
1453 | Optional<Location> sourceLoc; // Source location specifier if present. |
1454 | }; |
1455 | |
1456 | /// Parse a single argument with the following syntax: |
1457 | /// |
1458 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1459 | /// |
1460 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1461 | /// parts of the grammar are not parsed. |
1462 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1463 | bool allowAttrs = false) = 0; |
1464 | |
1465 | /// Parse a single argument if present. |
1466 | virtual OptionalParseResult |
1467 | parseOptionalArgument(Argument &result, bool allowType = false, |
1468 | bool allowAttrs = false) = 0; |
1469 | |
1470 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1471 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1472 | Delimiter delimiter = Delimiter::None, |
1473 | bool allowType = false, |
1474 | bool allowAttrs = false) = 0; |
1475 | |
1476 | //===--------------------------------------------------------------------===// |
1477 | // Region Parsing |
1478 | //===--------------------------------------------------------------------===// |
1479 | |
1480 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1481 | /// moved to the op regions after the op is created. The first block of the |
1482 | /// region takes 'arguments'. |
1483 | /// |
1484 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1485 | /// shadow the names of other existing SSA values defined above the region |
1486 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1487 | /// to operations that are 'IsolatedFromAbove'. |
1488 | virtual ParseResult parseRegion(Region ®ion, |
1489 | ArrayRef<Argument> arguments = {}, |
1490 | bool enableNameShadowing = false) = 0; |
1491 | |
1492 | /// Parses a region if present. |
1493 | virtual OptionalParseResult |
1494 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1495 | bool enableNameShadowing = false) = 0; |
1496 | |
1497 | /// Parses a region if present. If the region is present, a new region is |
1498 | /// allocated and placed in `region`. If no region is present or on failure, |
1499 | /// `region` remains untouched. |
1500 | virtual OptionalParseResult |
1501 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1502 | ArrayRef<Argument> arguments = {}, |
1503 | bool enableNameShadowing = false) = 0; |
1504 | |
1505 | //===--------------------------------------------------------------------===// |
1506 | // Successor Parsing |
1507 | //===--------------------------------------------------------------------===// |
1508 | |
1509 | /// Parse a single operation successor. |
1510 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1511 | |
1512 | /// Parse an optional operation successor. |
1513 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1514 | |
1515 | /// Parse a single operation successor and its operand list. |
1516 | virtual ParseResult |
1517 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1518 | |
1519 | //===--------------------------------------------------------------------===// |
1520 | // Type Parsing |
1521 | //===--------------------------------------------------------------------===// |
1522 | |
1523 | /// Parse a list of assignments of the form |
1524 | /// (%x1 = %y1, %x2 = %y2, ...) |
1525 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1526 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1527 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1528 | if (!result.has_value()) |
1529 | return emitError(getCurrentLocation(), "expected '('"); |
1530 | return result.value(); |
1531 | } |
1532 | |
1533 | virtual OptionalParseResult |
1534 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1535 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1536 | }; |
1537 | |
1538 | //===--------------------------------------------------------------------===// |
1539 | // Dialect OpAsm interface. |
1540 | //===--------------------------------------------------------------------===// |
1541 | |
1542 | /// A functor used to set the name of the start of a result group of an |
1543 | /// operation. See 'getAsmResultNames' below for more details. |
1544 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1545 | |
1546 | /// A functor used to set the name of blocks in regions directly nested under |
1547 | /// an operation. |
1548 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1549 | |
1550 | class OpAsmDialectInterface |
1551 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1552 | public: |
1553 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1554 | |
1555 | //===------------------------------------------------------------------===// |
1556 | // Aliases |
1557 | //===------------------------------------------------------------------===// |
1558 | |
1559 | /// Holds the result of `getAlias` hook call. |
1560 | enum class AliasResult { |
1561 | /// The object (type or attribute) is not supported by the hook |
1562 | /// and an alias was not provided. |
1563 | NoAlias, |
1564 | /// An alias was provided, but it might be overriden by other hook. |
1565 | OverridableAlias, |
1566 | /// An alias was provided and it should be used |
1567 | /// (no other hooks will be checked). |
1568 | FinalAlias |
1569 | }; |
1570 | |
1571 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1572 | /// not necessarily a part of this dialect. The identifier is used in place of |
1573 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1574 | /// end with a numeric digit([0-9]+). |
1575 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1576 | return AliasResult::NoAlias; |
1577 | } |
1578 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1579 | return AliasResult::NoAlias; |
1580 | } |
1581 | |
1582 | //===--------------------------------------------------------------------===// |
1583 | // Resources |
1584 | //===--------------------------------------------------------------------===// |
1585 | |
1586 | /// Declare a resource with the given key, returning a handle to use for any |
1587 | /// references of this resource key within the IR during parsing. The result |
1588 | /// of `getResourceKey` on the returned handle is permitted to be different |
1589 | /// than `key`. |
1590 | virtual FailureOr<AsmDialectResourceHandle> |
1591 | declareResource(StringRef key) const { |
1592 | return failure(); |
1593 | } |
1594 | |
1595 | /// Return a key to use for the given resource. This key should uniquely |
1596 | /// identify this resource within the dialect. |
1597 | virtual std::string |
1598 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1599 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600) |
1600 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600); |
1601 | } |
1602 | |
1603 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1604 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1605 | /// can be emitted via the provided entry. |
1606 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1607 | |
1608 | /// Hook for building resources to use during printing. The given `op` may be |
1609 | /// inspected to help determine what information to include. |
1610 | /// `referencedResources` contains all of the resources detected when printing |
1611 | /// 'op'. |
1612 | virtual void |
1613 | buildResources(Operation *op, |
1614 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1615 | AsmResourceBuilder &builder) const {} |
1616 | }; |
1617 | } // namespace mlir |
1618 | |
1619 | //===--------------------------------------------------------------------===// |
1620 | // Operation OpAsm interface. |
1621 | //===--------------------------------------------------------------------===// |
1622 | |
1623 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1624 | #include "mlir/IR/OpAsmInterface.h.inc" |
1625 | |
1626 | namespace llvm { |
1627 | template <> |
1628 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1629 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1630 | return {DenseMapInfo<void *>::getEmptyKey(), |
1631 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1632 | } |
1633 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1634 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1635 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1636 | } |
1637 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1638 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1639 | } |
1640 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1641 | const mlir::AsmDialectResourceHandle &rhs) { |
1642 | return lhs.getResource() == rhs.getResource(); |
1643 | } |
1644 | }; |
1645 | } // namespace llvm |
1646 | |
1647 | #endif |