Bug Summary

File:build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Warning:line 4929, column 23
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name VectorOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Vector/IR -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Vector/IR -I include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/llvm/include -I /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-10-03-140002-15933-1 -x c++ /build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Vector/IR/VectorOps.cpp

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/lib/Dialect/Vector/IR/VectorOps.cpp

1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/BlockAndValueMapping.h"
25#include "mlir/IR/Builders.h"
26#include "mlir/IR/BuiltinOps.h"
27#include "mlir/IR/BuiltinTypes.h"
28#include "mlir/IR/DialectImplementation.h"
29#include "mlir/IR/OpImplementation.h"
30#include "mlir/IR/PatternMatch.h"
31#include "mlir/IR/TypeUtilities.h"
32#include "mlir/Support/LLVM.h"
33#include "llvm/ADT/StringSet.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/ADT/bit.h"
36#include <numeric>
37
38#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
39// Pull in all enum type and utility function definitions.
40#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
41
42using namespace mlir;
43using namespace mlir::vector;
44
45/// Helper enum to classify mask value.
46enum class MaskFormat {
47 AllTrue = 0,
48 AllFalse = 1,
49 Unknown = 2,
50};
51
52/// Helper method to classify a mask value. Currently, the method
53/// looks "under the hood" of a constant value with dense attributes
54/// and a constant mask operation (since the client may be called at
55/// various stages during progressive lowering).
56static MaskFormat getMaskFormat(Value mask) {
57 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
58 // Inspect constant dense values. We count up for bits that
59 // are set, count down for bits that are cleared, and bail
60 // when a mix is detected.
61 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
62 int64_t val = 0;
63 for (bool b : denseElts.getValues<bool>())
64 if (b && val >= 0)
65 val++;
66 else if (!b && val <= 0)
67 val--;
68 else
69 return MaskFormat::Unknown;
70 if (val > 0)
71 return MaskFormat::AllTrue;
72 if (val < 0)
73 return MaskFormat::AllFalse;
74 }
75 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
76 // Inspect constant mask index. If the index exceeds the
77 // dimension size, all bits are set. If the index is zero
78 // or less, no bits are set.
79 ArrayAttr masks = m.getMaskDimSizes();
80 auto shape = m.getType().getShape();
81 bool allTrue = true;
82 bool allFalse = true;
83 for (auto pair : llvm::zip(masks, shape)) {
84 int64_t i = std::get<0>(pair).cast<IntegerAttr>().getInt();
85 int64_t u = std::get<1>(pair);
86 if (i < u)
87 allTrue = false;
88 if (i > 0)
89 allFalse = false;
90 }
91 if (allTrue)
92 return MaskFormat::AllTrue;
93 if (allFalse)
94 return MaskFormat::AllFalse;
95 }
96 return MaskFormat::Unknown;
97}
98
99// Helper for verifying combining kinds in contractions and reductions.
100static bool isSupportedCombiningKind(CombiningKind combiningKind,
101 Type elementType) {
102 switch (combiningKind) {
103 case CombiningKind::ADD:
104 case CombiningKind::MUL:
105 return elementType.isIntOrIndexOrFloat();
106 case CombiningKind::MINUI:
107 case CombiningKind::MINSI:
108 case CombiningKind::MAXUI:
109 case CombiningKind::MAXSI:
110 case CombiningKind::AND:
111 case CombiningKind::OR:
112 case CombiningKind::XOR:
113 return elementType.isIntOrIndex();
114 case CombiningKind::MINF:
115 case CombiningKind::MAXF:
116 return elementType.isa<FloatType>();
117 }
118 return false;
119}
120
121/// Return true if the last dimension of the MemRefType has unit stride. Also
122/// return true for memrefs with no strides.
123bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
124 int64_t offset;
125 SmallVector<int64_t> strides;
126 auto successStrides = getStridesAndOffset(type, strides, offset);
127 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
128}
129
130AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
131 VectorType vectorType) {
132 int64_t elementVectorRank = 0;
133 VectorType elementVectorType =
134 shapedType.getElementType().dyn_cast<VectorType>();
135 if (elementVectorType)
136 elementVectorRank += elementVectorType.getRank();
137 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
138 // TODO: replace once we have 0-d vectors.
139 if (shapedType.getRank() == 0 &&
140 vectorType.getShape() == ArrayRef<int64_t>{1})
141 return AffineMap::get(
142 /*numDims=*/0, /*numSymbols=*/0,
143 getAffineConstantExpr(0, shapedType.getContext()));
144 return AffineMap::getMinorIdentityMap(
145 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
146 shapedType.getContext());
147}
148
149bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
150 vector::TransferReadOp read) {
151 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
152 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
153 defWrite.getVectorType() == read.getVectorType() &&
154 defWrite.getPermutationMap() == read.getPermutationMap();
155}
156
157bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
158 vector::TransferWriteOp priorWrite) {
159 return priorWrite.getIndices() == write.getIndices() &&
160 priorWrite.getMask() == write.getMask() &&
161 priorWrite.getVectorType() == write.getVectorType() &&
162 priorWrite.getPermutationMap() == write.getPermutationMap();
163}
164
165bool mlir::vector::isDisjointTransferIndices(
166 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
167 // For simplicity only look at transfer of same type.
168 if (transferA.getVectorType() != transferB.getVectorType())
169 return false;
170 unsigned rankOffset = transferA.getLeadingShapedRank();
171 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
172 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
173 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
174 // If any of the indices are dynamic we cannot prove anything.
175 if (!indexA || !indexB)
176 continue;
177
178 if (i < rankOffset) {
179 // For leading dimensions, if we can prove that index are different we
180 // know we are accessing disjoint slices.
181 if (indexA.getValue().cast<IntegerAttr>().getInt() !=
182 indexB.getValue().cast<IntegerAttr>().getInt())
183 return true;
184 } else {
185 // For this dimension, we slice a part of the memref we need to make sure
186 // the intervals accessed don't overlap.
187 int64_t distance =
188 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
189 indexB.getValue().cast<IntegerAttr>().getInt());
190 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
191 return true;
192 }
193 }
194 return false;
195}
196
197bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
198 VectorTransferOpInterface transferB) {
199 if (transferA.source() != transferB.source())
200 return false;
201 return isDisjointTransferIndices(transferA, transferB);
202}
203
204//===----------------------------------------------------------------------===//
205// CombiningKindAttr
206//===----------------------------------------------------------------------===//
207
208namespace mlir {
209namespace vector {
210namespace detail {
211struct BitmaskEnumStorage : public AttributeStorage {
212 using KeyTy = uint64_t;
213
214 BitmaskEnumStorage(KeyTy val) : value(val) {}
215
216 bool operator==(const KeyTy &key) const { return value == key; }
217
218 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
219 const KeyTy &key) {
220 return new (allocator.allocate<BitmaskEnumStorage>())
221 BitmaskEnumStorage(key);
222 }
223
224 KeyTy value = 0;
225};
226} // namespace detail
227} // namespace vector
228} // namespace mlir
229
230//===----------------------------------------------------------------------===//
231// VectorDialect
232//===----------------------------------------------------------------------===//
233
234void VectorDialect::initialize() {
235 addAttributes<
236#define GET_ATTRDEF_LIST
237#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
238 >();
239
240 addOperations<
241#define GET_OP_LIST
242#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
243 >();
244}
245
246/// Materialize a single constant operation from a given attribute value with
247/// the desired resultant type.
248Operation *VectorDialect::materializeConstant(OpBuilder &builder,
249 Attribute value, Type type,
250 Location loc) {
251 return builder.create<arith::ConstantOp>(loc, type, value);
252}
253
254IntegerType vector::getVectorSubscriptType(Builder &builder) {
255 return builder.getIntegerType(64);
256}
257
258ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
259 ArrayRef<int64_t> values) {
260 return builder.getI64ArrayAttr(values);
261}
262
263//===----------------------------------------------------------------------===//
264// MultiDimReductionOp
265//===----------------------------------------------------------------------===//
266
267void vector::MultiDimReductionOp::build(OpBuilder &builder,
268 OperationState &result, Value source,
269 Value acc, ArrayRef<bool> reductionMask,
270 CombiningKind kind) {
271 SmallVector<int64_t> reductionDims;
272 for (const auto &en : llvm::enumerate(reductionMask))
273 if (en.value())
274 reductionDims.push_back(en.index());
275 build(builder, result, kind, source, acc,
276 builder.getI64ArrayAttr(reductionDims));
277}
278
279OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
280 // Single parallel dim, this is a noop.
281 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
282 return getSource();
283 return {};
284}
285
286Optional<SmallVector<int64_t, 4>> MultiDimReductionOp::getShapeForUnroll() {
287 return llvm::to_vector<4>(getSourceVectorType().getShape());
288}
289
290LogicalResult MultiDimReductionOp::verify() {
291 SmallVector<int64_t> targetShape;
292 Type inferredReturnType;
293 for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
294 if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
295 return attr.cast<IntegerAttr>().getValue() == it.index();
296 }))
297 targetShape.push_back(it.value());
298 // TODO: update to also allow 0-d vectors when available.
299 if (targetShape.empty())
300 inferredReturnType = getSourceVectorType().getElementType();
301 else
302 inferredReturnType =
303 VectorType::get(targetShape, getSourceVectorType().getElementType());
304 if (getType() != inferredReturnType)
305 return emitOpError() << "destination type " << getType()
306 << " is incompatible with source type "
307 << getSourceVectorType();
308
309 return success();
310}
311
312//===----------------------------------------------------------------------===//
313// ReductionOp
314//===----------------------------------------------------------------------===//
315
316void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
317 CombiningKind kind, Value vector) {
318 build(builder, result, kind, vector, /*acc=*/Value());
319}
320
321void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
322 CombiningKind kind, Value vector, Value acc) {
323 build(builder, result, vector.getType().cast<VectorType>().getElementType(),
324 kind, vector, acc);
325}
326
327LogicalResult ReductionOp::verify() {
328 // Verify for 0-D and 1-D vector.
329 int64_t rank = getVectorType().getRank();
330 if (rank > 1)
331 return emitOpError("unsupported reduction rank: ") << rank;
332
333 // Verify supported reduction kind.
334 Type eltType = getDest().getType();
335 if (!isSupportedCombiningKind(getKind(), eltType))
336 return emitOpError("unsupported reduction type '")
337 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
338 << "'";
339
340 return success();
341}
342
343ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
344 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
345 Type redType;
346 Type resType;
347 CombiningKindAttr kindAttr;
348 if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
349 result.attributes) ||
350 parser.parseComma() || parser.parseOperandList(operandsInfo) ||
351 parser.parseColonType(redType) ||
352 parser.parseKeywordType("into", resType) ||
353 (!operandsInfo.empty() &&
354 parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
355 (operandsInfo.size() > 1 &&
356 parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
357 parser.addTypeToList(resType, result.types))
358 return failure();
359 if (operandsInfo.empty() || operandsInfo.size() > 2)
360 return parser.emitError(parser.getNameLoc(),
361 "unsupported number of operands");
362 return success();
363}
364
365void ReductionOp::print(OpAsmPrinter &p) {
366 p << " ";
367 getKindAttr().print(p);
368 p << ", " << getVector();
369 if (getAcc())
370 p << ", " << getAcc();
371 p << " : " << getVector().getType() << " into " << getDest().getType();
372}
373
374Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
375 OpBuilder &builder, Location loc,
376 Value vector) {
377 switch (op) {
378 case arith::AtomicRMWKind::addf:
379 case arith::AtomicRMWKind::addi:
380 return builder.create<vector::ReductionOp>(vector.getLoc(),
381 CombiningKind::ADD, vector);
382 case arith::AtomicRMWKind::mulf:
383 case arith::AtomicRMWKind::muli:
384 return builder.create<vector::ReductionOp>(vector.getLoc(),
385 CombiningKind::MUL, vector);
386 case arith::AtomicRMWKind::minf:
387 return builder.create<vector::ReductionOp>(vector.getLoc(),
388 CombiningKind::MINF, vector);
389 case arith::AtomicRMWKind::mins:
390 return builder.create<vector::ReductionOp>(vector.getLoc(),
391 CombiningKind::MINSI, vector);
392 case arith::AtomicRMWKind::minu:
393 return builder.create<vector::ReductionOp>(vector.getLoc(),
394 CombiningKind::MINUI, vector);
395 case arith::AtomicRMWKind::maxf:
396 return builder.create<vector::ReductionOp>(vector.getLoc(),
397 CombiningKind::MAXF, vector);
398 case arith::AtomicRMWKind::maxs:
399 return builder.create<vector::ReductionOp>(vector.getLoc(),
400 CombiningKind::MAXSI, vector);
401 case arith::AtomicRMWKind::maxu:
402 return builder.create<vector::ReductionOp>(vector.getLoc(),
403 CombiningKind::MAXUI, vector);
404 case arith::AtomicRMWKind::andi:
405 return builder.create<vector::ReductionOp>(vector.getLoc(),
406 CombiningKind::AND, vector);
407 case arith::AtomicRMWKind::ori:
408 return builder.create<vector::ReductionOp>(vector.getLoc(),
409 CombiningKind::OR, vector);
410 // TODO: Add remaining reduction operations.
411 default:
412 (void)emitOptionalError(loc, "Reduction operation type not supported");
413 break;
414 }
415 return nullptr;
416}
417
418Optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
419 return llvm::to_vector<4>(getVectorType().getShape());
420}
421
422namespace {
423struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
424 using OpRewritePattern::OpRewritePattern;
425
426 LogicalResult matchAndRewrite(ReductionOp reductionOp,
427 PatternRewriter &rewriter) const override {
428 if (reductionOp.getVectorType().getDimSize(0) != 1)
429 return failure();
430
431 Location loc = reductionOp.getLoc();
432 Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
433 reductionOp.getVector(),
434 rewriter.getI64ArrayAttr(0));
435
436 if (Value acc = reductionOp.getAcc())
437 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
438 result, acc);
439
440 rewriter.replaceOp(reductionOp, result);
441 return success();
442 }
443};
444} // namespace
445
446void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
447 MLIRContext *context) {
448 results.add<ElideSingleElementReduction>(context);
449}
450
451//===----------------------------------------------------------------------===//
452// ContractionOp
453//===----------------------------------------------------------------------===//
454
455void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
456 Value lhs, Value rhs, Value acc,
457 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
458 ArrayRef<IteratorType> iteratorTypes) {
459 result.addOperands({lhs, rhs, acc});
460 result.addTypes(acc.getType());
461 result.addAttribute(::mlir::getIndexingMapsAttrName(),
462 builder.getAffineMapArrayAttr(
463 AffineMap::inferFromExprList(indexingExprs)));
464 result.addAttribute(
465 ::mlir::getIteratorTypesAttrName(),
466 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
467 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
468 return IteratorTypeAttr::get(builder.getContext(), t);
469 }))));
470}
471
472void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
473 Value lhs, Value rhs, Value acc,
474 ArrayAttr indexingMaps,
475 ArrayAttr iteratorTypes) {
476 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
477 ContractionOp::getDefaultKind());
478}
479
480void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
481 Value lhs, Value rhs, Value acc,
482 ArrayAttr indexingMaps,
483 ArrayAttr iteratorTypes, CombiningKind kind) {
484 result.addOperands({lhs, rhs, acc});
485 result.addTypes(acc.getType());
486 result.addAttribute(::mlir::getIndexingMapsAttrName(), indexingMaps);
487 result.addAttribute(::mlir::getIteratorTypesAttrName(), iteratorTypes);
488 result.addAttribute(ContractionOp::getKindAttrStrName(),
489 CombiningKindAttr::get(builder.getContext(), kind));
490}
491
492ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
493 OpAsmParser::UnresolvedOperand lhsInfo;
494 OpAsmParser::UnresolvedOperand rhsInfo;
495 OpAsmParser::UnresolvedOperand accInfo;
496 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
497 SmallVector<Type, 2> types;
498 Type resultType;
499 auto loc = parser.getCurrentLocation();
500 DictionaryAttr dictAttr;
501 // TODO: Unify linalg op attribute parsing.
502 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
503 parser.parseOperand(lhsInfo) || parser.parseComma() ||
504 parser.parseOperand(rhsInfo) || parser.parseComma() ||
505 parser.parseOperand(accInfo) ||
506 parser.parseTrailingOperandList(masksInfo) ||
507 parser.parseOptionalAttrDict(result.attributes) ||
508 parser.parseColonTypeList(types) ||
509 parser.parseKeywordType("into", resultType) ||
510 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
511 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
512 parser.resolveOperand(accInfo, resultType, result.operands) ||
513 parser.addTypeToList(resultType, result.types))
514 return failure();
515 result.attributes.assign(dictAttr.getValue().begin(),
516 dictAttr.getValue().end());
517
518 // Convert array of string into an array of IteratyType enums. This is needed,
519 // because tests still use the old format when 'iterator_types' attribute is
520 // represented as an array of strings.
521 // TODO: Remove this conversion once tests are fixed.
522 ArrayAttr iteratorTypes =
523 result.attributes.get("iterator_types").cast<ArrayAttr>();
524
525 SmallVector<Attribute> iteratorTypeAttrs;
526
527 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
528 auto maybeIteratorType = symbolizeIteratorType(s);
529 if (!maybeIteratorType.has_value())
530 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
531
532 iteratorTypeAttrs.push_back(IteratorTypeAttr::get(
533 parser.getContext(), maybeIteratorType.value()));
534 }
535 result.attributes.set("iterator_types",
536 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
537
538 if (!result.attributes.get(ContractionOp::getKindAttrStrName())) {
539 result.addAttribute(
540 ContractionOp::getKindAttrStrName(),
541 CombiningKindAttr::get(result.getContext(),
542 ContractionOp::getDefaultKind()));
543 }
544 if (masksInfo.empty())
545 return success();
546 if (masksInfo.size() != 2)
547 return parser.emitError(parser.getNameLoc(),
548 "expected zero or exactly 2 vector mask operands");
549 auto lhsType = types[0].cast<VectorType>();
550 auto rhsType = types[1].cast<VectorType>();
551 auto maskElementType = parser.getBuilder().getI1Type();
552 std::array<Type, 2> maskTypes = {
553 VectorType::Builder(lhsType).setElementType(maskElementType),
554 VectorType::Builder(rhsType).setElementType(maskElementType)};
555 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
556 return failure();
557 return success();
558}
559
560void ContractionOp::print(OpAsmPrinter &p) {
561 // TODO: Unify printing code with linalg ops.
562 auto attrNames = getTraitAttrNames();
563 llvm::StringSet<> traitAttrsSet;
564 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
565 SmallVector<NamedAttribute, 8> attrs;
566 for (auto attr : (*this)->getAttrs()) {
567 if (attr.getName() == getIteratorTypesAttrName()) {
568 auto iteratorTypes =
569 attr.getValue()
570 .cast<ArrayAttr>()
571 .getAsValueRange<IteratorTypeAttr, IteratorType>();
572 // Convert IteratorType enums into the string representation. This is
573 // needed, because tests still use the old format when 'iterator_types'
574 // attribute is represented as an array of strings.
575 // TODO: Remove this conversion once tests are fixed.
576 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
577 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
578 return StringAttr::get(getContext(), stringifyIteratorType(t));
579 }));
580
581 attrs.emplace_back(getIteratorTypesAttrName(),
582 ArrayAttr::get(getContext(), iteratorTypeNames));
583 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
584 attrs.push_back(attr);
585 }
586
587 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
588 p << " " << dictAttr << " " << getLhs() << ", ";
589 p << getRhs() << ", " << getAcc();
590 if (getMasks().size() == 2)
591 p << ", " << getMasks();
592
593 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
594 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
595 << getResultType();
596}
597
598static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
599 const std::vector<std::pair<int64_t, int64_t>> &map) {
600 for (auto &dimPair : map) {
601 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
602 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
603 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
604 return false;
605 }
606 return true;
607}
608
609static LogicalResult verifyOutputShape(
610 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
611 Type resType,
612 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
613 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
614 DenseSet<int64_t> lhsContractingDimSet;
615 DenseSet<int64_t> rhsContractingDimSet;
616 for (auto &dimPair : contractingDimMap) {
617 lhsContractingDimSet.insert(dimPair.first);
618 rhsContractingDimSet.insert(dimPair.second);
619 }
620 DenseSet<int64_t> rhsBatchDimSet;
621 for (auto &dimPair : batchDimMap)
622 rhsBatchDimSet.insert(dimPair.second);
623
624 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
625 SmallVector<int64_t, 4> expectedResultDims;
626 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
627 if (lhsContractingDimSet.count(i) > 0)
628 continue;
629 expectedResultDims.push_back(lhsType.getDimSize(i));
630 }
631
632 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
633 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
634 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
635 continue;
636 expectedResultDims.push_back(rhsType.getDimSize(i));
637 }
638
639 // Verify 'expectedResultDims'.
640 if (expectedResultDims.empty()) {
641 // No batch or free dimension implies a scalar result.
642 if (resType.isa<VectorType>() || accType.isa<VectorType>())
643 return op.emitOpError("invalid accumulator/result vector shape");
644 } else {
645 // At least one batch or free dimension implies a vector result.
646 auto resVectorType = resType.dyn_cast<VectorType>();
647 auto accVectorType = accType.dyn_cast<VectorType>();
648 if (!resVectorType || !accVectorType)
649 return op.emitOpError("invalid accumulator/result vector shape");
650
651 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
652 // types fully define the result vector type. This assumes the affine maps
653 // are well-formed, which must have been verified already.
654 MLIRContext *ctx = op.getContext();
655 AffineMap lhsMap = op.getIndexingMapsArray()[0];
656 AffineMap rhsMap = op.getIndexingMapsArray()[1];
657 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
658 return op.emitOpError(
659 "expected all dimensions to be either a LHS or a RHS dimension");
660 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
661 for (auto pair :
662 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
663 VectorType v = pair.first;
664 auto map = pair.second;
665 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
666 unsigned pos = map.getDimPosition(idx);
667 if (!extents[pos])
668 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
669 }
670 }
671 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
672 return op.emitOpError("expected all dimensions to get an extent as "
673 "either a LHS or a RHS dimension");
674
675 AffineMap resMap = op.getIndexingMapsArray()[2];
676 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
677 /*symCount=*/0, extents, ctx);
678 // Compose the resMap with the extentsMap, which is a constant map.
679 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
680 assert(llvm::all_of((static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 683, __extension__
__PRETTY_FUNCTION__))
681 expectedMap.getResults(),(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 683, __extension__
__PRETTY_FUNCTION__))
682 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 683, __extension__
__PRETTY_FUNCTION__))
683 "expected constant extent along all dimensions.")(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 683, __extension__
__PRETTY_FUNCTION__))
;
684 // Extract the expected shape and build the type.
685 auto expectedShape = llvm::to_vector<4>(
686 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
687 return e.cast<AffineConstantExpr>().getValue();
688 }));
689 auto expected =
690 VectorType::get(expectedShape, resVectorType.getElementType());
691 if (resVectorType != expected || accVectorType != expected)
692 return op.emitOpError(
693 "invalid accumulator/result vector shape, expected: ")
694 << expected;
695 }
696 return success();
697}
698
699LogicalResult ContractionOp::verify() {
700 auto lhsType = getLhsType();
701 auto rhsType = getRhsType();
702 auto accType = getAccType();
703 auto resType = getResultType();
704
705 // Verify that an indexing map was specified for each vector operand.
706 if (getIndexingMapsArray().size() != 3)
707 return emitOpError("expected an indexing map for each vector operand");
708
709 // Verify that each index map has 'numIterators' inputs, no symbols, and
710 // that the number of map outputs equals the rank of its associated
711 // vector operand.
712 unsigned numIterators = getIteratorTypes().getValue().size();
713 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
714 auto index = it.index();
715 auto map = it.value();
716 if (map.getNumSymbols() != 0)
717 return emitOpError("expected indexing map ")
718 << index << " to have no symbols";
719 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
720 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
721 // Verify that the map has the right number of inputs, outputs, and indices.
722 // This also correctly accounts for (..) -> () for rank-0 results.
723 if (map.getNumDims() != numIterators)
724 return emitOpError("expected indexing map ")
725 << index << " to have " << numIterators << " number of inputs";
726 if (map.getNumResults() != rank)
727 return emitOpError("expected indexing map ")
728 << index << " to have " << rank << " number of outputs";
729 if (!map.isProjectedPermutation())
730 return emitOpError("expected indexing map ")
731 << index << " to be a projected permutation of its inputs";
732 }
733
734 auto contractingDimMap = getContractingDimMap();
735 auto batchDimMap = getBatchDimMap();
736
737 // Verify at least one contracting dimension pair was specified.
738 if (contractingDimMap.empty())
739 return emitOpError("expected at least one contracting dimension pair");
740
741 // Verify contracting dimension map was properly constructed.
742 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
743 return emitOpError("invalid contracting dimension map");
744
745 // Verify batch dimension map was properly constructed.
746 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
747 return emitOpError("invalid batch dimension map");
748
749 // Verify 'accType' and 'resType' shape.
750 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
751 contractingDimMap, batchDimMap)))
752 return failure();
753
754 // Verify that either two vector masks are set or none are set.
755 auto lhsMaskType = getLHSVectorMaskType();
756 auto rhsMaskType = getRHSVectorMaskType();
757 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
758 return emitOpError("invalid number of vector masks specified");
759 if (lhsMaskType && rhsMaskType) {
760 // Verify mask rank == argument rank.
761 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
762 rhsMaskType.getShape().size() != rhsType.getShape().size())
763 return emitOpError("invalid vector mask rank");
764 }
765
766 // Verify supported combining kind.
767 auto vectorType = resType.dyn_cast<VectorType>();
768 auto elementType = vectorType ? vectorType.getElementType() : resType;
769 if (!isSupportedCombiningKind(getKind(), elementType))
770 return emitOpError("unsupported contraction type");
771
772 return success();
773}
774
775ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
776 static constexpr StringRef names[3] = {::mlir::getIndexingMapsAttrName(),
777 ::mlir::getIteratorTypesAttrName(),
778 ContractionOp::getKindAttrStrName()};
779 return llvm::makeArrayRef(names);
780}
781
782static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
783 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
784 if (targetExpr == map.getResult(i))
785 return i;
786 return -1;
787}
788
789static std::vector<std::pair<int64_t, int64_t>>
790getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
791 IteratorType targetIteratorType, MLIRContext *context) {
792 std::vector<std::pair<int64_t, int64_t>> dimMap;
793 for (const auto &it : llvm::enumerate(iteratorTypes)) {
794 auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
795 if (iteratorType != targetIteratorType)
796 continue;
797 // Search lhs/rhs map results for 'targetExpr'.
798 auto targetExpr = getAffineDimExpr(it.index(), context);
799 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
800 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
801 if (lhsDim >= 0 && rhsDim >= 0)
802 dimMap.emplace_back(lhsDim, rhsDim);
803 }
804 return dimMap;
805}
806
807void ContractionOp::getIterationBounds(
808 SmallVectorImpl<int64_t> &iterationBounds) {
809 auto lhsShape = getLhsType().getShape();
810 auto resVectorType = getResultType().dyn_cast<VectorType>();
811 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
812 SmallVector<int64_t, 2> iterationShape;
813 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
814 // Search lhs/rhs map results for 'targetExpr'.
815 auto targetExpr = getAffineDimExpr(it.index(), getContext());
816 auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
817 if (iteratorType == IteratorType::reduction) {
818 // Get reduction dim size from lhs shape (same size in rhsShape).
819 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
820 assert(lhsDimIndex >= 0)(static_cast <bool> (lhsDimIndex >= 0) ? void (0) : __assert_fail
("lhsDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 820, __extension__ __PRETTY_FUNCTION__))
;
821 iterationBounds.push_back(lhsShape[lhsDimIndex]);
822 continue;
823 }
824 // Get parallel dimension size from result shape.
825 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
826 assert(resDimIndex >= 0)(static_cast <bool> (resDimIndex >= 0) ? void (0) : __assert_fail
("resDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 826, __extension__ __PRETTY_FUNCTION__))
;
827 assert(resVectorType != nullptr)(static_cast <bool> (resVectorType != nullptr) ? void (
0) : __assert_fail ("resVectorType != nullptr", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 827, __extension__ __PRETTY_FUNCTION__))
;
828 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
829 }
830}
831
832void ContractionOp::getIterationIndexMap(
833 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
834 unsigned numMaps = getIndexingMapsArray().size();
835 iterationIndexMap.resize(numMaps);
836 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
837 auto index = it.index();
838 auto map = it.value();
839 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
840 auto dim = map.getResult(i).cast<AffineDimExpr>();
841 iterationIndexMap[index][dim.getPosition()] = i;
842 }
843 }
844}
845
846std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
847 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
848 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
849 getContext());
850}
851
852std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
853 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
854 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
855 getContext());
856}
857
858Optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
859 SmallVector<int64_t, 4> shape;
860 getIterationBounds(shape);
861 return shape;
862}
863
864/// Return a fused vector::ContractionOp which represents a patterns such as:
865///
866/// ```mlir
867/// %c0 = vector.constant 0: ...
868/// %c = vector.contract %a, %b, %c0: ...
869/// %e = add %c, %d: ...
870/// ```
871///
872/// by:
873///
874/// ```mlir
875/// %e = vector.contract %a, %b, %d: ...
876/// ```
877///
878/// Return null if the canonicalization does not apply.
879// TODO: This should be a folding of Add into Contract in core but while they
880// live in different dialects, it is not possible without unnatural
881// dependencies.
882template <typename AddOpType>
883struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
884 using OpRewritePattern<AddOpType>::OpRewritePattern;
885
886 LogicalResult matchAndRewrite(AddOpType addOp,
887 PatternRewriter &rewriter) const override {
888 auto canonicalize = [&](Value maybeContraction,
889 Value otherOperand) -> vector::ContractionOp {
890 vector::ContractionOp contractionOp =
891 dyn_cast_or_null<vector::ContractionOp>(
892 maybeContraction.getDefiningOp());
893 if (!contractionOp)
894 return vector::ContractionOp();
895 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
896 contractionOp.getAcc().getDefiningOp())) {
897 if (maybeZero.getValue() ==
898 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
899 BlockAndValueMapping bvm;
900 bvm.map(contractionOp.getAcc(), otherOperand);
901 auto newContraction =
902 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
903 rewriter.replaceOp(addOp, newContraction.getResult());
904 return newContraction;
905 }
906 }
907 return vector::ContractionOp();
908 };
909
910 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
911 vector::ContractionOp contract = canonicalize(a, b);
912 contract = contract ? contract : canonicalize(b, a);
913 return contract ? success() : failure();
914 }
915};
916
917void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
918 MLIRContext *context) {
919 results.add<CanonicalizeContractAdd<arith::AddIOp>,
920 CanonicalizeContractAdd<arith::AddFOp>>(context);
921}
922
923//===----------------------------------------------------------------------===//
924// ExtractElementOp
925//===----------------------------------------------------------------------===//
926
927void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
928 Value source) {
929 result.addOperands({source});
930 result.addTypes(source.getType().cast<VectorType>().getElementType());
931}
932
933void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
934 Value source, Value position) {
935 result.addOperands({source, position});
936 result.addTypes(source.getType().cast<VectorType>().getElementType());
937}
938
939LogicalResult vector::ExtractElementOp::verify() {
940 VectorType vectorType = getVectorType();
941 if (vectorType.getRank() == 0) {
942 if (getPosition())
943 return emitOpError("expected position to be empty with 0-D vector");
944 return success();
945 }
946 if (vectorType.getRank() != 1)
947 return emitOpError("unexpected >1 vector rank");
948 if (!getPosition())
949 return emitOpError("expected position for 1-D vector");
950 return success();
951}
952
953OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
954 // Skip the 0-D vector here now.
955 if (operands.size() < 2)
956 return {};
957
958 Attribute src = operands[0];
959 Attribute pos = operands[1];
960
961 // Fold extractelement (splat X) -> X.
962 if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
963 return splat.getInput();
964
965 if (!pos || !src)
966 return {};
967
968 auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
969
970 auto attr = pos.dyn_cast<IntegerAttr>();
971 uint64_t posIdx = attr.getInt();
972
973 return srcElements[posIdx];
974}
975
976//===----------------------------------------------------------------------===//
977// ExtractOp
978//===----------------------------------------------------------------------===//
979
980void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
981 Value source, ArrayRef<int64_t> position) {
982 build(builder, result, source, getVectorSubscriptAttr(builder, position));
983}
984
985// Convenience builder which assumes the values are constant indices.
986void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
987 Value source, ValueRange position) {
988 SmallVector<int64_t, 4> positionConstants =
989 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
990 return pos.getDefiningOp<arith::ConstantIndexOp>().value();
991 }));
992 build(builder, result, source, positionConstants);
993}
994
995LogicalResult
996ExtractOp::inferReturnTypes(MLIRContext *, Optional<Location>,
997 ValueRange operands, DictionaryAttr attributes,
998 RegionRange,
999 SmallVectorImpl<Type> &inferredReturnTypes) {
1000 ExtractOp::Adaptor op(operands, attributes);
1001 auto vectorType = op.getVector().getType().cast<VectorType>();
1002 if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
1003 inferredReturnTypes.push_back(vectorType.getElementType());
1004 } else {
1005 auto n =
1006 std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1007 inferredReturnTypes.push_back(VectorType::get(
1008 vectorType.getShape().drop_front(n), vectorType.getElementType()));
1009 }
1010 return success();
1011}
1012
1013bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1014 // Allow extracting 1-element vectors instead of scalars.
1015 auto isCompatible = [](TypeRange l, TypeRange r) {
1016 auto vectorType = l.front().dyn_cast<VectorType>();
1017 return vectorType && vectorType.getShape().equals({1}) &&
1018 vectorType.getElementType() == r.front();
1019 };
1020 if (l.size() == 1 && r.size() == 1 &&
1021 (isCompatible(l, r) || isCompatible(r, l)))
1022 return true;
1023 return l == r;
1024}
1025
1026LogicalResult vector::ExtractOp::verify() {
1027 auto positionAttr = getPosition().getValue();
1028 if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
1029 return emitOpError(
1030 "expected position attribute of rank smaller than vector rank");
1031 for (const auto &en : llvm::enumerate(positionAttr)) {
1032 auto attr = en.value().dyn_cast<IntegerAttr>();
1033 if (!attr || attr.getInt() < 0 ||
1034 attr.getInt() >= getVectorType().getDimSize(en.index()))
1035 return emitOpError("expected position attribute #")
1036 << (en.index() + 1)
1037 << " to be a non-negative integer smaller than the corresponding "
1038 "vector dimension";
1039 }
1040 return success();
1041}
1042
1043template <typename IntType>
1044static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1045 return llvm::to_vector<4>(llvm::map_range(
1046 arrayAttr.getAsRange<IntegerAttr>(),
1047 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1048}
1049
1050/// Fold the result of chains of ExtractOp in place by simply concatenating the
1051/// positions.
1052static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1053 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1054 return failure();
1055
1056 SmallVector<int64_t, 4> globalPosition;
1057 ExtractOp currentOp = extractOp;
1058 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1059 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1060 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1061 currentOp = nextOp;
1062 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1063 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1064 }
1065 extractOp.setOperand(currentOp.getVector());
1066 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1067 OpBuilder b(extractOp.getContext());
1068 std::reverse(globalPosition.begin(), globalPosition.end());
1069 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1070 b.getI64ArrayAttr(globalPosition));
1071 return success();
1072}
1073
1074namespace {
1075/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1076/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1077/// Compose TransposeOp permutations as we walk back.
1078/// This helper class keeps an updated extraction position `extractPosition`
1079/// with extra trailing sentinels.
1080/// The sentinels encode the internal transposition status of the result vector.
1081/// As we iterate, extractPosition is permuted and updated.
1082class ExtractFromInsertTransposeChainState {
1083public:
1084 ExtractFromInsertTransposeChainState(ExtractOp e);
1085
1086 /// Iterate over producing insert and transpose ops until we find a fold.
1087 Value fold();
1088
1089private:
1090 /// Return true if the vector at position `a` is contained within the vector
1091 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1092 /// is a prefix of `b`.
1093 template <typename ContainerA, typename ContainerB>
1094 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1095 return a.size() <= b.size() &&
1096 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1097 }
1098
1099 /// Return true if the vector at position `a` intersects the vector at
1100 /// position `b`. Under insert/extract semantics, this is the same as equality
1101 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1102 /// Comparison is on the common prefix (i.e. zip).
1103 template <typename ContainerA, typename ContainerB>
1104 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1105 for (auto it : llvm::zip(a, b)) {
1106 if (std::get<0>(it) < 0 || std::get<0>(it) < 0)
1107 continue;
1108 if (std::get<0>(it) != std::get<1>(it))
1109 return false;
1110 }
1111 return true;
1112 }
1113
1114 /// Folding is only possible in the absence of an internal permutation in the
1115 /// result vector.
1116 bool canFold() {
1117 return (sentinels ==
1118 makeArrayRef(extractPosition).drop_front(extractedRank));
1119 }
1120
1121 // Helper to get the next defining op of interest.
1122 void updateStateForNextIteration(Value v) {
1123 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1124 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1125 };
1126
1127 // Case 1. If we hit a transpose, just compose the map and iterate.
1128 // Invariant: insert + transpose do not change rank, we can always compose.
1129 LogicalResult handleTransposeOp();
1130
1131 // Case 2: the insert position matches extractPosition exactly, early return.
1132 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1133
1134 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1135 /// portion of the source of the insert.
1136 /// Example:
1137 /// ```
1138 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1139 /// // extractPosition == [1, 2, 3]
1140 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1141 /// // can fold to vector.extract %source[0, 3]
1142 /// %ext = vector.extract %source[3]: vector<5x6>
1143 /// ```
1144 /// To traverse through %source, we need to set the leading dims to 0 and
1145 /// drop the extra leading dims.
1146 /// This method updates the internal state.
1147 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1148
1149 /// Try to fold in place to extract(source, extractPosition) and return the
1150 /// folded result. Return null if folding is not possible (e.g. due to an
1151 /// internal tranposition in the result).
1152 Value tryToFoldExtractOpInPlace(Value source);
1153
1154 ExtractOp extractOp;
1155 int64_t vectorRank;
1156 int64_t extractedRank;
1157
1158 InsertOp nextInsertOp;
1159 TransposeOp nextTransposeOp;
1160
1161 /// Sentinel values that encode the internal permutation status of the result.
1162 /// They are set to (-1, ... , -k) at the beginning and appended to
1163 /// `extractPosition`.
1164 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1165 /// ensure that there is no internal transposition.
1166 /// Internal transposition cannot be accounted for with a folding pattern.
1167 // TODO: We could relax the internal transposition with an extra transposition
1168 // operation in a future canonicalizer.
1169 SmallVector<int64_t> sentinels;
1170 SmallVector<int64_t> extractPosition;
1171};
1172} // namespace
1173
1174ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1175 ExtractOp e)
1176 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1177 extractedRank(extractOp.getPosition().size()) {
1178 assert(vectorRank >= extractedRank && "extracted pos overflow")(static_cast <bool> (vectorRank >= extractedRank &&
"extracted pos overflow") ? void (0) : __assert_fail ("vectorRank >= extractedRank && \"extracted pos overflow\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1178, __extension__
__PRETTY_FUNCTION__))
;
1179 sentinels.reserve(vectorRank - extractedRank);
1180 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1181 sentinels.push_back(-(i + 1));
1182 extractPosition = extractVector<int64_t>(extractOp.getPosition());
1183 llvm::append_range(extractPosition, sentinels);
1184}
1185
1186// Case 1. If we hit a transpose, just compose the map and iterate.
1187// Invariant: insert + transpose do not change rank, we can always compose.
1188LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1189 if (!nextTransposeOp)
1190 return failure();
1191 auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1192 AffineMap m = inversePermutation(
1193 AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1194 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
1195 return success();
1196}
1197
1198// Case 2: the insert position matches extractPosition exactly, early return.
1199LogicalResult
1200ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1201 Value &res) {
1202 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1203 if (makeArrayRef(insertedPos) !=
1204 llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1205 return failure();
1206 // Case 2.a. early-exit fold.
1207 res = nextInsertOp.getSource();
1208 // Case 2.b. if internal transposition is present, canFold will be false.
1209 return success();
1210}
1211
1212/// Case 3: if inserted position is a prefix of extractPosition,
1213/// extract a portion of the source of the insertion.
1214/// This method updates the internal state.
1215LogicalResult
1216ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1217 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1218 if (!isContainedWithin(insertedPos, extractPosition))
1219 return failure();
1220 // Set leading dims to zero.
1221 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1222 // Drop extra leading dims.
1223 extractPosition.erase(extractPosition.begin(),
1224 extractPosition.begin() + insertedPos.size());
1225 extractedRank = extractPosition.size() - sentinels.size();
1226 // Case 3.a. early-exit fold (break and delegate to post-while path).
1227 res = nextInsertOp.getSource();
1228 // Case 3.b. if internal transposition is present, canFold will be false.
1229 return success();
1230}
1231
1232/// Try to fold in place to extract(source, extractPosition) and return the
1233/// folded result. Return null if folding is not possible (e.g. due to an
1234/// internal tranposition in the result).
1235Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1236 Value source) {
1237 // If we can't fold (either internal transposition, or nothing to fold), bail.
1238 bool nothingToFold = (source == extractOp.getVector());
1239 if (nothingToFold || !canFold())
1240 return Value();
1241 // Otherwise, fold by updating the op inplace and return its result.
1242 OpBuilder b(extractOp.getContext());
1243 extractOp->setAttr(
1244 extractOp.getPositionAttrName(),
1245 b.getI64ArrayAttr(
1246 makeArrayRef(extractPosition).take_front(extractedRank)));
1247 extractOp.getVectorMutable().assign(source);
1248 return extractOp.getResult();
1249}
1250
1251/// Iterate over producing insert and transpose ops until we find a fold.
1252Value ExtractFromInsertTransposeChainState::fold() {
1253 Value valueToExtractFrom = extractOp.getVector();
1254 updateStateForNextIteration(valueToExtractFrom);
1255 while (nextInsertOp || nextTransposeOp) {
1256 // Case 1. If we hit a transpose, just compose the map and iterate.
1257 // Invariant: insert + transpose do not change rank, we can always compose.
1258 if (succeeded(handleTransposeOp())) {
1259 valueToExtractFrom = nextTransposeOp.getVector();
1260 updateStateForNextIteration(valueToExtractFrom);
1261 continue;
1262 }
1263
1264 Value result;
1265 // Case 2: the position match exactly.
1266 if (succeeded(handleInsertOpWithMatchingPos(result)))
1267 return result;
1268
1269 // Case 3: if the inserted position is a prefix of extractPosition, we can
1270 // just extract a portion of the source of the insert.
1271 if (succeeded(handleInsertOpWithPrefixPos(result)))
1272 return tryToFoldExtractOpInPlace(result);
1273
1274 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1275 // values. This is a more difficult case and we bail.
1276 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1277 if (isContainedWithin(extractPosition, insertedPos) ||
1278 intersectsWhereNonNegative(extractPosition, insertedPos))
1279 return Value();
1280
1281 // Case 5: No intersection, we forward the extract to insertOp.dest().
1282 valueToExtractFrom = nextInsertOp.getDest();
1283 updateStateForNextIteration(valueToExtractFrom);
1284 }
1285 // If after all this we can fold, go for it.
1286 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1287}
1288
1289/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1290static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1291 Operation *defOp = extractOp.getVector().getDefiningOp();
1292 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1293 return Value();
1294 Value source = defOp->getOperand(0);
1295 if (extractOp.getType() == source.getType())
1296 return source;
1297 auto getRank = [](Type type) {
1298 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1299 };
1300 unsigned broadcastSrcRank = getRank(source.getType());
1301 unsigned extractResultRank = getRank(extractOp.getType());
1302 if (extractResultRank >= broadcastSrcRank)
1303 return Value();
1304 // Check that the dimension of the result haven't been broadcasted.
1305 auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1306 auto broadcastVecType = source.getType().dyn_cast<VectorType>();
1307 if (extractVecType && broadcastVecType &&
1308 extractVecType.getShape() !=
1309 broadcastVecType.getShape().take_back(extractResultRank))
1310 return Value();
1311 auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1312 unsigned rankDiff = broadcastSrcRank - extractResultRank;
1313 extractPos.erase(extractPos.begin(),
1314 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1315 extractOp.setOperand(source);
1316 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1317 OpBuilder b(extractOp.getContext());
1318 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1319 b.getI64ArrayAttr(extractPos));
1320 return extractOp.getResult();
1321}
1322
1323// Fold extractOp with source coming from ShapeCast op.
1324static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1325 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1326 if (!shapeCastOp)
1327 return Value();
1328 // Get the nth dimension size starting from lowest dimension.
1329 auto getDimReverse = [](VectorType type, int64_t n) {
1330 return type.getShape().take_back(n + 1).front();
1331 };
1332 int64_t destinationRank =
1333 extractOp.getType().isa<VectorType>()
1334 ? extractOp.getType().cast<VectorType>().getRank()
1335 : 0;
1336 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1337 return Value();
1338 if (destinationRank > 0) {
1339 auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1340 for (int64_t i = 0; i < destinationRank; i++) {
1341 // The lowest dimension of of the destination must match the lowest
1342 // dimension of the shapecast op source.
1343 // TODO: This case could be support in a canonicalization pattern.
1344 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1345 getDimReverse(destinationType, i))
1346 return Value();
1347 }
1348 }
1349 // Extract the strides associated with the extract op vector source. Then use
1350 // this to calculate a linearized position for the extract.
1351 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1352 std::reverse(extractedPos.begin(), extractedPos.end());
1353 SmallVector<int64_t, 4> strides;
1354 int64_t stride = 1;
1355 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1356 strides.push_back(stride);
1357 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1358 }
1359
1360 int64_t position = linearize(extractedPos, strides);
1361 // Then extract the strides associated to the shapeCast op vector source and
1362 // delinearize the position using those strides.
1363 SmallVector<int64_t, 4> newStrides;
1364 int64_t numDimension =
1365 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1366 stride = 1;
1367 for (int64_t i = 0; i < numDimension; i++) {
1368 newStrides.push_back(stride);
1369 stride *=
1370 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1371 }
1372 std::reverse(newStrides.begin(), newStrides.end());
1373 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1374 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1375 OpBuilder b(extractOp.getContext());
1376 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1377 b.getI64ArrayAttr(newPosition));
1378 extractOp.setOperand(shapeCastOp.getSource());
1379 return extractOp.getResult();
1380}
1381
1382/// Fold an ExtractOp from ExtractStridedSliceOp.
1383static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1384 auto extractStridedSliceOp =
1385 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1386 if (!extractStridedSliceOp)
1387 return Value();
1388 // Return if 'extractStridedSliceOp' has non-unit strides.
1389 if (extractStridedSliceOp.hasNonUnitStrides())
1390 return Value();
1391
1392 // Trim offsets for dimensions fully extracted.
1393 auto sliceOffsets =
1394 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1395 while (!sliceOffsets.empty()) {
1396 size_t lastOffset = sliceOffsets.size() - 1;
1397 if (sliceOffsets.back() != 0 ||
1398 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1399 extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1400 break;
1401 sliceOffsets.pop_back();
1402 }
1403 unsigned destinationRank = 0;
1404 if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1405 destinationRank = vecType.getRank();
1406 // The dimensions of the result need to be untouched by the
1407 // extractStridedSlice op.
1408 if (destinationRank >
1409 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1410 return Value();
1411 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1412 assert(extractedPos.size() >= sliceOffsets.size())(static_cast <bool> (extractedPos.size() >= sliceOffsets
.size()) ? void (0) : __assert_fail ("extractedPos.size() >= sliceOffsets.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1412, __extension__
__PRETTY_FUNCTION__))
;
1413 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1414 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1415 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1416 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1417 OpBuilder b(extractOp.getContext());
1418 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1419 b.getI64ArrayAttr(extractedPos));
1420 return extractOp.getResult();
1421}
1422
1423/// Fold extract_op fed from a chain of insertStridedSlice ops.
1424static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
1425 int64_t destinationRank = op.getType().isa<VectorType>()
1426 ? op.getType().cast<VectorType>().getRank()
1427 : 0;
1428 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
1429 while (insertOp) {
1430 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1431 insertOp.getSourceVectorType().getRank();
1432 if (destinationRank > insertOp.getSourceVectorType().getRank())
1433 return Value();
1434 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1435 auto extractOffsets = extractVector<int64_t>(op.getPosition());
1436
1437 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1438 return attr.cast<IntegerAttr>().getInt() != 1;
1439 }))
1440 return Value();
1441 bool disjoint = false;
1442 SmallVector<int64_t, 4> offsetDiffs;
1443 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1444 int64_t start = insertOffsets[dim];
1445 int64_t size =
1446 (dim < insertRankDiff)
1447 ? 1
1448 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1449 int64_t end = start + size;
1450 int64_t offset = extractOffsets[dim];
1451 // Check if the start of the extract offset is in the interval inserted.
1452 if (start <= offset && offset < end) {
1453 if (dim >= insertRankDiff)
1454 offsetDiffs.push_back(offset - start);
1455 continue;
1456 }
1457 disjoint = true;
1458 break;
1459 }
1460 // The extract element chunk overlap with the vector inserted.
1461 if (!disjoint) {
1462 // If any of the inner dimensions are only partially inserted we have a
1463 // partial overlap.
1464 int64_t srcRankDiff =
1465 insertOp.getSourceVectorType().getRank() - destinationRank;
1466 for (int64_t i = 0; i < destinationRank; i++) {
1467 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1468 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1469 insertRankDiff))
1470 return Value();
1471 }
1472 op.getVectorMutable().assign(insertOp.getSource());
1473 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1474 OpBuilder b(op.getContext());
1475 op->setAttr(ExtractOp::getPositionAttrStrName(),
1476 b.getI64ArrayAttr(offsetDiffs));
1477 return op.getResult();
1478 }
1479 // If the chunk extracted is disjoint from the chunk inserted, keep
1480 // looking in the insert chain.
1481 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1482 }
1483 return Value();
1484}
1485
1486OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1487 if (getPosition().empty())
1488 return getVector();
1489 if (succeeded(foldExtractOpFromExtractChain(*this)))
1490 return getResult();
1491 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1492 return res;
1493 if (auto res = foldExtractFromBroadcast(*this))
1494 return res;
1495 if (auto res = foldExtractFromShapeCast(*this))
1496 return res;
1497 if (auto val = foldExtractFromExtractStrided(*this))
1498 return val;
1499 if (auto val = foldExtractStridedOpFromInsertChain(*this))
1500 return val;
1501 return OpFoldResult();
1502}
1503
1504namespace {
1505
1506// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1507class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1508public:
1509 using OpRewritePattern::OpRewritePattern;
1510
1511 LogicalResult matchAndRewrite(ExtractOp extractOp,
1512 PatternRewriter &rewriter) const override {
1513 Operation *defOp = extractOp.getVector().getDefiningOp();
1514 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1515 return failure();
1516
1517 Value source = defOp->getOperand(0);
1518 if (extractOp.getType() == source.getType())
1519 return failure();
1520 auto getRank = [](Type type) {
1521 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1522 };
1523 unsigned broadcastSrcRank = getRank(source.getType());
1524 unsigned extractResultRank = getRank(extractOp.getType());
1525 // We only consider the case where the rank of the source is less than or
1526 // equal to the rank of the extract dst. The other cases are handled in the
1527 // folding patterns.
1528 if (extractResultRank < broadcastSrcRank)
1529 return failure();
1530 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1531 extractOp, extractOp.getType(), source);
1532 return success();
1533 }
1534};
1535
1536// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1537class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1538public:
1539 using OpRewritePattern::OpRewritePattern;
1540
1541 LogicalResult matchAndRewrite(ExtractOp extractOp,
1542 PatternRewriter &rewriter) const override {
1543 // Return if 'ExtractOp' operand is not defined by a splat vector
1544 // ConstantOp.
1545 Value sourceVector = extractOp.getVector();
1546 Attribute vectorCst;
1547 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1548 return failure();
1549 auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
1550 if (!splat)
1551 return failure();
1552 Attribute newAttr = splat.getSplatValue<Attribute>();
1553 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1554 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1555 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1556 return success();
1557 }
1558};
1559
1560// Pattern to rewrite a ExtractOp(vector<...xT> ConstantOp)[...] -> ConstantOp,
1561// where the position array specifies a scalar element.
1562class ExtractOpScalarVectorConstantFolder final
1563 : public OpRewritePattern<ExtractOp> {
1564public:
1565 using OpRewritePattern::OpRewritePattern;
1566
1567 LogicalResult matchAndRewrite(ExtractOp extractOp,
1568 PatternRewriter &rewriter) const override {
1569 // Return if 'ExtractOp' operand is not defined by a compatible vector
1570 // ConstantOp.
1571 Value sourceVector = extractOp.getVector();
1572 Attribute vectorCst;
1573 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1574 return failure();
1575
1576 auto vecTy = sourceVector.getType().cast<VectorType>();
1577 Type elemTy = vecTy.getElementType();
1578 ArrayAttr positions = extractOp.getPosition();
1579 if (vecTy.isScalable())
1580 return failure();
1581 // Do not allow extracting sub-vectors to limit the size of the generated
1582 // constants.
1583 if (vecTy.getRank() != static_cast<int64_t>(positions.size()))
1584 return failure();
1585 // TODO: Handle more element types, e.g., complex values.
1586 if (!elemTy.isIntOrIndexOrFloat())
1587 return failure();
1588
1589 // The splat case is handled by `ExtractOpSplatConstantFolder`.
1590 auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
1591 if (!dense || dense.isSplat())
1592 return failure();
1593
1594 // Calculate the flattened position.
1595 int64_t elemPosition = 0;
1596 int64_t innerElems = 1;
1597 for (auto [dimSize, positionInDim] :
1598 llvm::reverse(llvm::zip(vecTy.getShape(), positions))) {
1599 int64_t positionVal = positionInDim.cast<IntegerAttr>().getInt();
1600 elemPosition += positionVal * innerElems;
1601 innerElems *= dimSize;
1602 }
1603
1604 Attribute newAttr;
1605 if (vecTy.getElementType().isIntOrIndex()) {
1606 auto values = to_vector(dense.getValues<APInt>());
1607 newAttr = IntegerAttr::get(extractOp.getType(), values[elemPosition]);
1608 } else if (vecTy.getElementType().isa<FloatType>()) {
1609 auto values = to_vector(dense.getValues<APFloat>());
1610 newAttr = FloatAttr::get(extractOp.getType(), values[elemPosition]);
1611 }
1612 assert(newAttr && "Unhandled case")(static_cast <bool> (newAttr && "Unhandled case"
) ? void (0) : __assert_fail ("newAttr && \"Unhandled case\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1612, __extension__
__PRETTY_FUNCTION__))
;
1613
1614 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1615 return success();
1616 }
1617};
1618
1619} // namespace
1620
1621void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1622 MLIRContext *context) {
1623 results.add<ExtractOpSplatConstantFolder, ExtractOpScalarVectorConstantFolder,
1624 ExtractOpFromBroadcast>(context);
1625}
1626
1627static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1628 SmallVectorImpl<int64_t> &results) {
1629 for (auto attr : arrayAttr)
1630 results.push_back(attr.cast<IntegerAttr>().getInt());
1631}
1632
1633//===----------------------------------------------------------------------===//
1634// FmaOp
1635//===----------------------------------------------------------------------===//
1636
1637Optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1638 return llvm::to_vector<4>(getVectorType().getShape());
1639}
1640
1641//===----------------------------------------------------------------------===//
1642// BroadcastOp
1643//===----------------------------------------------------------------------===//
1644
1645BroadcastableToResult
1646mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1647 std::pair<int, int> *mismatchingDims) {
1648 // Broadcast scalar to vector of the same element type.
1649 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1650 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1651 return BroadcastableToResult::Success;
1652 // From now on, only vectors broadcast.
1653 VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1654 if (!srcVectorType)
1655 return BroadcastableToResult::SourceTypeNotAVector;
1656
1657 int64_t srcRank = srcVectorType.getRank();
1658 int64_t dstRank = dstVectorType.getRank();
1659 if (srcRank > dstRank)
1660 return BroadcastableToResult::SourceRankHigher;
1661 // Source has an exact match or singleton value for all trailing dimensions
1662 // (all leading dimensions are simply duplicated).
1663 int64_t lead = dstRank - srcRank;
1664 for (int64_t r = 0; r < srcRank; ++r) {
1665 int64_t srcDim = srcVectorType.getDimSize(r);
1666 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1667 if (srcDim != 1 && srcDim != dstDim) {
1668 if (mismatchingDims) {
1669 mismatchingDims->first = srcDim;
1670 mismatchingDims->second = dstDim;
1671 }
1672 return BroadcastableToResult::DimensionMismatch;
1673 }
1674 }
1675
1676 return BroadcastableToResult::Success;
1677}
1678
1679LogicalResult BroadcastOp::verify() {
1680 std::pair<int, int> mismatchingDims;
1681 BroadcastableToResult res =
1682 isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
1683 if (res == BroadcastableToResult::Success)
1684 return success();
1685 if (res == BroadcastableToResult::SourceRankHigher)
1686 return emitOpError("source rank higher than destination rank");
1687 if (res == BroadcastableToResult::DimensionMismatch)
1688 return emitOpError("dimension mismatch (")
1689 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1690 if (res == BroadcastableToResult::SourceTypeNotAVector)
1691 return emitOpError("source type is not a vector");
1692 llvm_unreachable("unexpected vector.broadcast op error")::llvm::llvm_unreachable_internal("unexpected vector.broadcast op error"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1692)
;
1693}
1694
1695OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1696 if (getSourceType() == getVectorType())
1697 return getSource();
1698 if (!operands[0])
1699 return {};
1700 auto vectorType = getVectorType();
1701 if (operands[0].isa<IntegerAttr, FloatAttr>())
1702 return DenseElementsAttr::get(vectorType, operands[0]);
1703 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1704 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1705 return {};
1706}
1707
1708namespace {
1709
1710// Fold broadcast1(broadcast2(x)) into broadcast1(x).
1711struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1712 using OpRewritePattern::OpRewritePattern;
1713
1714 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1715 PatternRewriter &rewriter) const override {
1716 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
1717 if (!srcBroadcast)
1718 return failure();
1719 rewriter.replaceOpWithNewOp<BroadcastOp>(
1720 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
1721 return success();
1722 }
1723};
1724} // namespace
1725
1726void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1727 MLIRContext *context) {
1728 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1729 // calling `populateCastAwayVectorLeadingOneDimPatterns`
1730 results.add<BroadcastFolder>(context);
1731}
1732
1733//===----------------------------------------------------------------------===//
1734// ShuffleOp
1735//===----------------------------------------------------------------------===//
1736
1737void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1738 Value v2, ArrayRef<int64_t> mask) {
1739 build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
1740}
1741
1742LogicalResult ShuffleOp::verify() {
1743 VectorType resultType = getVectorType();
1744 VectorType v1Type = getV1VectorType();
1745 VectorType v2Type = getV2VectorType();
1746 // Verify ranks.
1747 int64_t resRank = resultType.getRank();
1748 int64_t v1Rank = v1Type.getRank();
1749 int64_t v2Rank = v2Type.getRank();
1750 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
1751 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
1752 if (!wellFormed0DCase && !wellFormedNDCase)
1753 return emitOpError("rank mismatch");
1754
1755 // Verify all but leading dimension sizes.
1756 for (int64_t r = 1; r < v1Rank; ++r) {
1757 int64_t resDim = resultType.getDimSize(r);
1758 int64_t v1Dim = v1Type.getDimSize(r);
1759 int64_t v2Dim = v2Type.getDimSize(r);
1760 if (resDim != v1Dim || v1Dim != v2Dim)
1761 return emitOpError("dimension mismatch");
1762 }
1763 // Verify mask length.
1764 auto maskAttr = getMask().getValue();
1765 int64_t maskLength = maskAttr.size();
1766 if (maskLength <= 0)
1767 return emitOpError("invalid mask length");
1768 if (maskLength != resultType.getDimSize(0))
1769 return emitOpError("mask length mismatch");
1770 // Verify all indices.
1771 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
1772 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
1773 for (const auto &en : llvm::enumerate(maskAttr)) {
1774 auto attr = en.value().dyn_cast<IntegerAttr>();
1775 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1776 return emitOpError("mask index #") << (en.index() + 1) << " out of range";
1777 }
1778 return success();
1779}
1780
1781LogicalResult
1782ShuffleOp::inferReturnTypes(MLIRContext *, Optional<Location>,
1783 ValueRange operands, DictionaryAttr attributes,
1784 RegionRange,
1785 SmallVectorImpl<Type> &inferredReturnTypes) {
1786 ShuffleOp::Adaptor op(operands, attributes);
1787 auto v1Type = op.getV1().getType().cast<VectorType>();
1788 auto v1Rank = v1Type.getRank();
1789 // Construct resulting type: leading dimension matches mask
1790 // length, all trailing dimensions match the operands.
1791 SmallVector<int64_t, 4> shape;
1792 shape.reserve(v1Rank);
1793 shape.push_back(std::max<size_t>(1, op.getMask().size()));
1794 // In the 0-D case there is no trailing shape to append.
1795 if (v1Rank > 0)
1796 llvm::append_range(shape, v1Type.getShape().drop_front());
1797 inferredReturnTypes.push_back(
1798 VectorType::get(shape, v1Type.getElementType()));
1799 return success();
1800}
1801
1802static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
1803 uint64_t expected = begin;
1804 return idxArr.size() == width &&
1805 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
1806 [&expected](auto attr) {
1807 return attr.getZExtValue() == expected++;
1808 });
1809}
1810
1811OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
1812 VectorType v1Type = getV1VectorType();
1813 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
1814 // but must be a canonicalization into a vector.broadcast.
1815 if (v1Type.getRank() == 0)
1816 return {};
1817
1818 // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
1819 if (!v1Type.isScalable() &&
1820 isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
1821 return getV1();
1822 // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
1823 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
1824 isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
1825 getV2VectorType().getDimSize(0)))
1826 return getV2();
1827
1828 Attribute lhs = operands.front(), rhs = operands.back();
1829 if (!lhs || !rhs)
1830 return {};
1831
1832 auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
1833 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
1834 // manipulation.
1835 if (lhsType.getRank() != 1)
1836 return {};
1837 int64_t lhsSize = lhsType.getDimSize(0);
1838
1839 SmallVector<Attribute> results;
1840 auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
1841 auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
1842 for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
1843 int64_t i = index.getZExtValue();
1844 if (i >= lhsSize) {
1845 results.push_back(rhsElements[i - lhsSize]);
1846 } else {
1847 results.push_back(lhsElements[i]);
1848 }
1849 }
1850
1851 return DenseElementsAttr::get(getVectorType(), results);
1852}
1853
1854namespace {
1855
1856// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
1857// to a broadcast.
1858struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
1859 using OpRewritePattern::OpRewritePattern;
1860
1861 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
1862 PatternRewriter &rewriter) const override {
1863 VectorType v1VectorType = shuffleOp.getV1VectorType();
1864 ArrayAttr mask = shuffleOp.getMask();
1865 if (v1VectorType.getRank() > 0)
1866 return failure();
1867 if (mask.size() != 1)
1868 return failure();
1869 Type resType = VectorType::Builder(v1VectorType).setShape({1});
1870 if (mask[0].cast<IntegerAttr>().getInt() == 0)
1871 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
1872 shuffleOp.getV1());
1873 else
1874 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
1875 shuffleOp.getV2());
1876 return success();
1877 }
1878};
1879
1880/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
1881class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
1882public:
1883 using OpRewritePattern::OpRewritePattern;
1884
1885 LogicalResult matchAndRewrite(ShuffleOp op,
1886 PatternRewriter &rewriter) const override {
1887 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
1888 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
1889
1890 if (!v1Splat || !v2Splat)
1891 return failure();
1892
1893 if (v1Splat.getInput() != v2Splat.getInput())
1894 return failure();
1895
1896 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
1897 return success();
1898 }
1899};
1900
1901} // namespace
1902
1903void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
1904 MLIRContext *context) {
1905 results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
1906}
1907
1908//===----------------------------------------------------------------------===//
1909// InsertElementOp
1910//===----------------------------------------------------------------------===//
1911
1912void InsertElementOp::build(OpBuilder &builder, OperationState &result,
1913 Value source, Value dest) {
1914 build(builder, result, source, dest, {});
1915}
1916
1917LogicalResult InsertElementOp::verify() {
1918 auto dstVectorType = getDestVectorType();
1919 if (dstVectorType.getRank() == 0) {
1920 if (getPosition())
1921 return emitOpError("expected position to be empty with 0-D vector");
1922 return success();
1923 }
1924 if (dstVectorType.getRank() != 1)
1925 return emitOpError("unexpected >1 vector rank");
1926 if (!getPosition())
1927 return emitOpError("expected position for 1-D vector");
1928 return success();
1929}
1930
1931OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
1932 // Skip the 0-D vector here.
1933 if (operands.size() < 3)
1934 return {};
1935
1936 Attribute src = operands[0];
1937 Attribute dst = operands[1];
1938 Attribute pos = operands[2];
1939 if (!src || !dst || !pos)
1940 return {};
1941
1942 auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
1943
1944 SmallVector<Attribute> results(dstElements);
1945
1946 auto attr = pos.dyn_cast<IntegerAttr>();
1947 uint64_t posIdx = attr.getInt();
1948
1949 results[posIdx] = src;
1950
1951 return DenseElementsAttr::get(getDestVectorType(), results);
1952}
1953
1954//===----------------------------------------------------------------------===//
1955// InsertOp
1956//===----------------------------------------------------------------------===//
1957
1958void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1959 Value dest, ArrayRef<int64_t> position) {
1960 result.addOperands({source, dest});
1961 auto positionAttr = getVectorSubscriptAttr(builder, position);
1962 result.addTypes(dest.getType());
1963 result.addAttribute(getPositionAttrStrName(), positionAttr);
1964}
1965
1966// Convenience builder which assumes the values are constant indices.
1967void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
1968 Value dest, ValueRange position) {
1969 SmallVector<int64_t, 4> positionConstants =
1970 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1971 return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1972 }));
1973 build(builder, result, source, dest, positionConstants);
1974}
1975
1976LogicalResult InsertOp::verify() {
1977 auto positionAttr = getPosition().getValue();
1978 auto destVectorType = getDestVectorType();
1979 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
1980 return emitOpError(
1981 "expected position attribute of rank smaller than dest vector rank");
1982 auto srcVectorType = getSourceType().dyn_cast<VectorType>();
1983 if (srcVectorType &&
1984 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
1985 static_cast<unsigned>(destVectorType.getRank())))
1986 return emitOpError("expected position attribute rank + source rank to "
1987 "match dest vector rank");
1988 if (!srcVectorType &&
1989 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
1990 return emitOpError(
1991 "expected position attribute rank to match the dest vector rank");
1992 for (const auto &en : llvm::enumerate(positionAttr)) {
1993 auto attr = en.value().dyn_cast<IntegerAttr>();
1994 if (!attr || attr.getInt() < 0 ||
1995 attr.getInt() >= destVectorType.getDimSize(en.index()))
1996 return emitOpError("expected position attribute #")
1997 << (en.index() + 1)
1998 << " to be a non-negative integer smaller than the corresponding "
1999 "dest vector dimension";
2000 }
2001 return success();
2002}
2003
2004namespace {
2005
2006// If insertOp is only inserting unit dimensions it can be transformed to a
2007// broadcast.
2008class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2009public:
2010 using OpRewritePattern::OpRewritePattern;
2011
2012 LogicalResult matchAndRewrite(InsertOp insertOp,
2013 PatternRewriter &rewriter) const override {
2014 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
2015 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2016 srcVecType.getNumElements())
2017 return failure();
2018 rewriter.replaceOpWithNewOp<BroadcastOp>(
2019 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2020 return success();
2021 }
2022};
2023
2024/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2025class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2026public:
2027 using OpRewritePattern::OpRewritePattern;
2028
2029 LogicalResult matchAndRewrite(InsertOp op,
2030 PatternRewriter &rewriter) const override {
2031 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2032 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2033
2034 if (!srcSplat || !dstSplat)
2035 return failure();
2036
2037 if (srcSplat.getInput() != dstSplat.getInput())
2038 return failure();
2039
2040 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2041 return success();
2042 }
2043};
2044
2045} // namespace
2046
2047void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2048 MLIRContext *context) {
2049 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
2050}
2051
2052// Eliminates insert operations that produce values identical to their source
2053// value. This happens when the source and destination vectors have identical
2054// sizes.
2055OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
2056 if (getPosition().empty())
2057 return getSource();
2058 return {};
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// InsertStridedSliceOp
2063//===----------------------------------------------------------------------===//
2064
2065void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2066 Value source, Value dest,
2067 ArrayRef<int64_t> offsets,
2068 ArrayRef<int64_t> strides) {
2069 result.addOperands({source, dest});
2070 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2071 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2072 result.addTypes(dest.getType());
2073 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2074 result.addAttribute(getStridesAttrStrName(), stridesAttr);
2075}
2076
2077// TODO: Should be moved to Tablegen ConfinedAttr attributes.
2078template <typename OpType>
2079static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
2080 ArrayAttr arrayAttr,
2081 ArrayRef<int64_t> shape,
2082 StringRef attrName) {
2083 if (arrayAttr.size() > shape.size())
2084 return op.emitOpError("expected ")
2085 << attrName << " attribute of rank smaller than vector rank";
2086 return success();
2087}
2088
2089// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2090// interval. If `halfOpen` is true then the admissible interval is [min, max).
2091// Otherwise, the admissible interval is [min, max].
2092template <typename OpType>
2093static LogicalResult
2094isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2095 int64_t max, StringRef attrName,
2096 bool halfOpen = true) {
2097 for (auto attr : arrayAttr) {
2098 auto val = attr.cast<IntegerAttr>().getInt();
2099 auto upper = max;
2100 if (!halfOpen)
2101 upper += 1;
2102 if (val < min || val >= upper)
2103 return op.emitOpError("expected ") << attrName << " to be confined to ["
2104 << min << ", " << upper << ")";
2105 }
2106 return success();
2107}
2108
2109// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2110// interval. If `halfOpen` is true then the admissible interval is [min, max).
2111// Otherwise, the admissible interval is [min, max].
2112template <typename OpType>
2113static LogicalResult
2114isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2115 ArrayRef<int64_t> shape, StringRef attrName,
2116 bool halfOpen = true, int64_t min = 0) {
2117 assert(arrayAttr.size() <= shape.size())(static_cast <bool> (arrayAttr.size() <= shape.size(
)) ? void (0) : __assert_fail ("arrayAttr.size() <= shape.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2117, __extension__
__PRETTY_FUNCTION__))
;
2118 unsigned index = 0;
2119 for (auto it : llvm::zip(arrayAttr, shape)) {
2120 auto val = std::get<0>(it).cast<IntegerAttr>().getInt();
2121 auto max = std::get<1>(it);
2122 if (!halfOpen)
2123 max += 1;
2124 if (val < min || val >= max)
2125 return op.emitOpError("expected ")
2126 << attrName << " dimension " << index << " to be confined to ["
2127 << min << ", " << max << ")";
2128 ++index;
2129 }
2130 return success();
2131}
2132
2133// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
2134// interval. If `halfOpen` is true then the admissible interval is [min, max).
2135// Otherwise, the admissible interval is [min, max].
2136template <typename OpType>
2137static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
2138 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2139 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2140 bool halfOpen = true, int64_t min = 1) {
2141 assert(arrayAttr1.size() <= shape.size())(static_cast <bool> (arrayAttr1.size() <= shape.size
()) ? void (0) : __assert_fail ("arrayAttr1.size() <= shape.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2141, __extension__
__PRETTY_FUNCTION__))
;
2142 assert(arrayAttr2.size() <= shape.size())(static_cast <bool> (arrayAttr2.size() <= shape.size
()) ? void (0) : __assert_fail ("arrayAttr2.size() <= shape.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2142, __extension__
__PRETTY_FUNCTION__))
;
2143 unsigned index = 0;
2144 for (auto it : llvm::zip(arrayAttr1, arrayAttr2, shape)) {
2145 auto val1 = std::get<0>(it).cast<IntegerAttr>().getInt();
2146 auto val2 = std::get<1>(it).cast<IntegerAttr>().getInt();
2147 auto max = std::get<2>(it);
2148 if (!halfOpen)
2149 max += 1;
2150 if (val1 + val2 < 0 || val1 + val2 >= max)
2151 return op.emitOpError("expected sum(")
2152 << attrName1 << ", " << attrName2 << ") dimension " << index
2153 << " to be confined to [" << min << ", " << max << ")";
2154 ++index;
2155 }
2156 return success();
2157}
2158
2159static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2160 MLIRContext *context) {
2161 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2162 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2163 });
2164 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2165}
2166
2167LogicalResult InsertStridedSliceOp::verify() {
2168 auto sourceVectorType = getSourceVectorType();
2169 auto destVectorType = getDestVectorType();
2170 auto offsets = getOffsetsAttr();
2171 auto strides = getStridesAttr();
2172 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2173 return emitOpError(
2174 "expected offsets of same size as destination vector rank");
2175 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2176 return emitOpError("expected strides of same size as source vector rank");
2177 if (sourceVectorType.getRank() > destVectorType.getRank())
2178 return emitOpError(
2179 "expected source rank to be smaller than destination rank");
2180
2181 auto sourceShape = sourceVectorType.getShape();
2182 auto destShape = destVectorType.getShape();
2183 SmallVector<int64_t, 4> sourceShapeAsDestShape(
2184 destShape.size() - sourceShape.size(), 0);
2185 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2186 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2187 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2188 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2189 offName)) ||
2190 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2191 stridesName,
2192 /*halfOpen=*/false)) ||
2193 failed(isSumOfIntegerArrayAttrConfinedToShape(
2194 *this, offsets,
2195 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2196 offName, "source vector shape",
2197 /*halfOpen=*/false, /*min=*/1)))
2198 return failure();
2199
2200 return success();
2201}
2202
2203namespace {
2204/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2205/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2206class FoldInsertStridedSliceSplat final
2207 : public OpRewritePattern<InsertStridedSliceOp> {
2208public:
2209 using OpRewritePattern::OpRewritePattern;
2210
2211 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2212 PatternRewriter &rewriter) const override {
2213 auto srcSplatOp =
2214 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2215 auto destSplatOp =
2216 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2217
2218 if (!srcSplatOp || !destSplatOp)
2219 return failure();
2220
2221 if (srcSplatOp.getInput() != destSplatOp.getInput())
2222 return failure();
2223
2224 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2225 return success();
2226 }
2227};
2228
2229/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2230/// to dst.
2231class FoldInsertStridedSliceOfExtract final
2232 : public OpRewritePattern<InsertStridedSliceOp> {
2233public:
2234 using OpRewritePattern::OpRewritePattern;
2235
2236 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2237 PatternRewriter &rewriter) const override {
2238 auto extractStridedSliceOp =
2239 insertStridedSliceOp.getSource()
2240 .getDefiningOp<vector::ExtractStridedSliceOp>();
2241
2242 if (!extractStridedSliceOp)
2243 return failure();
2244
2245 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2246 return failure();
2247
2248 // Check if have the same strides and offsets.
2249 if (extractStridedSliceOp.getStrides() !=
2250 insertStridedSliceOp.getStrides() ||
2251 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2252 return failure();
2253
2254 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2255 return success();
2256 }
2257};
2258
2259} // namespace
2260
2261void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2262 RewritePatternSet &results, MLIRContext *context) {
2263 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract>(
2264 context);
2265}
2266
2267OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2268 if (getSourceVectorType() == getDestVectorType())
2269 return getSource();
2270 return {};
2271}
2272
2273//===----------------------------------------------------------------------===//
2274// OuterProductOp
2275//===----------------------------------------------------------------------===//
2276
2277/// Build an op without mask, use the type of `acc` as the return type.
2278void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2279 Value lhs, Value rhs, Value acc) {
2280 result.addOperands({lhs, rhs, acc});
2281 result.addTypes(acc.getType());
2282}
2283
2284void OuterProductOp::print(OpAsmPrinter &p) {
2285 p << " " << getLhs() << ", " << getRhs();
2286 if (!getAcc().empty()) {
2287 p << ", " << getAcc();
2288 p.printOptionalAttrDict((*this)->getAttrs());
2289 }
2290 p << " : " << getLhs().getType() << ", " << getRhs().getType();
2291}
2292
2293ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
2294 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
2295 Type tLHS, tRHS;
2296 if (parser.parseOperandList(operandsInfo) ||
2297 parser.parseOptionalAttrDict(result.attributes) ||
2298 parser.parseColonType(tLHS) || parser.parseComma() ||
2299 parser.parseType(tRHS))
2300 return failure();
2301 if (operandsInfo.size() < 2)
2302 return parser.emitError(parser.getNameLoc(),
2303 "expected at least 2 operands");
2304 VectorType vLHS = tLHS.dyn_cast<VectorType>();
2305 VectorType vRHS = tRHS.dyn_cast<VectorType>();
2306 if (!vLHS)
2307 return parser.emitError(parser.getNameLoc(),
2308 "expected vector type for operand #1");
2309 VectorType resType =
2310 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2311 vLHS.getElementType())
2312 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2313
2314 if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
2315 result.attributes.append(
2316 OuterProductOp::getKindAttrStrName(),
2317 CombiningKindAttr::get(result.getContext(),
2318 OuterProductOp::getDefaultKind()));
2319 }
2320
2321 return failure(
2322 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2323 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2324 (operandsInfo.size() > 2 &&
2325 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2326 parser.addTypeToList(resType, result.types));
2327}
2328
2329LogicalResult OuterProductOp::verify() {
2330 Type tRHS = getOperandTypeRHS();
2331 VectorType vLHS = getOperandVectorTypeLHS(),
2332 vRHS = tRHS.dyn_cast<VectorType>(),
2333 vACC = getOperandVectorTypeACC(), vRES = getVectorType();
2334
2335 if (vLHS.getRank() != 1)
2336 return emitOpError("expected 1-d vector for operand #1");
2337
2338 if (vRHS) {
2339 // Proper OUTER operation.
2340 if (vRHS.getRank() != 1)
2341 return emitOpError("expected 1-d vector for operand #2");
2342 if (vRES.getRank() != 2)
2343 return emitOpError("expected 2-d vector result");
2344 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2345 return emitOpError("expected #1 operand dim to match result dim #1");
2346 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2347 return emitOpError("expected #2 operand dim to match result dim #2");
2348 } else {
2349 // An AXPY operation.
2350 if (vRES.getRank() != 1)
2351 return emitOpError("expected 1-d vector result");
2352 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2353 return emitOpError("expected #1 operand dim to match result dim #1");
2354 }
2355
2356 if (vACC && vACC != vRES)
2357 return emitOpError("expected operand #3 of same type as result type");
2358
2359 // Verify supported combining kind.
2360 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
2361 return emitOpError("unsupported outerproduct type");
2362
2363 return success();
2364}
2365
2366//===----------------------------------------------------------------------===//
2367// ReshapeOp
2368//===----------------------------------------------------------------------===//
2369
2370LogicalResult ReshapeOp::verify() {
2371 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2372 auto inputVectorType = getInputVectorType();
2373 auto outputVectorType = getOutputVectorType();
2374 int64_t inputShapeRank = getNumInputShapeSizes();
2375 int64_t outputShapeRank = getNumOutputShapeSizes();
2376 SmallVector<int64_t, 4> fixedVectorSizes;
2377 getFixedVectorSizes(fixedVectorSizes);
2378 int64_t numFixedVectorSizes = fixedVectorSizes.size();
2379
2380 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2381 return emitError("invalid input shape for vector type ") << inputVectorType;
2382
2383 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2384 return emitError("invalid output shape for vector type ")
2385 << outputVectorType;
2386
2387 // Verify that the 'fixedVectorSizes' match an input/output vector shape
2388 // suffix.
2389 unsigned inputVectorRank = inputVectorType.getRank();
2390 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2391 unsigned index = inputVectorRank - numFixedVectorSizes - i;
2392 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2393 return emitError("fixed vector size must match input vector for dim ")
2394 << i;
2395 }
2396
2397 unsigned outputVectorRank = outputVectorType.getRank();
2398 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2399 unsigned index = outputVectorRank - numFixedVectorSizes - i;
2400 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2401 return emitError("fixed vector size must match output vector for dim ")
2402 << i;
2403 }
2404
2405 // If all shape operands are produced by constant ops, verify that product
2406 // of dimensions for input/output shape match.
2407 auto isDefByConstant = [](Value operand) {
2408 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2409 };
2410 if (llvm::all_of(getInputShape(), isDefByConstant) &&
2411 llvm::all_of(getOutputShape(), isDefByConstant)) {
2412 int64_t numInputElements = 1;
2413 for (auto operand : getInputShape())
2414 numInputElements *=
2415 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2416 int64_t numOutputElements = 1;
2417 for (auto operand : getOutputShape())
2418 numOutputElements *=
2419 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2420 if (numInputElements != numOutputElements)
2421 return emitError("product of input and output shape sizes must match");
2422 }
2423 return success();
2424}
2425
2426void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2427 populateFromInt64AttrArray(getFixedVectorSizes(), results);
2428}
2429
2430//===----------------------------------------------------------------------===//
2431// ExtractStridedSliceOp
2432//===----------------------------------------------------------------------===//
2433
2434// Inference works as follows:
2435// 1. Add 'sizes' from prefix of dims in 'offsets'.
2436// 2. Add sizes from 'vectorType' for remaining dims.
2437static Type inferStridedSliceOpResultType(VectorType vectorType,
2438 ArrayAttr offsets, ArrayAttr sizes,
2439 ArrayAttr strides) {
2440 assert(offsets.size() == sizes.size() && offsets.size() == strides.size())(static_cast <bool> (offsets.size() == sizes.size() &&
offsets.size() == strides.size()) ? void (0) : __assert_fail
("offsets.size() == sizes.size() && offsets.size() == strides.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2440, __extension__
__PRETTY_FUNCTION__))
;
2441 SmallVector<int64_t, 4> shape;
2442 shape.reserve(vectorType.getRank());
2443 unsigned idx = 0;
2444 for (unsigned e = offsets.size(); idx < e; ++idx)
2445 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2446 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2447 shape.push_back(vectorType.getShape()[idx]);
2448
2449 return VectorType::get(shape, vectorType.getElementType());
2450}
2451
2452void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2453 Value source, ArrayRef<int64_t> offsets,
2454 ArrayRef<int64_t> sizes,
2455 ArrayRef<int64_t> strides) {
2456 result.addOperands(source);
2457 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2458 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2459 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2460 result.addTypes(
2461 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2462 offsetsAttr, sizesAttr, stridesAttr));
2463 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2464 result.addAttribute(getSizesAttrStrName(), sizesAttr);
2465 result.addAttribute(getStridesAttrStrName(), stridesAttr);
2466}
2467
2468LogicalResult ExtractStridedSliceOp::verify() {
2469 auto type = getVectorType();
2470 auto offsets = getOffsetsAttr();
2471 auto sizes = getSizesAttr();
2472 auto strides = getStridesAttr();
2473 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2474 return emitOpError(
2475 "expected offsets, sizes and strides attributes of same size");
2476
2477 auto shape = type.getShape();
2478 auto offName = getOffsetsAttrName();
2479 auto sizesName = getSizesAttrName();
2480 auto stridesName = getStridesAttrName();
2481 if (failed(
2482 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
2483 failed(
2484 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
2485 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
2486 stridesName)) ||
2487 failed(
2488 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
2489 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
2490 /*halfOpen=*/false,
2491 /*min=*/1)) ||
2492 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2493 stridesName,
2494 /*halfOpen=*/false)) ||
2495 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
2496 shape, offName, sizesName,
2497 /*halfOpen=*/false)))
2498 return failure();
2499
2500 auto resultType =
2501 inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
2502 if (getResult().getType() != resultType)
2503 return emitOpError("expected result type to be ") << resultType;
2504
2505 return success();
2506}
2507
2508// When the source of ExtractStrided comes from a chain of InsertStrided ops try
2509// to use the source of the InsertStrided ops if we can detect that the
2510// extracted vector is a subset of one of the vector inserted.
2511static LogicalResult
2512foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2513 // Helper to extract integer out of ArrayAttr.
2514 auto getElement = [](ArrayAttr array, int idx) {
2515 return array[idx].cast<IntegerAttr>().getInt();
2516 };
2517 ArrayAttr extractOffsets = op.getOffsets();
2518 ArrayAttr extractStrides = op.getStrides();
2519 ArrayAttr extractSizes = op.getSizes();
2520 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
2521 while (insertOp) {
2522 if (op.getVectorType().getRank() !=
2523 insertOp.getSourceVectorType().getRank())
2524 return failure();
2525 ArrayAttr insertOffsets = insertOp.getOffsets();
2526 ArrayAttr insertStrides = insertOp.getStrides();
2527 // If the rank of extract is greater than the rank of insert, we are likely
2528 // extracting a partial chunk of the vector inserted.
2529 if (extractOffsets.size() > insertOffsets.size())
2530 return failure();
2531 bool patialoverlap = false;
2532 bool disjoint = false;
2533 SmallVector<int64_t, 4> offsetDiffs;
2534 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2535 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2536 return failure();
2537 int64_t start = getElement(insertOffsets, dim);
2538 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2539 int64_t offset = getElement(extractOffsets, dim);
2540 int64_t size = getElement(extractSizes, dim);
2541 // Check if the start of the extract offset is in the interval inserted.
2542 if (start <= offset && offset < end) {
2543 // If the extract interval overlaps but is not fully included we may
2544 // have a partial overlap that will prevent any folding.
2545 if (offset + size > end)
2546 patialoverlap = true;
2547 offsetDiffs.push_back(offset - start);
2548 continue;
2549 }
2550 disjoint = true;
2551 break;
2552 }
2553 // The extract element chunk is a subset of the insert element.
2554 if (!disjoint && !patialoverlap) {
2555 op.setOperand(insertOp.getSource());
2556 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2557 OpBuilder b(op.getContext());
2558 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
2559 b.getI64ArrayAttr(offsetDiffs));
2560 return success();
2561 }
2562 // If the chunk extracted is disjoint from the chunk inserted, keep looking
2563 // in the insert chain.
2564 if (disjoint)
2565 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2566 else {
2567 // The extracted vector partially overlap the inserted vector, we cannot
2568 // fold.
2569 return failure();
2570 }
2571 }
2572 return failure();
2573}
2574
2575OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2576 if (getVectorType() == getResult().getType())
2577 return getVector();
2578 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2579 return getResult();
2580 return {};
2581}
2582
2583void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2584 populateFromInt64AttrArray(getOffsets(), results);
2585}
2586
2587namespace {
2588
2589// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2590// ConstantMaskOp.
2591class StridedSliceConstantMaskFolder final
2592 : public OpRewritePattern<ExtractStridedSliceOp> {
2593public:
2594 using OpRewritePattern::OpRewritePattern;
2595
2596 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2597 PatternRewriter &rewriter) const override {
2598 // Return if 'extractStridedSliceOp' operand is not defined by a
2599 // ConstantMaskOp.
2600 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
2601 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2602 if (!constantMaskOp)
2603 return failure();
2604 // Return if 'extractStridedSliceOp' has non-unit strides.
2605 if (extractStridedSliceOp.hasNonUnitStrides())
2606 return failure();
2607 // Gather constant mask dimension sizes.
2608 SmallVector<int64_t, 4> maskDimSizes;
2609 populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
2610 // Gather strided slice offsets and sizes.
2611 SmallVector<int64_t, 4> sliceOffsets;
2612 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
2613 sliceOffsets);
2614 SmallVector<int64_t, 4> sliceSizes;
2615 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
2616
2617 // Compute slice of vector mask region.
2618 SmallVector<int64_t, 4> sliceMaskDimSizes;
2619 assert(sliceOffsets.size() == maskDimSizes.size())(static_cast <bool> (sliceOffsets.size() == maskDimSizes
.size()) ? void (0) : __assert_fail ("sliceOffsets.size() == maskDimSizes.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2619, __extension__
__PRETTY_FUNCTION__))
;
2620 for (auto it : llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2621 int64_t maskDimSize = std::get<0>(it);
2622 int64_t sliceOffset = std::get<1>(it);
2623 int64_t sliceSize = std::get<2>(it);
2624 int64_t sliceMaskDimSize = std::max(
2625 static_cast<int64_t>(0),
2626 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2627 sliceMaskDimSizes.push_back(sliceMaskDimSize);
2628 }
2629 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2630 // region is a conjunction of mask dim intervals).
2631 if (llvm::is_contained(sliceMaskDimSizes, 0))
2632 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2633
2634 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2635 // region.
2636 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2637 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2638 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2639 return success();
2640 }
2641};
2642
2643// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
2644class StridedSliceConstantFolder final
2645 : public OpRewritePattern<ExtractStridedSliceOp> {
2646public:
2647 using OpRewritePattern::OpRewritePattern;
2648
2649 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2650 PatternRewriter &rewriter) const override {
2651 // Return if 'extractStridedSliceOp' operand is not defined by a
2652 // ConstantOp.
2653 auto constantOp =
2654 extractStridedSliceOp.getVector().getDefiningOp<arith::ConstantOp>();
2655 if (!constantOp)
2656 return failure();
2657 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
2658 if (!dense)
2659 return failure();
2660 auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
2661 dense.getSplatValue<Attribute>());
2662 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
2663 newAttr);
2664 return success();
2665 }
2666};
2667
2668// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
2669// BroadcastOp(ExtractStrideSliceOp).
2670class StridedSliceBroadcast final
2671 : public OpRewritePattern<ExtractStridedSliceOp> {
2672public:
2673 using OpRewritePattern::OpRewritePattern;
2674
2675 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2676 PatternRewriter &rewriter) const override {
2677 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
2678 if (!broadcast)
2679 return failure();
2680 auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
2681 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
2682 auto dstVecType = op.getType().cast<VectorType>();
2683 unsigned dstRank = dstVecType.getRank();
2684 unsigned rankDiff = dstRank - srcRank;
2685 // Check if the most inner dimensions of the source of the broadcast are the
2686 // same as the destination of the extract. If this is the case we can just
2687 // use a broadcast as the original dimensions are untouched.
2688 bool lowerDimMatch = true;
2689 for (unsigned i = 0; i < srcRank; i++) {
2690 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
2691 lowerDimMatch = false;
2692 break;
2693 }
2694 }
2695 Value source = broadcast.getSource();
2696 // If the inner dimensions don't match, it means we need to extract from the
2697 // source of the orignal broadcast and then broadcast the extracted value.
2698 // We also need to handle degenerated cases where the source is effectively
2699 // just a single scalar.
2700 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
2701 if (!lowerDimMatch && !isScalarSrc) {
2702 source = rewriter.create<ExtractStridedSliceOp>(
2703 op->getLoc(), source,
2704 getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
2705 getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
2706 getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
2707 }
2708 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
2709 return success();
2710 }
2711};
2712
2713/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
2714class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
2715public:
2716 using OpRewritePattern::OpRewritePattern;
2717
2718 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
2719 PatternRewriter &rewriter) const override {
2720 auto splat = op.getVector().getDefiningOp<SplatOp>();
2721 if (!splat)
2722 return failure();
2723 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
2724 return success();
2725 }
2726};
2727
2728} // namespace
2729
2730void ExtractStridedSliceOp::getCanonicalizationPatterns(
2731 RewritePatternSet &results, MLIRContext *context) {
2732 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
2733 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
2734 results.add<StridedSliceConstantMaskFolder, StridedSliceConstantFolder,
2735 StridedSliceBroadcast, StridedSliceSplat>(context);
2736}
2737
2738//===----------------------------------------------------------------------===//
2739// TransferReadOp
2740//===----------------------------------------------------------------------===//
2741
2742/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
2743void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2744 VectorType vectorType, Value source,
2745 ValueRange indices, AffineMapAttr permutationMapAttr,
2746 /*optional*/ ArrayAttr inBoundsAttr) {
2747 Type elemType = source.getType().cast<ShapedType>().getElementType();
2748 Value padding = builder.create<arith::ConstantOp>(
2749 result.location, elemType, builder.getZeroAttr(elemType));
2750 build(builder, result, vectorType, source, indices, permutationMapAttr,
2751 padding, /*mask=*/Value(), inBoundsAttr);
2752}
2753
2754/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
2755void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2756 VectorType vectorType, Value source,
2757 ValueRange indices, AffineMap permutationMap,
2758 Optional<ArrayRef<bool>> inBounds) {
2759 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2760 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2761 ? builder.getBoolArrayAttr(inBounds.value())
2762 : ArrayAttr();
2763 build(builder, result, vectorType, source, indices, permutationMapAttr,
2764 inBoundsAttr);
2765}
2766
2767/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
2768void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2769 VectorType vectorType, Value source,
2770 ValueRange indices, Value padding,
2771 Optional<ArrayRef<bool>> inBounds) {
2772 AffineMap permutationMap = getTransferMinorIdentityMap(
2773 source.getType().cast<ShapedType>(), vectorType);
2774 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
2775 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
2776 ? builder.getBoolArrayAttr(inBounds.value())
2777 : ArrayAttr();
2778 build(builder, result, vectorType, source, indices, permutationMapAttr,
2779 padding,
2780 /*mask=*/Value(), inBoundsAttr);
2781}
2782
2783/// 4. Builder that sets padding to zero and permutation map to
2784/// 'getMinorIdentityMap'.
2785void TransferReadOp::build(OpBuilder &builder, OperationState &result,
2786 VectorType vectorType, Value source,
2787 ValueRange indices,
2788 Optional<ArrayRef<bool>> inBounds) {
2789 Type elemType = source.getType().cast<ShapedType>().getElementType();
2790 Value padding = builder.create<arith::ConstantOp>(
2791 result.location, elemType, builder.getZeroAttr(elemType));
2792 build(builder, result, vectorType, source, indices, padding, inBounds);
2793}
2794
2795template <typename EmitFun>
2796static LogicalResult verifyPermutationMap(AffineMap permutationMap,
2797 EmitFun emitOpError) {
2798 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
2799 for (auto expr : permutationMap.getResults()) {
2800 auto dim = expr.dyn_cast<AffineDimExpr>();
2801 auto zero = expr.dyn_cast<AffineConstantExpr>();
2802 if (zero) {
2803 if (zero.getValue() != 0) {
2804 return emitOpError(
2805 "requires a projected permutation_map (at most one dim or the zero "
2806 "constant can appear in each result)");
2807 }
2808 continue;
2809 }
2810 if (!dim) {
2811 return emitOpError("requires a projected permutation_map (at most one "
2812 "dim or the zero constant can appear in each result)");
2813 }
2814 if (seen[dim.getPosition()]) {
2815 return emitOpError(
2816 "requires a permutation_map that is a permutation (found one dim "
2817 "used more than once)");
2818 }
2819 seen[dim.getPosition()] = true;
2820 }
2821 return success();
2822}
2823
2824static LogicalResult
2825verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
2826 VectorType vectorType, VectorType maskType,
2827 AffineMap permutationMap, ArrayAttr inBounds) {
2828 if (op->hasAttr("masked")) {
2829 return op->emitOpError("masked attribute has been removed. "
2830 "Use in_bounds instead.");
2831 }
2832
2833 if (!shapedType.isa<MemRefType, RankedTensorType>())
2834 return op->emitOpError(
2835 "requires source to be a memref or ranked tensor type");
2836
2837 auto elementType = shapedType.getElementType();
2838 DataLayout dataLayout = DataLayout::closest(op);
2839 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
2840 // Memref or tensor has vector element type.
2841 unsigned sourceVecSize =
2842 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
2843 vectorElementType.getShape().back();
2844 unsigned resultVecSize =
2845 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
2846 vectorType.getShape().back();
2847 if (resultVecSize % sourceVecSize != 0)
2848 return op->emitOpError(
2849 "requires the bitwidth of the minor 1-D vector to be an integral "
2850 "multiple of the bitwidth of the minor 1-D vector of the source");
2851
2852 unsigned sourceVecEltRank = vectorElementType.getRank();
2853 unsigned resultVecRank = vectorType.getRank();
2854 if (sourceVecEltRank > resultVecRank)
2855 return op->emitOpError(
2856 "requires source vector element and vector result ranks to match.");
2857 unsigned rankOffset = resultVecRank - sourceVecEltRank;
2858 // Check that permutation map results match 'rankOffset' of vector type.
2859 if (permutationMap.getNumResults() != rankOffset)
2860 return op->emitOpError("requires a permutation_map with result dims of "
2861 "the same rank as the vector type");
2862
2863 if (maskType)
2864 return op->emitOpError("does not support masks with vector element type");
2865 } else {
2866 // Memref or tensor has scalar element type.
2867 unsigned minorSize =
2868 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
2869 unsigned resultVecSize =
2870 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
2871 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
2872 return op->emitOpError(
2873 "requires the bitwidth of the minor 1-D vector to be an integral "
2874 "multiple of the bitwidth of the source element type");
2875
2876 // Check that permutation map results match rank of vector type.
2877 if (permutationMap.getNumResults() != vectorType.getRank())
2878 return op->emitOpError("requires a permutation_map with result dims of "
2879 "the same rank as the vector type");
2880
2881 VectorType expectedMaskType =
2882 vector::detail::transferMaskType(vectorType, permutationMap);
2883 if (maskType && expectedMaskType != maskType)
2884 return op->emitOpError("expects mask type consistent with permutation "
2885 "map: ")
2886 << maskType;
2887 }
2888
2889 if (permutationMap.getNumSymbols() != 0)
2890 return op->emitOpError("requires permutation_map without symbols");
2891
2892 if (permutationMap.getNumInputs() != shapedType.getRank())
2893 return op->emitOpError("requires a permutation_map with input dims of the "
2894 "same rank as the source type");
2895
2896 if (inBounds) {
2897 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
2898 return op->emitOpError("expects the optional in_bounds attr of same rank "
2899 "as permutation_map results: ")
2900 << AffineMapAttr::get(permutationMap)
2901 << " vs inBounds of size: " << inBounds.size();
2902 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
2903 if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
2904 !inBounds.getValue()[i].cast<BoolAttr>().getValue())
2905 return op->emitOpError("requires broadcast dimensions to be in-bounds");
2906 }
2907
2908 return success();
2909}
2910
2911static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
2912 SmallVector<StringRef, 3> elidedAttrs;
2913 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
2914 if (op.permutation_map().isMinorIdentity())
2915 elidedAttrs.push_back(op.getPermutationMapAttrStrName());
2916 bool elideInBounds = true;
2917 if (auto inBounds = op.in_bounds()) {
2918 for (auto attr : *inBounds) {
2919 if (attr.template cast<BoolAttr>().getValue()) {
2920 elideInBounds = false;
2921 break;
2922 }
2923 }
2924 }
2925 if (elideInBounds)
2926 elidedAttrs.push_back(op.getInBoundsAttrStrName());
2927 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
2928}
2929
2930void TransferReadOp::print(OpAsmPrinter &p) {
2931 p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
2932 if (getMask())
2933 p << ", " << getMask();
2934 printTransferAttrs(p, *this);
2935 p << " : " << getShapedType() << ", " << getVectorType();
2936}
2937
2938ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
2939 auto &builder = parser.getBuilder();
2940 SMLoc typesLoc;
2941 OpAsmParser::UnresolvedOperand sourceInfo;
2942 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
2943 OpAsmParser::UnresolvedOperand paddingInfo;
2944 SmallVector<Type, 2> types;
2945 OpAsmParser::UnresolvedOperand maskInfo;
2946 // Parsing with support for paddingValue.
2947 if (parser.parseOperand(sourceInfo) ||
2948 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
2949 parser.parseComma() || parser.parseOperand(paddingInfo))
2950 return failure();
2951 ParseResult hasMask = parser.parseOptionalComma();
2952 if (hasMask.succeeded()) {
2953 if (parser.parseOperand(maskInfo))
2954 return failure();
2955 }
2956 if (parser.parseOptionalAttrDict(result.attributes) ||
2957 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
2958 return failure();
2959 if (types.size() != 2)
2960 return parser.emitError(typesLoc, "requires two types");
2961 auto indexType = builder.getIndexType();
2962 auto shapedType = types[0].dyn_cast<ShapedType>();
2963 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
2964 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
2965 VectorType vectorType = types[1].dyn_cast<VectorType>();
2966 if (!vectorType)
2967 return parser.emitError(typesLoc, "requires vector type");
2968 auto permutationAttrName = TransferReadOp::getPermutationMapAttrStrName();
2969 Attribute mapAttr = result.attributes.get(permutationAttrName);
2970 if (!mapAttr) {
2971 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
2972 // Update `mapAttr` that is used later to determine mask type.
2973 mapAttr = AffineMapAttr::get(permMap);
2974 result.attributes.set(permutationAttrName, mapAttr);
2975 }
2976 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
2977 parser.resolveOperands(indexInfo, indexType, result.operands) ||
2978 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
2979 result.operands))
2980 return failure();
2981 if (hasMask.succeeded()) {
2982 if (shapedType.getElementType().dyn_cast<VectorType>())
2983 return parser.emitError(
2984 maskInfo.location, "does not support masks with vector element type");
2985 auto map = mapAttr.dyn_cast<AffineMapAttr>().getValue();
2986 // Instead of adding the mask type as an op type, compute it based on the
2987 // vector type and the permutation map (to keep the type signature small).
2988 auto maskType = mlir::vector::detail::transferMaskType(vectorType, map);
2989 if (parser.resolveOperand(maskInfo, maskType, result.operands))
2990 return failure();
2991 }
2992 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
2993 builder.getDenseI32ArrayAttr(
2994 {1, static_cast<int32_t>(indexInfo.size()), 1,
2995 static_cast<int32_t>(hasMask.succeeded())}));
2996 return parser.addTypeToList(vectorType, result.types);
2997}
2998
2999LogicalResult TransferReadOp::verify() {
3000 // Consistency of elemental types in source and vector.
3001 ShapedType shapedType = getShapedType();
3002 VectorType vectorType = getVectorType();
3003 VectorType maskType = getMaskType();
3004 auto paddingType = getPadding().getType();
3005 auto permutationMap = getPermutationMap();
3006 auto sourceElementType = shapedType.getElementType();
3007
3008 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3009 return emitOpError("requires ") << shapedType.getRank() << " indices";
3010
3011 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3012 shapedType, vectorType, maskType, permutationMap,
3013 getInBounds() ? *getInBounds() : ArrayAttr())))
3014 return failure();
3015
3016 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
3017 // Source has vector element type.
3018 // Check that 'sourceVectorElementType' and 'paddingType' types match.
3019 if (sourceVectorElementType != paddingType)
3020 return emitOpError(
3021 "requires source element type and padding type to match.");
3022
3023 } else {
3024 // Check that 'paddingType' is valid to store in a vector type.
3025 if (!VectorType::isValidElementType(paddingType))
3026 return emitOpError("requires valid padding vector elemental type");
3027
3028 // Check that padding type and vector element types match.
3029 if (paddingType != sourceElementType)
3030 return emitOpError(
3031 "requires formal padding and source of the same elemental type");
3032 }
3033
3034 return verifyPermutationMap(permutationMap,
3035 [&](Twine t) { return emitOpError(t); });
3036}
3037
3038/// This is a common class used for patterns of the form
3039/// ```
3040/// someop(memrefcast) -> someop
3041/// ```
3042/// It folds the source of the memref.cast into the root operation directly.
3043static LogicalResult foldMemRefCast(Operation *op) {
3044 bool folded = false;
3045 for (OpOperand &operand : op->getOpOperands()) {
3046 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
3047 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
3048 operand.set(castOp.getOperand());
3049 folded = true;
3050 }
3051 }
3052 return success(folded);
3053}
3054
3055static LogicalResult foldTensorCast(Operation *op) {
3056 bool folded = false;
3057 for (OpOperand &operand : op->getOpOperands()) {
3058 auto castOp = operand.get().getDefiningOp<tensor::CastOp>();
3059 if (castOp && tensor::canFoldIntoConsumerOp(castOp)) {
3060 operand.set(castOp.getOperand());
3061 folded = true;
3062 }
3063 }
3064 return success(folded);
3065}
3066
3067template <typename TransferOp>
3068static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3069 // TODO: support more aggressive createOrFold on:
3070 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
3071 if (op.getShapedType().isDynamicDim(indicesIdx))
3072 return false;
3073 Value index = op.getIndices()[indicesIdx];
3074 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
3075 if (!cstOp)
3076 return false;
3077
3078 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3079 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3080
3081 return cstOp.value() + vectorSize <= sourceSize;
3082}
3083
3084template <typename TransferOp>
3085static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
3086 // TODO: support 0-d corner case.
3087 // TODO: Be less conservative.
3088 if (op.getTransferRank() == 0)
3089 return failure();
3090 AffineMap permutationMap = op.getPermutationMap();
3091 bool changed = false;
3092 SmallVector<bool, 4> newInBounds;
3093 newInBounds.reserve(op.getTransferRank());
3094 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
3095 // Already marked as in-bounds, nothing to see here.
3096 if (op.isDimInBounds(i)) {
3097 newInBounds.push_back(true);
3098 continue;
3099 }
3100 // Currently out-of-bounds, check whether we can statically determine it is
3101 // inBounds.
3102 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
3103 assert(dimExpr && "Broadcast dims must be in-bounds")(static_cast <bool> (dimExpr && "Broadcast dims must be in-bounds"
) ? void (0) : __assert_fail ("dimExpr && \"Broadcast dims must be in-bounds\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3103, __extension__
__PRETTY_FUNCTION__))
;
3104 auto inBounds =
3105 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
3106 newInBounds.push_back(inBounds);
3107 // We commit the pattern if it is "more inbounds".
3108 changed |= inBounds;
3109 }
3110 if (!changed)
3111 return failure();
3112 // OpBuilder is only used as a helper to build an I64ArrayAttr.
3113 OpBuilder b(op.getContext());
3114 op->setAttr(TransferOp::getInBoundsAttrStrName(),
3115 b.getBoolArrayAttr(newInBounds));
3116 return success();
3117}
3118
3119/// ```
3120/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3121/// : vector<1x4xf32>, tensor<4x4xf32>
3122/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
3123/// : tensor<4x4xf32>, vector<1x4xf32>
3124/// ```
3125/// -> Folds into
3126/// ```
3127/// %v0
3128/// ```
3129static Value foldRAW(TransferReadOp readOp) {
3130 if (!readOp.getShapedType().isa<RankedTensorType>())
3131 return {};
3132 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3133 while (defWrite) {
3134 if (checkSameValueRAW(defWrite, readOp))
3135 return defWrite.getVector();
3136 if (!isDisjointTransferIndices(
3137 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3138 cast<VectorTransferOpInterface>(readOp.getOperation())))
3139 break;
3140 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3141 }
3142 return {};
3143}
3144
3145OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
3146 if (Value vec = foldRAW(*this))
3147 return vec;
3148 /// transfer_read(memrefcast) -> transfer_read
3149 if (succeeded(foldTransferInBoundsAttribute(*this)))
3150 return getResult();
3151 if (succeeded(foldMemRefCast(*this)))
3152 return getResult();
3153 if (succeeded(foldTensorCast(*this)))
3154 return getResult();
3155 return OpFoldResult();
3156}
3157
3158Optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
3159 return llvm::to_vector<4>(getVectorType().getShape());
3160}
3161
3162void TransferReadOp::getEffects(
3163 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3164 &effects) {
3165 if (getShapedType().isa<MemRefType>())
3166 effects.emplace_back(MemoryEffects::Read::get(), getSource(),
3167 SideEffects::DefaultResource::get());
3168}
3169
3170/// Returns true if all rank reduced in the given `extractOp` happen in leading
3171/// dimensions earlier than last `trailingRank` dimensions.
3172static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
3173 unsigned trailingRank) {
3174 // If no ranks are reduced at all, it's a degenerated case; always true.
3175 if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
3176 return true;
3177
3178 RankedTensorType inferredType = extractOp.inferResultType(
3179 extractOp.getSourceType(), extractOp.getMixedOffsets(),
3180 extractOp.getMixedSizes(), extractOp.getMixedStrides());
3181 return extractOp.getType().getShape().take_back(trailingRank) ==
3182 inferredType.getShape().take_back(trailingRank);
3183}
3184
3185namespace {
3186/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
3187///
3188/// ```
3189/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
3190/// : tensor<?x?xf32> to tensor<?x?xf32>
3191/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
3192/// : tensor<?x?xf32>, vector<4x5xf32>
3193/// ```
3194/// is rewritten to:
3195/// ```
3196/// %p0 = arith.addi %a, %e : index
3197/// %p1 = arith.addi %b, %f : index
3198/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
3199/// : tensor<?x?xf32>, vector<4x5xf32>
3200/// ```
3201struct FoldExtractSliceIntoTransferRead
3202 : public OpRewritePattern<TransferReadOp> {
3203public:
3204 using OpRewritePattern::OpRewritePattern;
3205
3206 LogicalResult matchAndRewrite(TransferReadOp xferOp,
3207 PatternRewriter &rewriter) const override {
3208 // TODO: support 0-d corner case.
3209 if (xferOp.getTransferRank() == 0)
3210 return failure();
3211 if (xferOp.hasOutOfBoundsDim())
3212 return failure();
3213 if (!xferOp.getPermutationMap().isMinorIdentity())
3214 return failure();
3215 if (xferOp.getMask())
3216 return failure();
3217 auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3218 if (!extractOp)
3219 return failure();
3220 if (!extractOp.hasUnitStride())
3221 return failure();
3222
3223 // Bail on illegal rank-reduction: we need to check that the rank-reduced
3224 // dims are exactly the leading dims. I.e. the following is illegal:
3225 // ```
3226 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
3227 // tensor<2x1x4xf32> to tensor<2x4xf32>
3228 // %1 = vector.transfer_read %0[0,0], %cst :
3229 // tensor<2x4xf32>, vector<2x4xf32>
3230 // ```
3231 //
3232 // Cannot fold into:
3233 // ```
3234 // %0 = vector.transfer_read %t[0,0,0], %cst :
3235 // tensor<2x1x4xf32>, vector<2x4xf32>
3236 // ```
3237 // For this, check the trailing `vectorRank` dims of the extract_slice
3238 // result tensor match the trailing dims of the inferred result tensor.
3239 if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
3240 return failure();
3241
3242 int64_t rankReduced =
3243 extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3244
3245 SmallVector<Value> newIndices;
3246 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
3247 // indices first.
3248 for (int64_t i = 0; i < rankReduced; ++i) {
3249 OpFoldResult offset = extractOp.getMixedOffsets()[i];
3250 newIndices.push_back(getValueOrCreateConstantIndexOp(
3251 rewriter, extractOp.getLoc(), offset));
3252 }
3253 for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
3254 OpFoldResult offset =
3255 extractOp.getMixedOffsets()[it.index() + rankReduced];
3256 newIndices.push_back(rewriter.create<arith::AddIOp>(
3257 xferOp->getLoc(), it.value(),
3258 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3259 offset)));
3260 }
3261 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3262 rewriter.replaceOpWithNewOp<TransferReadOp>(
3263 xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
3264 xferOp.getPadding(), ArrayRef<bool>{inBounds});
3265
3266 return success();
3267 }
3268};
3269
3270/// Store to load forwarding for transfer operations with permuation maps.
3271/// Even if the permutation maps are different we can still propagate the store
3272/// into the load if the size of the dimensions read and written match. Then we
3273/// can replace the transfer_read + transfer_write by vector.broadcast and
3274/// vector.transpose.
3275/// Example:
3276/// ```
3277/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3278/// {in_bounds = [true, true],
3279/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3280/// vector<4x1xf32>, tensor<4x4x4xf32>
3281/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3282/// {in_bounds = [true, true, true, true],
3283/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3284/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3285/// ```
3286/// To:
3287/// ```
3288/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3289/// %r = vector.transpose %0, [3, 0, 2, 1] :
3290/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3291/// ```
3292struct TransferReadAfterWriteToBroadcast
3293 : public OpRewritePattern<TransferReadOp> {
3294 using OpRewritePattern::OpRewritePattern;
3295
3296 LogicalResult matchAndRewrite(TransferReadOp readOp,
3297 PatternRewriter &rewriter) const override {
3298 if (readOp.hasOutOfBoundsDim() ||
3299 !readOp.getShapedType().isa<RankedTensorType>())
3300 return failure();
3301 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3302 if (!defWrite)
3303 return failure();
3304
3305 SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3306 Value vec;
3307 if (readOp.getIndices() == defWrite.getIndices() &&
3308 readOp.getMask() == defWrite.getMask()) {
3309 SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3310 // TODO: If the writeDim is a superset of the read dims we could do an
3311 // extract_strided_slice.
3312 if (writeDims == readDims)
3313 vec = defWrite.getVector();
3314 }
3315 // TODO: loop through the chain of transfer_write if we can prove that they
3316 // don't overlap with the transfer_read. This requires improving
3317 // `isDisjointTransferIndices` helper.
3318 if (!vec)
3319 return failure();
3320 SmallVector<unsigned> permutation;
3321 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3322 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3323 AffineMap map = readMap.compose(writeMap);
3324 if (map.getNumResults() == 0)
3325 return failure();
3326 // Calculate the permuation to apply to go from the vector stored to the
3327 // vector read.
3328 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3329 return failure();
3330
3331 Location loc = readOp.getLoc();
3332 // Calculate the broadcast shape by applying the reverse permuation to the
3333 // final shape we want.
3334 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3335 SmallVector<int64_t> broadcastShape(destShape.size());
3336 for (const auto &pos : llvm::enumerate(permutation))
3337 broadcastShape[pos.value()] = destShape[pos.index()];
3338 VectorType broadcastedType = VectorType::get(
3339 broadcastShape, defWrite.getVectorType().getElementType());
3340 vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3341 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3342 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3343 transposePerm);
3344 return success();
3345 }
3346};
3347} // namespace
3348
3349void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3350 MLIRContext *context) {
3351 results
3352 .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3353 context);
3354}
3355
3356//===----------------------------------------------------------------------===//
3357// TransferWriteOp
3358//===----------------------------------------------------------------------===//
3359
3360/// 1. Builder with type inference.
3361void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3362 Value vector, Value dest, ValueRange indices,
3363 AffineMapAttr permutationMapAttr,
3364 /*optional*/ Value mask,
3365 /*optional*/ ArrayAttr inBoundsAttr) {
3366 Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3367 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3368 mask, inBoundsAttr);
3369}
3370
3371/// 2. Builder with type inference that sets an empty mask (variant with attrs).
3372void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3373 Value vector, Value dest, ValueRange indices,
3374 AffineMapAttr permutationMapAttr,
3375 /*optional*/ ArrayAttr inBoundsAttr) {
3376 build(builder, result, vector, dest, indices, permutationMapAttr,
3377 /*mask=*/Value(), inBoundsAttr);
3378}
3379
3380/// 3. Builder with type inference that sets an empty mask (variant without
3381/// attrs)
3382void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3383 Value vector, Value dest, ValueRange indices,
3384 AffineMap permutationMap,
3385 Optional<ArrayRef<bool>> inBounds) {
3386 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3387 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3388 ? builder.getBoolArrayAttr(inBounds.value())
3389 : ArrayAttr();
3390 build(builder, result, vector, dest, indices, permutationMapAttr,
3391 /*mask=*/Value(), inBoundsAttr);
3392}
3393
3394/// 4. Builder with type inference that sets an empty mask and sets permutation
3395/// map to 'getMinorIdentityMap'.
3396void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3397 Value vector, Value dest, ValueRange indices,
3398 Optional<ArrayRef<bool>> inBounds) {
3399 auto vectorType = vector.getType().cast<VectorType>();
3400 AffineMap permutationMap = getTransferMinorIdentityMap(
3401 dest.getType().cast<ShapedType>(), vectorType);
3402 build(builder, result, vector, dest, indices, permutationMap, inBounds);
3403}
3404
3405ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3406 OperationState &result) {
3407 auto &builder = parser.getBuilder();
3408 SMLoc typesLoc;
3409 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
3410 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3411 SmallVector<Type, 2> types;
3412 OpAsmParser::UnresolvedOperand maskInfo;
3413 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3414 parser.parseOperand(sourceInfo) ||
3415 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
3416 return failure();
3417 ParseResult hasMask = parser.parseOptionalComma();
3418 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3419 return failure();
3420 if (parser.parseOptionalAttrDict(result.attributes) ||
3421 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3422 return failure();
3423 if (types.size() != 2)
3424 return parser.emitError(typesLoc, "requires two types");
3425 auto indexType = builder.getIndexType();
3426 VectorType vectorType = types[0].dyn_cast<VectorType>();
3427 if (!vectorType)
3428 return parser.emitError(typesLoc, "requires vector type");
3429 ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3430 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3431 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3432 auto permutationAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3433 auto attr = result.attributes.get(permutationAttrName);
3434 if (!attr) {
3435 auto permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3436 result.attributes.set(permutationAttrName, AffineMapAttr::get(permMap));
3437 }
3438 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3439 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3440 parser.resolveOperands(indexInfo, indexType, result.operands))
3441 return failure();
3442 if (hasMask.succeeded()) {
3443 if (shapedType.getElementType().dyn_cast<VectorType>())
3444 return parser.emitError(
3445 maskInfo.location, "does not support masks with vector element type");
3446 auto maskType = VectorType::get(vectorType.getShape(), builder.getI1Type());
3447 if (parser.resolveOperand(maskInfo, maskType, result.operands))
3448 return failure();
3449 }
3450 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3451 builder.getDenseI32ArrayAttr(
3452 {1, 1, static_cast<int32_t>(indexInfo.size()),
3453 static_cast<int32_t>(hasMask.succeeded())}));
3454 return failure(shapedType.isa<RankedTensorType>() &&
3455 parser.addTypeToList(shapedType, result.types));
3456}
3457
3458void TransferWriteOp::print(OpAsmPrinter &p) {
3459 p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
3460 if (getMask())
3461 p << ", " << getMask();
3462 printTransferAttrs(p, *this);
3463 p << " : " << getVectorType() << ", " << getShapedType();
3464}
3465
3466LogicalResult TransferWriteOp::verify() {
3467 // Consistency of elemental types in shape and vector.
3468 ShapedType shapedType = getShapedType();
3469 VectorType vectorType = getVectorType();
3470 VectorType maskType = getMaskType();
3471 auto permutationMap = getPermutationMap();
3472
3473 if (llvm::size(getIndices()) != shapedType.getRank())
3474 return emitOpError("requires ") << shapedType.getRank() << " indices";
3475
3476 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3477 // as the semantics is unclear. This can be revisited later if necessary.
3478 if (hasBroadcastDim())
3479 return emitOpError("should not have broadcast dimensions");
3480
3481 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3482 shapedType, vectorType, maskType, permutationMap,
3483 getInBounds() ? *getInBounds() : ArrayAttr())))
3484 return failure();
3485
3486 return verifyPermutationMap(permutationMap,
3487 [&](Twine t) { return emitOpError(t); });
3488}
3489
3490/// Fold:
3491/// ```
3492/// %t1 = ...
3493/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3494/// tensor<static_sizesxf32>, vector<static_sizesxf32>
3495/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3496/// vector<static_sizesxf32>, tensor<static_sizesxf32>
3497/// ```
3498///
3499/// into:
3500///
3501/// ```
3502/// %t0
3503/// ```
3504///
3505/// The producer of t1 may or may not be DCE'd depending on whether it is a
3506/// block argument or has side effects.
3507static LogicalResult foldReadInitWrite(TransferWriteOp write,
3508 ArrayRef<Attribute>,
3509 SmallVectorImpl<OpFoldResult> &results) {
3510 // TODO: support 0-d corner case.
3511 if (write.getTransferRank() == 0)
3512 return failure();
3513 auto rankedTensorType =
3514 write.getSource().getType().dyn_cast<RankedTensorType>();
3515 // If not operating on tensors, bail.
3516 if (!rankedTensorType)
3517 return failure();
3518 // If no read, bail.
3519 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3520 if (!read)
3521 return failure();
3522 // TODO: support 0-d corner case.
3523 if (read.getTransferRank() == 0)
3524 return failure();
3525 // For now, only accept minor identity. Future: composition is minor identity.
3526 if (!read.getPermutationMap().isMinorIdentity() ||
3527 !write.getPermutationMap().isMinorIdentity())
3528 return failure();
3529 // Bail on mismatching ranks.
3530 if (read.getTransferRank() != write.getTransferRank())
3531 return failure();
3532 // Bail on potential out-of-bounds accesses.
3533 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3534 return failure();
3535 // Tensor types must be the same.
3536 if (read.getSource().getType() != rankedTensorType)
3537 return failure();
3538 // Vector types must be the same.
3539 if (read.getVectorType() != write.getVectorType())
3540 return failure();
3541 // Vector and Tensor shapes must match.
3542 if (read.getVectorType().getShape() != rankedTensorType.getShape())
3543 return failure();
3544 // If any index is nonzero.
3545 auto isNotConstantZero = [](Value v) {
3546 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3547 return !cstOp || cstOp.value() != 0;
3548 };
3549 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
3550 llvm::any_of(write.getIndices(), isNotConstantZero))
3551 return failure();
3552 // Success.
3553 results.push_back(read.getSource());
3554 return success();
3555}
3556
3557static bool checkSameValueWAR(vector::TransferReadOp read,
3558 vector::TransferWriteOp write) {
3559 return read.getSource() == write.getSource() &&
3560 read.getIndices() == write.getIndices() &&
3561 read.getPermutationMap() == write.getPermutationMap() &&
3562 read.getVectorType() == write.getVectorType() && !read.getMask() &&
3563 !write.getMask();
3564}
3565/// Fold transfer_write write after read:
3566/// ```
3567/// %t0 = ...
3568/// %v = vector.transfer_read %t0[%c0...] :
3569/// tensor<static_sizesxf32>, vector<static_sizesxf32>
3570/// %t1 = vector.transfer_write %v, %t0[%c0...] :
3571/// vector<static_sizesxf32>, tensor<static_sizesxf32>
3572/// ```
3573///
3574/// into:
3575///
3576/// ```
3577/// %t0
3578/// ```
3579static LogicalResult foldWAR(TransferWriteOp write,
3580 SmallVectorImpl<OpFoldResult> &results) {
3581 if (!write.getSource().getType().isa<RankedTensorType>())
3582 return failure();
3583 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3584 if (!read)
3585 return failure();
3586
3587 if (!checkSameValueWAR(read, write))
3588 return failure();
3589 results.push_back(read.getSource());
3590 return success();
3591}
3592
3593LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
3594 SmallVectorImpl<OpFoldResult> &results) {
3595 if (succeeded(foldReadInitWrite(*this, operands, results)))
3596 return success();
3597 if (succeeded(foldWAR(*this, results)))
3598 return success();
3599 if (succeeded(foldTransferInBoundsAttribute(*this)))
3600 return success();
3601 return foldMemRefCast(*this);
3602}
3603
3604Optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
3605 return llvm::to_vector<4>(getVectorType().getShape());
3606}
3607
3608void TransferWriteOp::getEffects(
3609 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3610 &effects) {
3611 if (getShapedType().isa<MemRefType>())
3612 effects.emplace_back(MemoryEffects::Write::get(), getSource(),
3613 SideEffects::DefaultResource::get());
3614}
3615
3616namespace {
3617/// Remove dead transfer write from the SSA chain so that it an be eliminated by
3618/// DCE
3619/// ```
3620/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3621/// : vector<1x4xf32>, tensor<4x4xf32>
3622/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
3623/// : vector<1x4xf32>, tensor<4x4xf32>
3624/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3625/// : vector<1x4xf32>, tensor<4x4xf32>
3626/// ```
3627///
3628/// into:
3629///
3630/// ```
3631/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3632/// : vector<1x4xf32>, tensor<4x4xf32>
3633/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
3634/// : vector<1x4xf32>, tensor<4x4xf32>
3635/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
3636/// : vector<1x4xf32>, tensor<4x4xf32>
3637/// ```
3638///
3639/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
3640/// any other uses.
3641class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
3642public:
3643 using OpRewritePattern::OpRewritePattern;
3644 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
3645 PatternRewriter &rewriter) const override {
3646 if (!writeOp.getShapedType().isa<RankedTensorType>())
3647 return failure();
3648 vector::TransferWriteOp writeToModify = writeOp;
3649
3650 auto defWrite =
3651 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3652 while (defWrite) {
3653 if (checkSameValueWAW(writeOp, defWrite)) {
3654 writeToModify.getSourceMutable().assign(defWrite.getSource());
3655 return success();
3656 }
3657 if (!isDisjointTransferIndices(
3658 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3659 cast<VectorTransferOpInterface>(writeOp.getOperation())))
3660 break;
3661 // If the previous write op doesn't have any other use we an safely look
3662 // at the previous store to see if it can be removed.
3663 if (!defWrite->hasOneUse())
3664 break;
3665 writeToModify = defWrite;
3666 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3667 }
3668 return failure();
3669 }
3670};
3671
3672/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
3673/// could directly write to the insert_slice's destination. E.g.:
3674///
3675/// ```
3676/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
3677/// : vector<4x5xf32>, tensor<4x5xf32>
3678/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
3679/// : tensor<4x5xf32> into tensor<?x?xf32>
3680/// ```
3681/// is rewritten to:
3682/// ```
3683/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
3684/// : vector<4x5xf32>, tensor<?x?xf32>
3685/// ```
3686struct FoldInsertSliceIntoTransferWrite
3687 : public OpRewritePattern<tensor::InsertSliceOp> {
3688public:
3689 using OpRewritePattern::OpRewritePattern;
3690
3691 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3692 PatternRewriter &rewriter) const override {
3693 if (!insertOp.hasUnitStride())
3694 return failure();
3695
3696 auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
3697 if (!xferOp)
3698 return failure();
3699 // TODO: support 0-d corner case.
3700 if (xferOp.getTransferRank() == 0)
3701 return failure();
3702
3703 if (xferOp.hasOutOfBoundsDim())
3704 return failure();
3705 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
3706 return failure();
3707 if (xferOp.getMask())
3708 return failure();
3709 // Fold only if the TransferWriteOp completely overwrites the `source` with
3710 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
3711 // content is the data of the vector.
3712 if (!llvm::equal(xferOp.getVectorType().getShape(),
3713 xferOp.getShapedType().getShape()))
3714 return failure();
3715 if (!xferOp.getPermutationMap().isIdentity())
3716 return failure();
3717
3718 // Bail on illegal rank-reduction: we need to check that the rank-reduced
3719 // dims are exactly the leading dims. I.e. the following is illegal:
3720 // ```
3721 // %0 = vector.transfer_write %v, %t[0,0], %cst :
3722 // vector<2x4xf32>, tensor<2x4xf32>
3723 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
3724 // tensor<2x4xf32> into tensor<2x1x4xf32>
3725 // ```
3726 //
3727 // Cannot fold into:
3728 // ```
3729 // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
3730 // vector<2x4xf32>, tensor<2x1x4xf32>
3731 // ```
3732 // For this, check the trailing `vectorRank` dims of the insert_slice result
3733 // tensor match the trailing dims of the inferred result tensor.
3734 int64_t rankReduced =
3735 insertOp.getType().getRank() - insertOp.getSourceType().getRank();
3736 int64_t vectorRank = xferOp.getVectorType().getRank();
3737 RankedTensorType inferredSourceTensorType =
3738 tensor::ExtractSliceOp::inferResultType(
3739 insertOp.getType(), insertOp.getMixedOffsets(),
3740 insertOp.getMixedSizes(), insertOp.getMixedStrides());
3741 auto actualSourceTensorShape = insertOp.getSourceType().getShape();
3742 if (rankReduced > 0 &&
3743 actualSourceTensorShape.take_back(vectorRank) !=
3744 inferredSourceTensorType.getShape().take_back(vectorRank))
3745 return failure();
3746
3747 SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
3748 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
3749 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3750 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
3751 insertOp.getDest(), indices,
3752 ArrayRef<bool>{inBounds});
3753 return success();
3754 }
3755};
3756
3757/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
3758/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
3759/// overwritten and inserted into another tensor. After this rewrite, the
3760/// operations bufferize in-place since all of them work on the same slice.
3761///
3762/// For example:
3763/// ```mlir
3764/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
3765/// : vector<8x16xf32>, tensor<8x16xf32>
3766/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
3767/// : tensor<8x16xf32> to tensor<?x?xf32>
3768/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3769/// : tensor<?x?xf32> into tensor<27x37xf32>
3770/// ```
3771/// folds to
3772/// ```mlir
3773/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3774/// : tensor<27x37xf32> to tensor<?x?xf32>
3775/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
3776/// : vector<8x16xf32>, tensor<?x?xf32>
3777/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
3778/// : tensor<?x?xf32> into tensor<27x37xf32>
3779/// ```
3780struct SwapExtractSliceOfTransferWrite
3781 : public OpRewritePattern<tensor::InsertSliceOp> {
3782public:
3783 using OpRewritePattern::OpRewritePattern;
3784
3785 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
3786 PatternRewriter &rewriter) const override {
3787 if (!insertOp.hasUnitStride())
3788 return failure();
3789 auto extractOp =
3790 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3791 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
3792 return failure();
3793 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
3794 if (!transferOp || !transferOp->hasOneUse())
3795 return failure();
3796
3797 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
3798 // rank-reducing.
3799 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
3800 return rewriter.notifyMatchFailure(insertOp,
3801 "use-def chain is rank-reducing");
3802 }
3803
3804 // Fail if tensor::ExtractSliceOp has non-zero offset.
3805 if (!extractOp.hasZeroOffset()) {
3806 return rewriter.notifyMatchFailure(insertOp,
3807 "ExtractSliceOp has non-zero offset");
3808 }
3809
3810 // Fail if tensor::TransferWriteOp has non-zero offset.
3811 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
3812 return getConstantIntValue(value) == static_cast<int64_t>(0);
3813 })) {
3814 return rewriter.notifyMatchFailure(insertOp,
3815 "TranferWriteOp has non-zero offset");
3816 }
3817
3818 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
3819 for (const auto &it :
3820 llvm::zip(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
3821 if (!isEqualConstantIntOrValue(std::get<0>(it), std::get<1>(it))) {
3822 return rewriter.notifyMatchFailure(
3823 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
3824 }
3825 }
3826
3827 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
3828 assert(transferOp.getVectorType().hasStaticShape() &&(static_cast <bool> (transferOp.getVectorType().hasStaticShape
() && "expected vector to have a static shape") ? void
(0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3829, __extension__
__PRETTY_FUNCTION__))
3829 "expected vector to have a static shape")(static_cast <bool> (transferOp.getVectorType().hasStaticShape
() && "expected vector to have a static shape") ? void
(0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3829, __extension__
__PRETTY_FUNCTION__))
;
3830 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
3831 SmallVector<int64_t> resultShape = applyPermutationMap(
3832 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
3833 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
3834 return rewriter.notifyMatchFailure(
3835 insertOp, "TransferWriteOp may not write the full tensor.");
3836 }
3837
3838 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
3839 SmallVector<int64_t> newResultShape = applyPermutationMap(
3840 transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
3841 SmallVector<bool> newInBounds;
3842 for (const auto &en : enumerate(newResultShape))
3843 newInBounds.push_back(en.value() == vectorShape[en.index()]);
3844 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
3845 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
3846 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
3847 insertOp.getMixedStrides());
3848 auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
3849 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
3850 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
3851 rewriter.getBoolArrayAttr(newInBounds));
3852 rewriter.updateRootInPlace(insertOp, [&]() {
3853 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
3854 });
3855 return success();
3856 }
3857};
3858
3859} // namespace
3860
3861void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
3862 MLIRContext *context) {
3863 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
3864 SwapExtractSliceOfTransferWrite>(context);
3865}
3866
3867//===----------------------------------------------------------------------===//
3868// LoadOp
3869//===----------------------------------------------------------------------===//
3870
3871static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
3872 MemRefType memRefTy) {
3873 if (!isLastMemrefDimUnitStride(memRefTy))
3874 return op->emitOpError("most minor memref dim must have unit stride");
3875 return success();
3876}
3877
3878LogicalResult vector::LoadOp::verify() {
3879 VectorType resVecTy = getVectorType();
3880 MemRefType memRefTy = getMemRefType();
3881
3882 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
3883 return failure();
3884
3885 // Checks for vector memrefs.
3886 Type memElemTy = memRefTy.getElementType();
3887 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3888 if (memVecTy != resVecTy)
3889 return emitOpError("base memref and result vector types should match");
3890 memElemTy = memVecTy.getElementType();
3891 }
3892
3893 if (resVecTy.getElementType() != memElemTy)
3894 return emitOpError("base and result element types should match");
3895 if (llvm::size(getIndices()) != memRefTy.getRank())
3896 return emitOpError("requires ") << memRefTy.getRank() << " indices";
3897 return success();
3898}
3899
3900OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
3901 if (succeeded(foldMemRefCast(*this)))
3902 return getResult();
3903 return OpFoldResult();
3904}
3905
3906//===----------------------------------------------------------------------===//
3907// StoreOp
3908//===----------------------------------------------------------------------===//
3909
3910LogicalResult vector::StoreOp::verify() {
3911 VectorType valueVecTy = getVectorType();
3912 MemRefType memRefTy = getMemRefType();
3913
3914 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
3915 return failure();
3916
3917 // Checks for vector memrefs.
3918 Type memElemTy = memRefTy.getElementType();
3919 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
3920 if (memVecTy != valueVecTy)
3921 return emitOpError(
3922 "base memref and valueToStore vector types should match");
3923 memElemTy = memVecTy.getElementType();
3924 }
3925
3926 if (valueVecTy.getElementType() != memElemTy)
3927 return emitOpError("base and valueToStore element type should match");
3928 if (llvm::size(getIndices()) != memRefTy.getRank())
3929 return emitOpError("requires ") << memRefTy.getRank() << " indices";
3930 return success();
3931}
3932
3933LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
3934 SmallVectorImpl<OpFoldResult> &results) {
3935 return foldMemRefCast(*this);
3936}
3937
3938//===----------------------------------------------------------------------===//
3939// MaskedLoadOp
3940//===----------------------------------------------------------------------===//
3941
3942LogicalResult MaskedLoadOp::verify() {
3943 VectorType maskVType = getMaskVectorType();
3944 VectorType passVType = getPassThruVectorType();
3945 VectorType resVType = getVectorType();
3946 MemRefType memType = getMemRefType();
3947
3948 if (resVType.getElementType() != memType.getElementType())
3949 return emitOpError("base and result element type should match");
3950 if (llvm::size(getIndices()) != memType.getRank())
3951 return emitOpError("requires ") << memType.getRank() << " indices";
3952 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
3953 return emitOpError("expected result dim to match mask dim");
3954 if (resVType != passVType)
3955 return emitOpError("expected pass_thru of same type as result type");
3956 return success();
3957}
3958
3959namespace {
3960class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
3961public:
3962 using OpRewritePattern::OpRewritePattern;
3963 LogicalResult matchAndRewrite(MaskedLoadOp load,
3964 PatternRewriter &rewriter) const override {
3965 switch (getMaskFormat(load.getMask())) {
3966 case MaskFormat::AllTrue:
3967 rewriter.replaceOpWithNewOp<vector::LoadOp>(
3968 load, load.getType(), load.getBase(), load.getIndices());
3969 return success();
3970 case MaskFormat::AllFalse:
3971 rewriter.replaceOp(load, load.getPassThru());
3972 return success();
3973 case MaskFormat::Unknown:
3974 return failure();
3975 }
3976 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedLoad"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3976)
;
3977 }
3978};
3979} // namespace
3980
3981void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3982 MLIRContext *context) {
3983 results.add<MaskedLoadFolder>(context);
3984}
3985
3986OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
3987 if (succeeded(foldMemRefCast(*this)))
3988 return getResult();
3989 return OpFoldResult();
3990}
3991
3992//===----------------------------------------------------------------------===//
3993// MaskedStoreOp
3994//===----------------------------------------------------------------------===//
3995
3996LogicalResult MaskedStoreOp::verify() {
3997 VectorType maskVType = getMaskVectorType();
3998 VectorType valueVType = getVectorType();
3999 MemRefType memType = getMemRefType();
4000
4001 if (valueVType.getElementType() != memType.getElementType())
4002 return emitOpError("base and valueToStore element type should match");
4003 if (llvm::size(getIndices()) != memType.getRank())
4004 return emitOpError("requires ") << memType.getRank() << " indices";
4005 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4006 return emitOpError("expected valueToStore dim to match mask dim");
4007 return success();
4008}
4009
4010namespace {
4011class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4012public:
4013 using OpRewritePattern::OpRewritePattern;
4014 LogicalResult matchAndRewrite(MaskedStoreOp store,
4015 PatternRewriter &rewriter) const override {
4016 switch (getMaskFormat(store.getMask())) {
4017 case MaskFormat::AllTrue:
4018 rewriter.replaceOpWithNewOp<vector::StoreOp>(
4019 store, store.getValueToStore(), store.getBase(), store.getIndices());
4020 return success();
4021 case MaskFormat::AllFalse:
4022 rewriter.eraseOp(store);
4023 return success();
4024 case MaskFormat::Unknown:
4025 return failure();
4026 }
4027 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedStore"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4027)
;
4028 }
4029};
4030} // namespace
4031
4032void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4033 MLIRContext *context) {
4034 results.add<MaskedStoreFolder>(context);
4035}
4036
4037LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
4038 SmallVectorImpl<OpFoldResult> &results) {
4039 return foldMemRefCast(*this);
4040}
4041
4042//===----------------------------------------------------------------------===//
4043// GatherOp
4044//===----------------------------------------------------------------------===//
4045
4046LogicalResult GatherOp::verify() {
4047 VectorType indVType = getIndexVectorType();
4048 VectorType maskVType = getMaskVectorType();
4049 VectorType resVType = getVectorType();
4050 ShapedType baseType = getBaseType();
4051
4052 if (!baseType.isa<MemRefType, RankedTensorType>())
4053 return emitOpError("requires base to be a memref or ranked tensor type");
4054
4055 if (resVType.getElementType() != baseType.getElementType())
4056 return emitOpError("base and result element type should match");
4057 if (llvm::size(getIndices()) != baseType.getRank())
4058 return emitOpError("requires ") << baseType.getRank() << " indices";
4059 if (resVType.getShape() != indVType.getShape())
4060 return emitOpError("expected result dim to match indices dim");
4061 if (resVType.getShape() != maskVType.getShape())
4062 return emitOpError("expected result dim to match mask dim");
4063 if (resVType != getPassThruVectorType())
4064 return emitOpError("expected pass_thru of same type as result type");
4065 return success();
4066}
4067
4068namespace {
4069class GatherFolder final : public OpRewritePattern<GatherOp> {
4070public:
4071 using OpRewritePattern::OpRewritePattern;
4072 LogicalResult matchAndRewrite(GatherOp gather,
4073 PatternRewriter &rewriter) const override {
4074 switch (getMaskFormat(gather.getMask())) {
4075 case MaskFormat::AllTrue:
4076 return failure(); // no unmasked equivalent
4077 case MaskFormat::AllFalse:
4078 rewriter.replaceOp(gather, gather.getPassThru());
4079 return success();
4080 case MaskFormat::Unknown:
4081 return failure();
4082 }
4083 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on GatherFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4083)
;
4084 }
4085};
4086} // namespace
4087
4088void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4089 MLIRContext *context) {
4090 results.add<GatherFolder>(context);
4091}
4092
4093//===----------------------------------------------------------------------===//
4094// ScatterOp
4095//===----------------------------------------------------------------------===//
4096
4097LogicalResult ScatterOp::verify() {
4098 VectorType indVType = getIndexVectorType();
4099 VectorType maskVType = getMaskVectorType();
4100 VectorType valueVType = getVectorType();
4101 MemRefType memType = getMemRefType();
4102
4103 if (valueVType.getElementType() != memType.getElementType())
4104 return emitOpError("base and valueToStore element type should match");
4105 if (llvm::size(getIndices()) != memType.getRank())
4106 return emitOpError("requires ") << memType.getRank() << " indices";
4107 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4108 return emitOpError("expected valueToStore dim to match indices dim");
4109 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4110 return emitOpError("expected valueToStore dim to match mask dim");
4111 return success();
4112}
4113
4114namespace {
4115class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4116public:
4117 using OpRewritePattern::OpRewritePattern;
4118 LogicalResult matchAndRewrite(ScatterOp scatter,
4119 PatternRewriter &rewriter) const override {
4120 switch (getMaskFormat(scatter.getMask())) {
4121 case MaskFormat::AllTrue:
4122 return failure(); // no unmasked equivalent
4123 case MaskFormat::AllFalse:
4124 rewriter.eraseOp(scatter);
4125 return success();
4126 case MaskFormat::Unknown:
4127 return failure();
4128 }
4129 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ScatterFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4129)
;
4130 }
4131};
4132} // namespace
4133
4134void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4135 MLIRContext *context) {
4136 results.add<ScatterFolder>(context);
4137}
4138
4139//===----------------------------------------------------------------------===//
4140// ExpandLoadOp
4141//===----------------------------------------------------------------------===//
4142
4143LogicalResult ExpandLoadOp::verify() {
4144 VectorType maskVType = getMaskVectorType();
4145 VectorType passVType = getPassThruVectorType();
4146 VectorType resVType = getVectorType();
4147 MemRefType memType = getMemRefType();
4148
4149 if (resVType.getElementType() != memType.getElementType())
4150 return emitOpError("base and result element type should match");
4151 if (llvm::size(getIndices()) != memType.getRank())
4152 return emitOpError("requires ") << memType.getRank() << " indices";
4153 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4154 return emitOpError("expected result dim to match mask dim");
4155 if (resVType != passVType)
4156 return emitOpError("expected pass_thru of same type as result type");
4157 return success();
4158}
4159
4160namespace {
4161class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4162public:
4163 using OpRewritePattern::OpRewritePattern;
4164 LogicalResult matchAndRewrite(ExpandLoadOp expand,
4165 PatternRewriter &rewriter) const override {
4166 switch (getMaskFormat(expand.getMask())) {
4167 case MaskFormat::AllTrue:
4168 rewriter.replaceOpWithNewOp<vector::LoadOp>(
4169 expand, expand.getType(), expand.getBase(), expand.getIndices());
4170 return success();
4171 case MaskFormat::AllFalse:
4172 rewriter.replaceOp(expand, expand.getPassThru());
4173 return success();
4174 case MaskFormat::Unknown:
4175 return failure();
4176 }
4177 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ExpandLoadFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4177)
;
4178 }
4179};
4180} // namespace
4181
4182void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4183 MLIRContext *context) {
4184 results.add<ExpandLoadFolder>(context);
4185}
4186
4187//===----------------------------------------------------------------------===//
4188// CompressStoreOp
4189//===----------------------------------------------------------------------===//
4190
4191LogicalResult CompressStoreOp::verify() {
4192 VectorType maskVType = getMaskVectorType();
4193 VectorType valueVType = getVectorType();
4194 MemRefType memType = getMemRefType();
4195
4196 if (valueVType.getElementType() != memType.getElementType())
4197 return emitOpError("base and valueToStore element type should match");
4198 if (llvm::size(getIndices()) != memType.getRank())
4199 return emitOpError("requires ") << memType.getRank() << " indices";
4200 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4201 return emitOpError("expected valueToStore dim to match mask dim");
4202 return success();
4203}
4204
4205namespace {
4206class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4207public:
4208 using OpRewritePattern::OpRewritePattern;
4209 LogicalResult matchAndRewrite(CompressStoreOp compress,
4210 PatternRewriter &rewriter) const override {
4211 switch (getMaskFormat(compress.getMask())) {
4212 case MaskFormat::AllTrue:
4213 rewriter.replaceOpWithNewOp<vector::StoreOp>(
4214 compress, compress.getValueToStore(), compress.getBase(),
4215 compress.getIndices());
4216 return success();
4217 case MaskFormat::AllFalse:
4218 rewriter.eraseOp(compress);
4219 return success();
4220 case MaskFormat::Unknown:
4221 return failure();
4222 }
4223 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on CompressStoreFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4223)
;
4224 }
4225};
4226} // namespace
4227
4228void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4229 MLIRContext *context) {
4230 results.add<CompressStoreFolder>(context);
4231}
4232
4233//===----------------------------------------------------------------------===//
4234// ShapeCastOp
4235//===----------------------------------------------------------------------===//
4236
4237/// Returns true if each element of 'a' is equal to the product of a contiguous
4238/// sequence of the elements of 'b'. Returns false otherwise.
4239static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
4240 unsigned rankA = a.size();
4241 unsigned rankB = b.size();
4242 assert(rankA < rankB)(static_cast <bool> (rankA < rankB) ? void (0) : __assert_fail
("rankA < rankB", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 4242, __extension__ __PRETTY_FUNCTION__))
;
4243
4244 unsigned i = 0;
4245 unsigned j = 0;
4246 while (i < rankA && j < rankB) {
4247 int64_t dimA = a[i];
4248 int64_t dimB = 1;
4249 while (dimB < dimA && j < rankB)
4250 dimB *= b[j++];
4251 if (dimA != dimB)
4252 break;
4253 ++i;
4254
4255 // Handle the case when trailing dimensions are of size 1.
4256 // Include them into the contiguous sequence.
4257 auto isOne = [](int64_t v) { return v == 1; };
4258 if (i < rankA && llvm::all_of(a.slice(i), isOne))
4259 i = rankA;
4260 if (j < rankB && llvm::all_of(b.slice(j), isOne))
4261 j = rankB;
4262 }
4263
4264 return i == rankA && j == rankB;
4265}
4266
4267static LogicalResult verifyVectorShapeCast(Operation *op,
4268 VectorType sourceVectorType,
4269 VectorType resultVectorType) {
4270 // Check that element type is the same.
4271 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4272 return op->emitOpError("source/result vectors must have same element type");
4273 auto sourceShape = sourceVectorType.getShape();
4274 auto resultShape = resultVectorType.getShape();
4275
4276 // Check that product of source dim sizes matches product of result dim sizes.
4277 int64_t sourceDimProduct = std::accumulate(
4278 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4279 int64_t resultDimProduct = std::accumulate(
4280 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4281 if (sourceDimProduct != resultDimProduct)
4282 return op->emitOpError("source/result number of elements must match");
4283
4284 // Check that expanding/contracting rank cases.
4285 unsigned sourceRank = sourceVectorType.getRank();
4286 unsigned resultRank = resultVectorType.getRank();
4287 if (sourceRank < resultRank) {
4288 if (!isValidShapeCast(sourceShape, resultShape))
4289 return op->emitOpError("invalid shape cast");
4290 } else if (sourceRank > resultRank) {
4291 if (!isValidShapeCast(resultShape, sourceShape))
4292 return op->emitOpError("invalid shape cast");
4293 }
4294 return success();
4295}
4296
4297LogicalResult ShapeCastOp::verify() {
4298 auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
4299 auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
4300
4301 // Check if source/result are of vector type.
4302 if (sourceVectorType && resultVectorType)
4303 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
4304
4305 return success();
4306}
4307
4308OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4309 // No-op shape cast.
4310 if (getSource().getType() == getResult().getType())
4311 return getSource();
4312
4313 // Canceling shape casts.
4314 if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4315 if (getResult().getType() == otherOp.getSource().getType())
4316 return otherOp.getSource();
4317
4318 // Only allows valid transitive folding.
4319 VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
4320 VectorType resultType = getResult().getType().cast<VectorType>();
4321 if (srcType.getRank() < resultType.getRank()) {
4322 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4323 return {};
4324 } else if (srcType.getRank() > resultType.getRank()) {
4325 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4326 return {};
4327 } else {
4328 return {};
4329 }
4330
4331 setOperand(otherOp.getSource());
4332 return getResult();
4333 }
4334
4335 // Cancelling broadcast and shape cast ops.
4336 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4337 if (bcastOp.getSourceType() == getType())
4338 return bcastOp.getSource();
4339 }
4340
4341 return {};
4342}
4343
4344namespace {
4345// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
4346class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
4347public:
4348 using OpRewritePattern::OpRewritePattern;
4349
4350 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4351 PatternRewriter &rewriter) const override {
4352 auto constantOp =
4353 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4354 if (!constantOp)
4355 return failure();
4356 // Only handle splat for now.
4357 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
4358 if (!dense)
4359 return failure();
4360 auto newAttr =
4361 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
4362 dense.getSplatValue<Attribute>());
4363 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
4364 return success();
4365 }
4366};
4367
4368/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4369/// This only applies when the shape of the broadcast source is a suffix of the
4370/// shape of the result (i.e. when broadcast without reshape is expressive
4371/// enough to capture the result in a single op).
4372class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4373public:
4374 using OpRewritePattern::OpRewritePattern;
4375
4376 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4377 PatternRewriter &rewriter) const override {
4378 auto broadcastOp =
4379 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4380 if (!broadcastOp)
4381 return failure();
4382
4383 auto broadcastSourceVectorType =
4384 broadcastOp.getSourceType().dyn_cast<VectorType>();
4385 auto broadcastSourceShape = broadcastSourceVectorType
4386 ? broadcastSourceVectorType.getShape()
4387 : ArrayRef<int64_t>{};
4388 auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4389
4390 // Bail if `broadcastSourceShape` is not a suffix of the result.
4391 bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4392 broadcastSourceShape.size()));
4393 if (!isSuffix)
4394 return failure();
4395
4396 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4397 shapeCastOp, shapeCastOp.getResultVectorType(),
4398 broadcastOp.getSource());
4399 return success();
4400 }
4401};
4402
4403} // namespace
4404
4405void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
4406 MLIRContext *context) {
4407 results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4408}
4409
4410//===----------------------------------------------------------------------===//
4411// VectorBitCastOp
4412//===----------------------------------------------------------------------===//
4413
4414LogicalResult BitCastOp::verify() {
4415 auto sourceVectorType = getSourceVectorType();
4416 auto resultVectorType = getResultVectorType();
4417
4418 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4419 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4420 return emitOpError("dimension size mismatch at: ") << i;
4421 }
4422
4423 DataLayout dataLayout = DataLayout::closest(*this);
4424 auto sourceElementBits =
4425 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
4426 auto resultElementBits =
4427 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
4428
4429 if (sourceVectorType.getRank() == 0) {
4430 if (sourceElementBits != resultElementBits)
4431 return emitOpError("source/result bitwidth of the 0-D vector element "
4432 "types must be equal");
4433 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
4434 resultElementBits * resultVectorType.getShape().back()) {
4435 return emitOpError(
4436 "source/result bitwidth of the minor 1-D vectors must be equal");
4437 }
4438
4439 return success();
4440}
4441
4442OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4443 // Nop cast.
4444 if (getSource().getType() == getResult().getType())
4445 return getSource();
4446
4447 // Canceling bitcasts.
4448 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4449 if (getResult().getType() == otherOp.getSource().getType())
4450 return otherOp.getSource();
4451
4452 setOperand(otherOp.getSource());
4453 return getResult();
4454 }
4455
4456 Attribute sourceConstant = operands.front();
4457 if (!sourceConstant)
4458 return {};
4459
4460 Type srcElemType = getSourceVectorType().getElementType();
4461 Type dstElemType = getResultVectorType().getElementType();
4462
4463 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
4464 if (floatPack.isSplat()) {
4465 auto splat = floatPack.getSplatValue<FloatAttr>();
4466
4467 // Casting fp16 into fp32.
4468 if (srcElemType.isF16() && dstElemType.isF32()) {
4469 uint32_t bits = static_cast<uint32_t>(
4470 splat.getValue().bitcastToAPInt().getZExtValue());
4471 // Duplicate the 16-bit pattern.
4472 bits = (bits << 16) | (bits & 0xffff);
4473 APInt intBits(32, bits);
4474 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4475 return DenseElementsAttr::get(getResultVectorType(), floatBits);
4476 }
4477 }
4478 }
4479
4480 return {};
4481}
4482
4483//===----------------------------------------------------------------------===//
4484// TypeCastOp
4485//===----------------------------------------------------------------------===//
4486
4487static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4488 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4489 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4490 memRefType.getShape().end());
4491 if (vectorType)
4492 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4493 return res;
4494}
4495
4496/// Build the canonical memRefType with a single vector.
4497/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4498void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4499 Value source) {
4500 result.addOperands(source);
4501 MemRefType memRefType = source.getType().cast<MemRefType>();
4502 VectorType vectorType =
4503 VectorType::get(extractShape(memRefType),
4504 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
4505 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4506 memRefType.getMemorySpace()));
4507}
4508
4509LogicalResult TypeCastOp::verify() {
4510 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
4511 if (!canonicalType.getLayout().isIdentity())
4512 return emitOpError("expects operand to be a memref with identity layout");
4513 if (!getResultMemRefType().getLayout().isIdentity())
4514 return emitOpError("expects result to be a memref with identity layout");
4515 if (getResultMemRefType().getMemorySpace() !=
4516 getMemRefType().getMemorySpace())
4517 return emitOpError("expects result in same memory space");
4518
4519 auto sourceType = getMemRefType();
4520 auto resultType = getResultMemRefType();
4521 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4522 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
4523 return emitOpError(
4524 "expects result and operand with same underlying scalar type: ")
4525 << resultType;
4526 if (extractShape(sourceType) != extractShape(resultType))
4527 return emitOpError(
4528 "expects concatenated result and operand shapes to be equal: ")
4529 << resultType;
4530 return success();
4531}
4532
4533//===----------------------------------------------------------------------===//
4534// TransposeOp
4535//===----------------------------------------------------------------------===//
4536
4537void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4538 Value vector, ArrayRef<int64_t> transp) {
4539 VectorType vt = vector.getType().cast<VectorType>();
4540 SmallVector<int64_t, 4> transposedShape(vt.getRank());
4541 for (unsigned i = 0; i < transp.size(); ++i)
4542 transposedShape[i] = vt.getShape()[transp[i]];
4543
4544 result.addOperands(vector);
4545 result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4546 result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
4547}
4548
4549OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4550 // Eliminate splat constant transpose ops.
4551 if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
4552 if (attr.isSplat())
4553 return attr.reshape(getResultType());
4554
4555 // Eliminate identity transpose ops. This happens when the dimensions of the
4556 // input vector remain in their original order after the transpose operation.
4557 SmallVector<int64_t, 4> transp;
4558 getTransp(transp);
4559
4560 // Check if the permutation of the dimensions contains sequential values:
4561 // {0, 1, 2, ...}.
4562 for (int64_t i = 0, e = transp.size(); i < e; i++) {
4563 if (transp[i] != i)
4564 return {};
4565 }
4566
4567 return getVector();
4568}
4569
4570LogicalResult vector::TransposeOp::verify() {
4571 VectorType vectorType = getVectorType();
4572 VectorType resultType = getResultType();
4573 int64_t rank = resultType.getRank();
4574 if (vectorType.getRank() != rank)
4575 return emitOpError("vector result rank mismatch: ") << rank;
4576 // Verify transposition array.
4577 auto transpAttr = getTransp().getValue();
4578 int64_t size = transpAttr.size();
4579 if (rank != size)
4580 return emitOpError("transposition length mismatch: ") << size;
4581 SmallVector<bool, 8> seen(rank, false);
4582 for (const auto &ta : llvm::enumerate(transpAttr)) {
4583 int64_t i = ta.value().cast<IntegerAttr>().getInt();
4584 if (i < 0 || i >= rank)
4585 return emitOpError("transposition index out of range: ") << i;
4586 if (seen[i])
4587 return emitOpError("duplicate position index: ") << i;
4588 seen[i] = true;
4589 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
4590 return emitOpError("dimension size mismatch at: ") << i;
4591 }
4592 return success();
4593}
4594
4595Optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
4596 return llvm::to_vector<4>(getResultType().getShape());
4597}
4598
4599namespace {
4600
4601// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
4602class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
4603public:
4604 using OpRewritePattern::OpRewritePattern;
4605
4606 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4607 PatternRewriter &rewriter) const override {
4608 // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
4609 auto getPermutation = [](vector::TransposeOp transpose) {
4610 SmallVector<int64_t, 4> permutation;
4611 transpose.getTransp(permutation);
4612 return permutation;
4613 };
4614
4615 // Composes two permutations: result[i] = permutation1[permutation2[i]].
4616 auto composePermutations = [](ArrayRef<int64_t> permutation1,
4617 ArrayRef<int64_t> permutation2) {
4618 SmallVector<int64_t, 4> result;
4619 for (auto index : permutation2)
4620 result.push_back(permutation1[index]);
4621 return result;
4622 };
4623
4624 // Return if the input of 'transposeOp' is not defined by another transpose.
4625 vector::TransposeOp parentTransposeOp =
4626 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
4627 if (!parentTransposeOp)
4628 return failure();
4629
4630 SmallVector<int64_t, 4> permutation = composePermutations(
4631 getPermutation(parentTransposeOp), getPermutation(transposeOp));
4632 // Replace 'transposeOp' with a new transpose operation.
4633 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
4634 transposeOp, transposeOp.getResult().getType(),
4635 parentTransposeOp.getVector(),
4636 vector::getVectorSubscriptAttr(rewriter, permutation));
4637 return success();
4638 }
4639};
4640
4641// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
4642struct FoldTransposedScalarBroadcast final
4643 : public OpRewritePattern<vector::TransposeOp> {
4644 using OpRewritePattern::OpRewritePattern;
4645
4646 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
4647 PatternRewriter &rewriter) const override {
4648 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
4649 if (!bcastOp)
4650 return failure();
4651
4652 auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
4653 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
4654 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4655 transposeOp, transposeOp.getResultType(), bcastOp.getSource());
4656 return success();
4657 }
4658
4659 return failure();
4660 }
4661};
4662
4663// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
4664class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
4665public:
4666 using OpRewritePattern::OpRewritePattern;
4667
4668 LogicalResult matchAndRewrite(TransposeOp transposeOp,
4669 PatternRewriter &rewriter) const override {
4670 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
4671 if (!splatOp)
4672 return failure();
4673
4674 rewriter.replaceOpWithNewOp<vector::SplatOp>(
4675 transposeOp, transposeOp.getResultType(), splatOp.getInput());
4676 return success();
4677 }
4678};
4679
4680} // namespace
4681
4682void vector::TransposeOp::getCanonicalizationPatterns(
4683 RewritePatternSet &results, MLIRContext *context) {
4684 results
4685 .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
4686 context);
4687}
4688
4689void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
4690 populateFromInt64AttrArray(getTransp(), results);
4691}
4692
4693//===----------------------------------------------------------------------===//
4694// ConstantMaskOp
4695//===----------------------------------------------------------------------===//
4696
4697LogicalResult ConstantMaskOp::verify() {
4698 auto resultType = getResult().getType().cast<VectorType>();
4699 // Check the corner case of 0-D vectors first.
4700 if (resultType.getRank() == 0) {
4701 if (getMaskDimSizes().size() != 1)
4702 return emitError("array attr must have length 1 for 0-D vectors");
4703 auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
4704 if (dim != 0 && dim != 1)
4705 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
4706 return success();
4707 }
4708
4709 // Verify that array attr size matches the rank of the vector result.
4710 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
4711 return emitOpError(
4712 "must specify array attr of size equal vector result rank");
4713 // Verify that each array attr element is in bounds of corresponding vector
4714 // result dimension size.
4715 auto resultShape = resultType.getShape();
4716 SmallVector<int64_t, 4> maskDimSizes;
4717 for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
4718 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
4719 if (attrValue < 0 || attrValue > resultShape[it.index()])
4720 return emitOpError(
4721 "array attr of size out of bounds of vector result dimension size");
4722 maskDimSizes.push_back(attrValue);
4723 }
4724 // Verify that if one mask dim size is zero, they all should be zero (because
4725 // the mask region is a conjunction of each mask dimension interval).
4726 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
4727 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
4728 if (anyZeros && !allZeros)
4729 return emitOpError("expected all mask dim sizes to be zeros, "
4730 "as a result of conjunction with zero mask dim");
4731 // Verify that if the mask type is scalable, dimensions should be zero because
4732 // constant scalable masks can only be defined for the "none set" or "all set"
4733 // cases, and there is no VLA way to define an "all set" case for
4734 // `vector.constant_mask`. In the future, a convention could be established
4735 // to decide if a specific dimension value could be considered as "all set".
4736 if (resultType.isScalable() &&
4737 getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
4738 return emitOpError("expected mask dim sizes for scalable masks to be 0");
4739 return success();
4740}
4741
4742//===----------------------------------------------------------------------===//
4743// CreateMaskOp
4744//===----------------------------------------------------------------------===//
4745
4746LogicalResult CreateMaskOp::verify() {
4747 auto vectorType = getResult().getType().cast<VectorType>();
4748 // Verify that an operand was specified for each result vector each dimension.
4749 if (vectorType.getRank() == 0) {
4750 if (getNumOperands() != 1)
4751 return emitOpError(
4752 "must specify exactly one operand for 0-D create_mask");
4753 } else if (getNumOperands() !=
4754 getResult().getType().cast<VectorType>().getRank()) {
4755 return emitOpError(
4756 "must specify an operand for each result vector dimension");
4757 }
4758 return success();
4759}
4760
4761namespace {
4762
4763// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
4764class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
4765public:
4766 using OpRewritePattern::OpRewritePattern;
4767
4768 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
4769 PatternRewriter &rewriter) const override {
4770 // Return if any of 'createMaskOp' operands are not defined by a constant.
4771 auto isNotDefByConstant = [](Value operand) {
4772 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
4773 };
4774 if (llvm::any_of(createMaskOp.operands(), isNotDefByConstant))
4775 return failure();
4776
4777 // CreateMaskOp for scalable vectors can be folded only if all dimensions
4778 // are negative or zero.
4779 if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
4780 if (vType.isScalable())
4781 for (auto opDim : createMaskOp.getOperands()) {
4782 APInt intVal;
4783 if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
4784 intVal.isStrictlyPositive())
4785 return failure();
4786 }
4787 }
4788
4789 // Gather constant mask dimension sizes.
4790 SmallVector<int64_t, 4> maskDimSizes;
4791 for (auto it : llvm::zip(createMaskOp.operands(),
4792 createMaskOp.getType().getShape())) {
4793 auto *defOp = std::get<0>(it).getDefiningOp();
4794 int64_t maxDimSize = std::get<1>(it);
4795 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
4796 dimSize = std::min(dimSize, maxDimSize);
4797 // If one of dim sizes is zero, set all dims to zero.
4798 if (dimSize <= 0) {
4799 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
4800 break;
4801 }
4802 maskDimSizes.push_back(dimSize);
4803 }
4804 // Replace 'createMaskOp' with ConstantMaskOp.
4805 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
4806 createMaskOp, createMaskOp.getResult().getType(),
4807 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
4808 return success();
4809 }
4810};
4811
4812} // namespace
4813
4814void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
4815 MLIRContext *context) {
4816 results.add<CreateMaskFolder>(context);
4817}
4818
4819//===----------------------------------------------------------------------===//
4820// ScanOp
4821//===----------------------------------------------------------------------===//
4822
4823LogicalResult ScanOp::verify() {
4824 VectorType srcType = getSourceType();
4825 VectorType initialType = getInitialValueType();
4826 // Check reduction dimension < rank.
4827 int64_t srcRank = srcType.getRank();
4828 int64_t reductionDim = getReductionDim();
4829 if (reductionDim >= srcRank)
4830 return emitOpError("reduction dimension ")
4831 << reductionDim << " has to be less than " << srcRank;
4832
4833 // Check that rank(initial_value) = rank(src) - 1.
4834 int64_t initialValueRank = initialType.getRank();
4835 if (initialValueRank != srcRank - 1)
4836 return emitOpError("initial value rank ")
4837 << initialValueRank << " has to be equal to " << srcRank - 1;
4838
4839 // Check shapes of initial value and src.
4840 ArrayRef<int64_t> srcShape = srcType.getShape();
4841 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
4842 SmallVector<int64_t> expectedShape;
4843 for (int i = 0; i < srcRank; i++) {
4844 if (i != reductionDim)
4845 expectedShape.push_back(srcShape[i]);
4846 }
4847 if (llvm::any_of(llvm::zip(initialValueShapes, expectedShape),
4848 [](std::tuple<int64_t, int64_t> s) {
4849 return std::get<0>(s) != std::get<1>(s);
4850 })) {
4851 return emitOpError("incompatible input/initial value shapes");
4852 }
4853
4854 // Verify supported reduction kind.
4855 Type eltType = getDestType().getElementType();
4856 if (!isSupportedCombiningKind(getKind(), eltType))
4857 return emitOpError("unsupported reduction type ")
4858 << eltType << " for kind '" << stringifyCombiningKind(getKind())
4859 << "'";
4860
4861 return success();
4862}
4863
4864void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
4865 RewritePatternSet &patterns, PatternBenefit benefit) {
4866 patterns
4867 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
4868 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
4869 StridedSliceConstantMaskFolder, TransposeFolder>(
4870 patterns.getContext(), benefit);
4871}
4872
4873//===----------------------------------------------------------------------===//
4874// SplatOp
4875//===----------------------------------------------------------------------===//
4876
4877OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
4878 auto constOperand = operands.front();
4879 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
4880 return {};
4881
4882 // SplatElementsAttr::get treats single value for second arg as being a splat.
4883 return SplatElementsAttr::get(getType(), {constOperand});
4884}
4885
4886//===----------------------------------------------------------------------===//
4887// WarpExecuteOnLane0Op
4888//===----------------------------------------------------------------------===//
4889
4890void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
4891 p << "(" << getLaneid() << ")";
4892
4893 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
4894 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
4895 p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
4896
4897 if (!getArgs().empty())
4898 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
4899 if (!getResults().empty())
4900 p << " -> (" << getResults().getTypes() << ')';
4901 p << " ";
4902 p.printRegion(getRegion(),
4903 /*printEntryBlockArgs=*/true,
4904 /*printBlockTerminators=*/!getResults().empty());
4905 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
4906}
4907
4908ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
4909 OperationState &result) {
4910 // Create the region.
4911 result.regions.reserve(1);
4912 Region *warpRegion = result.addRegion();
4913
4914 auto &builder = parser.getBuilder();
4915 OpAsmParser::UnresolvedOperand laneId;
4916
4917 // Parse predicate operand.
4918 if (parser.parseLParen() ||
1
Taking false branch
4919 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
4920 parser.parseRParen())
4921 return failure();
4922
4923 int64_t warpSize;
2
'warpSize' declared without an initial value
4924 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Taking false branch
4925 parser.parseRSquare())
4926 return failure();
4927 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
4928 builder.getContext())),
4929 builder.getI64IntegerAttr(warpSize));
13
1st function call argument is an uninitialized value
4930
4931 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
4932 return failure();
4933
4934 llvm::SMLoc inputsOperandsLoc;
4935 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
4936 SmallVector<Type> inputTypes;
4937 if (succeeded(parser.parseOptionalKeyword("args"))) {
4938 if (parser.parseLParen())
4939 return failure();
4940
4941 inputsOperandsLoc = parser.getCurrentLocation();
4942 if (parser.parseOperandList(inputsOperands) ||
4943 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
4944 return failure();
4945 }
4946 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
4947 result.operands))
4948 return failure();
4949
4950 // Parse optional results type list.
4951 if (parser.parseOptionalArrowTypeList(result.types))
4952 return failure();
4953 // Parse the region.
4954 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
4955 /*argTypes=*/{}))
4956 return failure();
4957 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
4958
4959 // Parse the optional attribute list.
4960 if (parser.parseOptionalAttrDict(result.attributes))
4961 return failure();
4962 return success();
4963}
4964
4965void WarpExecuteOnLane0Op::getSuccessorRegions(
4966 Optional<unsigned> index, ArrayRef<Attribute> operands,
4967 SmallVectorImpl<RegionSuccessor> &regions) {
4968 if (index) {
4969 regions.push_back(RegionSuccessor(getResults()));
4970 return;
4971 }
4972
4973 // The warp region is always executed
4974 regions.push_back(RegionSuccessor(&getWarpRegion()));
4975}
4976
4977void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
4978 TypeRange resultTypes, Value laneId,
4979 int64_t warpSize) {
4980 build(builder, result, resultTypes, laneId, warpSize,
4981 /*operands=*/llvm::None, /*argTypes=*/llvm::None);
4982}
4983
4984void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
4985 TypeRange resultTypes, Value laneId,
4986 int64_t warpSize, ValueRange args,
4987 TypeRange blockArgTypes) {
4988 result.addOperands(laneId);
4989 result.addAttribute(getAttributeNames()[0],
4990 builder.getI64IntegerAttr(warpSize));
4991 result.addTypes(resultTypes);
4992 result.addOperands(args);
4993 assert(args.size() == blockArgTypes.size())(static_cast <bool> (args.size() == blockArgTypes.size(
)) ? void (0) : __assert_fail ("args.size() == blockArgTypes.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4993, __extension__
__PRETTY_FUNCTION__))
;
4994 OpBuilder::InsertionGuard guard(builder);
4995 Region *warpRegion = result.addRegion();
4996 Block *block = builder.createBlock(warpRegion);
4997 for (auto it : llvm::zip(blockArgTypes, args))
4998 block->addArgument(std::get<0>(it), std::get<1>(it).getLoc());
4999}
5000
5001/// Helper check if the distributed vector type is consistent with the expanded
5002/// type and distributed size.
5003static LogicalResult verifyDistributedType(Type expanded, Type distributed,
5004 int64_t warpSize, Operation *op) {
5005 // If the types matches there is no distribution.
5006 if (expanded == distributed)
5007 return success();
5008 auto expandedVecType = expanded.dyn_cast<VectorType>();
5009 auto distributedVecType = distributed.dyn_cast<VectorType>();
5010 if (!expandedVecType || !distributedVecType)
5011 return op->emitOpError("expected vector type for distributed operands.");
5012 if (expandedVecType.getRank() != distributedVecType.getRank() ||
5013 expandedVecType.getElementType() != distributedVecType.getElementType())
5014 return op->emitOpError(
5015 "expected distributed vectors to have same rank and element type.");
5016 bool foundDistributedDim = false;
5017 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5018 if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5019 continue;
5020 if (expandedVecType.getDimSize(i) ==
5021 distributedVecType.getDimSize(i) * warpSize) {
5022 if (foundDistributedDim)
5023 return op->emitOpError()
5024 << "expected only one dimension to be distributed from "
5025 << expandedVecType << " to " << distributedVecType;
5026 foundDistributedDim = true;
5027 continue;
5028 }
5029 return op->emitOpError() << "incompatible distribution dimensions from "
5030 << expandedVecType << " to " << distributedVecType;
5031 }
5032 return success();
5033}
5034
5035LogicalResult WarpExecuteOnLane0Op::verify() {
5036 if (getArgs().size() != getWarpRegion().getNumArguments())
5037 return emitOpError(
5038 "expected same number op arguments and block arguments.");
5039 auto yield =
5040 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5041 if (yield.getNumOperands() != getNumResults())
5042 return emitOpError(
5043 "expected same number of yield operands and return values.");
5044 int64_t warpSize = getWarpSize();
5045 for (auto it : llvm::zip(getWarpRegion().getArguments(), getArgs())) {
5046 if (failed(verifyDistributedType(std::get<0>(it).getType(),
5047 std::get<1>(it).getType(), warpSize,
5048 getOperation())))
5049 return failure();
5050 }
5051 for (auto it : llvm::zip(yield.getOperands(), getResults())) {
5052 if (failed(verifyDistributedType(std::get<0>(it).getType(),
5053 std::get<1>(it).getType(), warpSize,
5054 getOperation())))
5055 return failure();
5056 }
5057 return success();
5058}
5059
5060bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
5061 return succeeded(
5062 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5063}
5064
5065Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
5066 CombiningKind kind, Value v1, Value v2) {
5067 Type t1 = getElementTypeOrSelf(v1.getType());
5068 Type t2 = getElementTypeOrSelf(v2.getType());
5069 switch (kind) {
5070 case CombiningKind::ADD:
5071 if (t1.isIntOrIndex() && t2.isIntOrIndex())
5072 return b.createOrFold<arith::AddIOp>(loc, v1, v2);
5073 else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5074 return b.createOrFold<arith::AddFOp>(loc, v1, v2);
5075 llvm_unreachable("invalid value types for ADD reduction")::llvm::llvm_unreachable_internal("invalid value types for ADD reduction"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5075)
;
5076 case CombiningKind::AND:
5077 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5077, __extension__
__PRETTY_FUNCTION__))
;
5078 return b.createOrFold<arith::AndIOp>(loc, v1, v2);
5079 case CombiningKind::MAXF:
5080 assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5081, __extension__
__PRETTY_FUNCTION__))
5081 "expected float values")(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5081, __extension__
__PRETTY_FUNCTION__))
;
5082 return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
5083 case CombiningKind::MINF:
5084 assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5085, __extension__
__PRETTY_FUNCTION__))
5085 "expected float values")(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5085, __extension__
__PRETTY_FUNCTION__))
;
5086 return b.createOrFold<arith::MinFOp>(loc, v1, v2);
5087 case CombiningKind::MAXSI:
5088 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5088, __extension__
__PRETTY_FUNCTION__))
;
5089 return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
5090 case CombiningKind::MINSI:
5091 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5091, __extension__
__PRETTY_FUNCTION__))
;
5092 return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
5093 case CombiningKind::MAXUI:
5094 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5094, __extension__
__PRETTY_FUNCTION__))
;
5095 return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
5096 case CombiningKind::MINUI:
5097 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5097, __extension__
__PRETTY_FUNCTION__))
;
5098 return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
5099 case CombiningKind::MUL:
5100 if (t1.isIntOrIndex() && t2.isIntOrIndex())
5101 return b.createOrFold<arith::MulIOp>(loc, v1, v2);
5102 else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5103 return b.createOrFold<arith::MulFOp>(loc, v1, v2);
5104 llvm_unreachable("invalid value types for MUL reduction")::llvm::llvm_unreachable_internal("invalid value types for MUL reduction"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5104)
;
5105 case CombiningKind::OR:
5106 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5106, __extension__
__PRETTY_FUNCTION__))
;
5107 return b.createOrFold<arith::OrIOp>(loc, v1, v2);
5108 case CombiningKind::XOR:
5109 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5109, __extension__
__PRETTY_FUNCTION__))
;
5110 return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
5111 };
5112 llvm_unreachable("unknown CombiningKind")::llvm::llvm_unreachable_internal("unknown CombiningKind", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 5112)
;
5113}
5114
5115//===----------------------------------------------------------------------===//
5116// TableGen'd op method definitions
5117//===----------------------------------------------------------------------===//
5118
5119#define GET_ATTRDEF_CLASSES
5120#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
5121
5122#define GET_OP_CLASSES
5123#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

/build/llvm-toolchain-snapshot-16~++20221003111214+1fa2019828ca/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() {}
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a loc(...) specifier if printing debug info is enabled.
331 virtual void printOptionalLocationSpecifier(Location loc) = 0;
332
333 /// Print a newline and indent the printer to the start of the current
334 /// operation.
335 virtual void printNewline() = 0;
336
337 /// Print a block argument in the usual format of:
338 /// %ssaName : type {attr1=42} loc("here")
339 /// where location printing is controlled by the standard internal option.
340 /// You may pass omitType=true to not print a type, and pass an empty
341 /// attribute list if you don't care for attributes.
342 virtual void printRegionArgument(BlockArgument arg,
343 ArrayRef<NamedAttribute> argAttrs = {},
344 bool omitType = false) = 0;
345
346 /// Print implementations for various things an operation contains.
347 virtual void printOperand(Value value) = 0;
348 virtual void printOperand(Value value, raw_ostream &os) = 0;
349
350 /// Print a comma separated list of operands.
351 template <typename ContainerType>
352 void printOperands(const ContainerType &container) {
353 printOperands(container.begin(), container.end());
354 }
355
356 /// Print a comma separated list of operands.
357 template <typename IteratorType>
358 void printOperands(IteratorType it, IteratorType end) {
359 llvm::interleaveComma(llvm::make_range(it, end), getStream(),
360 [this](Value value) { printOperand(value); });
361 }
362
363 /// Print the given successor.
364 virtual void printSuccessor(Block *successor) = 0;
365
366 /// Print the successor and its operands.
367 virtual void printSuccessorAndUseList(Block *successor,
368 ValueRange succOperands) = 0;
369
370 /// If the specified operation has attributes, print out an attribute
371 /// dictionary with their values. elidedAttrs allows the client to ignore
372 /// specific well known attributes, commonly used if the attribute value is
373 /// printed some other way (like as a fixed operand).
374 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
375 ArrayRef<StringRef> elidedAttrs = {}) = 0;
376
377 /// If the specified operation has attributes, print out an attribute
378 /// dictionary prefixed with 'attributes'.
379 virtual void
380 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
381 ArrayRef<StringRef> elidedAttrs = {}) = 0;
382
383 /// Print the entire operation with the default generic assembly form.
384 /// If `printOpName` is true, then the operation name is printed (the default)
385 /// otherwise it is omitted and the print will start with the operand list.
386 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
387
388 /// Prints a region.
389 /// If 'printEntryBlockArgs' is false, the arguments of the
390 /// block are not printed. If 'printBlockTerminator' is false, the terminator
391 /// operation of the block is not printed. If printEmptyBlock is true, then
392 /// the block header is printed even if the block is empty.
393 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
394 bool printBlockTerminators = true,
395 bool printEmptyBlock = false) = 0;
396
397 /// Renumber the arguments for the specified region to the same names as the
398 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
399 /// operations. If any entry in namesToUse is null, the corresponding
400 /// argument name is left alone.
401 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
402
403 /// Prints an affine map of SSA ids, where SSA id names are used in place
404 /// of dims/symbols.
405 /// Operand values must come from single-result sources, and be valid
406 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
407 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
408 ValueRange operands) = 0;
409
410 /// Prints an affine expression of SSA ids with SSA id names used instead of
411 /// dims and symbols.
412 /// Operand values must come from single-result sources, and be valid
413 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
414 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
415 ValueRange symOperands) = 0;
416
417 /// Print the complete type of an operation in functional form.
418 void printFunctionalType(Operation *op);
419 using AsmPrinter::printFunctionalType;
420};
421
422// Make the implementations convenient to use.
423inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
424 p.printOperand(value);
425 return p;
426}
427
428template <typename T,
429 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
430 !std::is_convertible<T &, Value &>::value,
431 T> * = nullptr>
432inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
433 p.printOperands(values);
434 return p;
435}
436
437inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
438 p.printSuccessor(value);
439 return p;
440}
441
442//===----------------------------------------------------------------------===//
443// AsmParser
444//===----------------------------------------------------------------------===//
445
446/// This base class exposes generic asm parser hooks, usable across the various
447/// derived parsers.
448class AsmParser {
449public:
450 AsmParser() = default;
451 virtual ~AsmParser();
452
453 MLIRContext *getContext() const;
454
455 /// Return the location of the original name token.
456 virtual SMLoc getNameLoc() const = 0;
457
458 //===--------------------------------------------------------------------===//
459 // Utilities
460 //===--------------------------------------------------------------------===//
461
462 /// Emit a diagnostic at the specified location and return failure.
463 virtual InFlightDiagnostic emitError(SMLoc loc,
464 const Twine &message = {}) = 0;
465
466 /// Return a builder which provides useful access to MLIRContext, global
467 /// objects like types and attributes.
468 virtual Builder &getBuilder() const = 0;
469
470 /// Get the location of the next token and store it into the argument. This
471 /// always succeeds.
472 virtual SMLoc getCurrentLocation() = 0;
473 ParseResult getCurrentLocation(SMLoc *loc) {
474 *loc = getCurrentLocation();
475 return success();
476 }
477
478 /// Re-encode the given source location as an MLIR location and return it.
479 /// Note: This method should only be used when a `Location` is necessary, as
480 /// the encoding process is not efficient.
481 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
482
483 //===--------------------------------------------------------------------===//
484 // Token Parsing
485 //===--------------------------------------------------------------------===//
486
487 /// Parse a '->' token.
488 virtual ParseResult parseArrow() = 0;
489
490 /// Parse a '->' token if present
491 virtual ParseResult parseOptionalArrow() = 0;
492
493 /// Parse a `{` token.
494 virtual ParseResult parseLBrace() = 0;
495
496 /// Parse a `{` token if present.
497 virtual ParseResult parseOptionalLBrace() = 0;
498
499 /// Parse a `}` token.
500 virtual ParseResult parseRBrace() = 0;
501
502 /// Parse a `}` token if present.
503 virtual ParseResult parseOptionalRBrace() = 0;
504
505 /// Parse a `:` token.
506 virtual ParseResult parseColon() = 0;
507
508 /// Parse a `:` token if present.
509 virtual ParseResult parseOptionalColon() = 0;
510
511 /// Parse a `,` token.
512 virtual ParseResult parseComma() = 0;
513
514 /// Parse a `,` token if present.
515 virtual ParseResult parseOptionalComma() = 0;
516
517 /// Parse a `=` token.
518 virtual ParseResult parseEqual() = 0;
519
520 /// Parse a `=` token if present.
521 virtual ParseResult parseOptionalEqual() = 0;
522
523 /// Parse a '<' token.
524 virtual ParseResult parseLess() = 0;
525
526 /// Parse a '<' token if present.
527 virtual ParseResult parseOptionalLess() = 0;
528
529 /// Parse a '>' token.
530 virtual ParseResult parseGreater() = 0;
531
532 /// Parse a '>' token if present.
533 virtual ParseResult parseOptionalGreater() = 0;
534
535 /// Parse a '?' token.
536 virtual ParseResult parseQuestion() = 0;
537
538 /// Parse a '?' token if present.
539 virtual ParseResult parseOptionalQuestion() = 0;
540
541 /// Parse a '+' token.
542 virtual ParseResult parsePlus() = 0;
543
544 /// Parse a '+' token if present.
545 virtual ParseResult parseOptionalPlus() = 0;
546
547 /// Parse a '*' token.
548 virtual ParseResult parseStar() = 0;
549
550 /// Parse a '*' token if present.
551 virtual ParseResult parseOptionalStar() = 0;
552
553 /// Parse a '|' token.
554 virtual ParseResult parseVerticalBar() = 0;
555
556 /// Parse a '|' token if present.
557 virtual ParseResult parseOptionalVerticalBar() = 0;
558
559 /// Parse a quoted string token.
560 ParseResult parseString(std::string *string) {
561 auto loc = getCurrentLocation();
562 if (parseOptionalString(string))
563 return emitError(loc, "expected string");
564 return success();
565 }
566
567 /// Parse a quoted string token if present.
568 virtual ParseResult parseOptionalString(std::string *string) = 0;
569
570 /// Parse a `(` token.
571 virtual ParseResult parseLParen() = 0;
572
573 /// Parse a `(` token if present.
574 virtual ParseResult parseOptionalLParen() = 0;
575
576 /// Parse a `)` token.
577 virtual ParseResult parseRParen() = 0;
578
579 /// Parse a `)` token if present.
580 virtual ParseResult parseOptionalRParen() = 0;
581
582 /// Parse a `[` token.
583 virtual ParseResult parseLSquare() = 0;
584
585 /// Parse a `[` token if present.
586 virtual ParseResult parseOptionalLSquare() = 0;
587
588 /// Parse a `]` token.
589 virtual ParseResult parseRSquare() = 0;
590
591 /// Parse a `]` token if present.
592 virtual ParseResult parseOptionalRSquare() = 0;
593
594 /// Parse a `...` token.
595 virtual ParseResult parseEllipsis() = 0;
596
597 /// Parse a `...` token if present;
598 virtual ParseResult parseOptionalEllipsis() = 0;
599
600 /// Parse a floating point value from the stream.
601 virtual ParseResult parseFloat(double &result) = 0;
602
603 /// Parse an integer value from the stream.
604 template <typename IntT>
605 ParseResult parseInteger(IntT &result) {
606 auto loc = getCurrentLocation();
607 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
608 if (!parseResult.has_value())
9
Taking true branch
609 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
610 return *parseResult;
611 }
612
613 /// Parse an optional integer value from the stream.
614 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
615
616 template <typename IntT>
617 OptionalParseResult parseOptionalInteger(IntT &result) {
618 auto loc = getCurrentLocation();
619
620 // Parse the unsigned variant.
621 APInt uintResult;
622 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
623 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
624 return parseResult;
7
Returning without writing to 'result'
625
626 // Try to convert to the provided integer type. sextOrTrunc is correct even
627 // for unsigned types because parseOptionalInteger ensures the sign bit is
628 // zero for non-negated integers.
629 result =
630 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
631 if (APInt(uintResult.getBitWidth(), result) != uintResult)
632 return emitError(loc, "integer value too large");
633 return success();
634 }
635
636 /// These are the supported delimiters around operand lists and region
637 /// argument lists, used by parseOperandList.
638 enum class Delimiter {
639 /// Zero or more operands with no delimiters.
640 None,
641 /// Parens surrounding zero or more operands.
642 Paren,
643 /// Square brackets surrounding zero or more operands.
644 Square,
645 /// <> brackets surrounding zero or more operands.
646 LessGreater,
647 /// {} brackets surrounding zero or more operands.
648 Braces,
649 /// Parens supporting zero or more operands, or nothing.
650 OptionalParen,
651 /// Square brackets supporting zero or more ops, or nothing.
652 OptionalSquare,
653 /// <> brackets supporting zero or more ops, or nothing.
654 OptionalLessGreater,
655 /// {} brackets surrounding zero or more operands, or nothing.
656 OptionalBraces,
657 };
658
659 /// Parse a list of comma-separated items with an optional delimiter. If a
660 /// delimiter is provided, then an empty list is allowed. If not, then at
661 /// least one element will be parsed.
662 ///
663 /// contextMessage is an optional message appended to "expected '('" sorts of
664 /// diagnostics when parsing the delimeters.
665 virtual ParseResult
666 parseCommaSeparatedList(Delimiter delimiter,
667 function_ref<ParseResult()> parseElementFn,
668 StringRef contextMessage = StringRef()) = 0;
669
670 /// Parse a comma separated list of elements that must have at least one entry
671 /// in it.
672 ParseResult
673 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
674 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
675 }
676
677 //===--------------------------------------------------------------------===//
678 // Keyword Parsing
679 //===--------------------------------------------------------------------===//
680
681 /// This class represents a StringSwitch like class that is useful for parsing
682 /// expected keywords. On construction, it invokes `parseKeyword` and
683 /// processes each of the provided cases statements until a match is hit. The
684 /// provided `ResultT` must be assignable from `failure()`.
685 template <typename ResultT = ParseResult>
686 class KeywordSwitch {
687 public:
688 KeywordSwitch(AsmParser &parser)
689 : parser(parser), loc(parser.getCurrentLocation()) {
690 if (failed(parser.parseKeywordOrCompletion(&keyword)))
691 result = failure();
692 }
693
694 /// Case that uses the provided value when true.
695 KeywordSwitch &Case(StringLiteral str, ResultT value) {
696 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
697 }
698 KeywordSwitch &Default(ResultT value) {
699 return Default([&](StringRef, SMLoc) { return std::move(value); });
700 }
701 /// Case that invokes the provided functor when true. The parameters passed
702 /// to the functor are the keyword, and the location of the keyword (in case
703 /// any errors need to be emitted).
704 template <typename FnT>
705 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
706 Case(StringLiteral str, FnT &&fn) {
707 if (result)
708 return *this;
709
710 // If the word was empty, record this as a completion.
711 if (keyword.empty())
712 parser.codeCompleteExpectedTokens(str);
713 else if (keyword == str)
714 result.emplace(std::move(fn(keyword, loc)));
715 return *this;
716 }
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Default(FnT &&fn) {
720 if (!result)
721 result.emplace(fn(keyword, loc));
722 return *this;
723 }
724
725 /// Returns true if this switch has a value yet.
726 bool hasValue() const { return result.has_value(); }
727
728 /// Return the result of the switch.
729 [[nodiscard]] operator ResultT() {
730 if (!result)
731 return parser.emitError(loc, "unexpected keyword: ") << keyword;
732 return std::move(*result);
733 }
734
735 private:
736 /// The parser used to construct this switch.
737 AsmParser &parser;
738
739 /// The location of the keyword, used to emit errors as necessary.
740 SMLoc loc;
741
742 /// The parsed keyword itself.
743 StringRef keyword;
744
745 /// The result of the switch statement or none if currently unknown.
746 Optional<ResultT> result;
747 };
748
749 /// Parse a given keyword.
750 ParseResult parseKeyword(StringRef keyword) {
751 return parseKeyword(keyword, "");
752 }
753 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
754
755 /// Parse a keyword into 'keyword'.
756 ParseResult parseKeyword(StringRef *keyword) {
757 auto loc = getCurrentLocation();
758 if (parseOptionalKeyword(keyword))
759 return emitError(loc, "expected valid keyword");
760 return success();
761 }
762
763 /// Parse the given keyword if present.
764 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
765
766 /// Parse a keyword, if present, into 'keyword'.
767 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
768
769 /// Parse a keyword, if present, and if one of the 'allowedValues',
770 /// into 'keyword'
771 virtual ParseResult
772 parseOptionalKeyword(StringRef *keyword,
773 ArrayRef<StringRef> allowedValues) = 0;
774
775 /// Parse a keyword or a quoted string.
776 ParseResult parseKeywordOrString(std::string *result) {
777 if (failed(parseOptionalKeywordOrString(result)))
778 return emitError(getCurrentLocation())
779 << "expected valid keyword or string";
780 return success();
781 }
782
783 /// Parse an optional keyword or string.
784 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
785
786 //===--------------------------------------------------------------------===//
787 // Attribute/Type Parsing
788 //===--------------------------------------------------------------------===//
789
790 /// Invoke the `getChecked` method of the given Attribute or Type class, using
791 /// the provided location to emit errors in the case of failure. Note that
792 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
793 /// context parameter.
794 template <typename T, typename... ParamsT>
795 auto getChecked(SMLoc loc, ParamsT &&...params) {
796 return T::getChecked([&] { return emitError(loc); },
797 std::forward<ParamsT>(params)...);
798 }
799 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800 /// errors.
801 template <typename T, typename... ParamsT>
802 auto getChecked(ParamsT &&...params) {
803 return T::getChecked([&] { return emitError(getNameLoc()); },
804 std::forward<ParamsT>(params)...);
805 }
806
807 //===--------------------------------------------------------------------===//
808 // Attribute Parsing
809 //===--------------------------------------------------------------------===//
810
811 /// Parse an arbitrary attribute of a given type and return it in result.
812 virtual ParseRes