Bug Summary

File:build/source/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Warning:line 5547, column 23
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name VectorOps.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16 -I tools/mlir/lib/Dialect/Vector/IR -I /build/source/mlir/lib/Dialect/Vector/IR -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1671487667 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-12-20-010714-16201-1 -x c++ /build/source/mlir/lib/Dialect/Vector/IR/VectorOps.cpp

/build/source/mlir/lib/Dialect/Vector/IR/VectorOps.cpp

1//===- VectorOps.cpp - MLIR Vector Dialect Operations ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements convenience types for working with super-vectorization
10// operations, in particular super-vector loads and stores.
11//
12//===----------------------------------------------------------------------===//
13
14#include "mlir/Dialect/Vector/IR/VectorOps.h"
15
16#include "mlir/Dialect/Arith/IR/Arith.h"
17#include "mlir/Dialect/Arith/Utils/Utils.h"
18#include "mlir/Dialect/MemRef/IR/MemRef.h"
19#include "mlir/Dialect/Tensor/IR/Tensor.h"
20#include "mlir/Dialect/Utils/IndexingUtils.h"
21#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
22#include "mlir/IR/AffineExpr.h"
23#include "mlir/IR/AffineMap.h"
24#include "mlir/IR/BlockAndValueMapping.h"
25#include "mlir/IR/Builders.h"
26#include "mlir/IR/BuiltinAttributes.h"
27#include "mlir/IR/BuiltinOps.h"
28#include "mlir/IR/BuiltinTypes.h"
29#include "mlir/IR/DialectImplementation.h"
30#include "mlir/IR/OpImplementation.h"
31#include "mlir/IR/PatternMatch.h"
32#include "mlir/IR/TypeUtilities.h"
33#include "mlir/Support/LLVM.h"
34#include "llvm/ADT/ArrayRef.h"
35#include "llvm/ADT/STLExtras.h"
36#include "llvm/ADT/SmallVector.h"
37#include "llvm/ADT/StringSet.h"
38#include "llvm/ADT/TypeSwitch.h"
39#include "llvm/ADT/bit.h"
40
41#include <cassert>
42#include <cstdint>
43#include <numeric>
44
45#include "mlir/Dialect/Vector/IR/VectorOpsDialect.cpp.inc"
46// Pull in all enum type and utility function definitions.
47#include "mlir/Dialect/Vector/IR/VectorOpsEnums.cpp.inc"
48
49using namespace mlir;
50using namespace mlir::vector;
51
52/// Helper enum to classify mask value.
53enum class MaskFormat {
54 AllTrue = 0,
55 AllFalse = 1,
56 Unknown = 2,
57};
58
59/// Helper method to classify a mask value. Currently, the method
60/// looks "under the hood" of a constant value with dense attributes
61/// and a constant mask operation (since the client may be called at
62/// various stages during progressive lowering).
63static MaskFormat getMaskFormat(Value mask) {
64 if (auto c = mask.getDefiningOp<arith::ConstantOp>()) {
65 // Inspect constant dense values. We count up for bits that
66 // are set, count down for bits that are cleared, and bail
67 // when a mix is detected.
68 if (auto denseElts = c.getValue().dyn_cast<DenseIntElementsAttr>()) {
69 int64_t val = 0;
70 for (bool b : denseElts.getValues<bool>())
71 if (b && val >= 0)
72 val++;
73 else if (!b && val <= 0)
74 val--;
75 else
76 return MaskFormat::Unknown;
77 if (val > 0)
78 return MaskFormat::AllTrue;
79 if (val < 0)
80 return MaskFormat::AllFalse;
81 }
82 } else if (auto m = mask.getDefiningOp<ConstantMaskOp>()) {
83 // Inspect constant mask index. If the index exceeds the
84 // dimension size, all bits are set. If the index is zero
85 // or less, no bits are set.
86 ArrayAttr masks = m.getMaskDimSizes();
87 auto shape = m.getType().getShape();
88 bool allTrue = true;
89 bool allFalse = true;
90 for (auto [maskIdx, dimSize] : llvm::zip_equal(masks, shape)) {
91 int64_t i = maskIdx.cast<IntegerAttr>().getInt();
92 if (i < dimSize)
93 allTrue = false;
94 if (i > 0)
95 allFalse = false;
96 }
97 if (allTrue)
98 return MaskFormat::AllTrue;
99 if (allFalse)
100 return MaskFormat::AllFalse;
101 }
102 return MaskFormat::Unknown;
103}
104
105/// Default callback to build a region with a 'vector.yield' terminator with no
106/// arguments.
107void mlir::vector::buildTerminatedBody(OpBuilder &builder, Location loc) {
108 builder.create<vector::YieldOp>(loc);
109}
110
111// Helper for verifying combining kinds in contractions and reductions.
112static bool isSupportedCombiningKind(CombiningKind combiningKind,
113 Type elementType) {
114 switch (combiningKind) {
115 case CombiningKind::ADD:
116 case CombiningKind::MUL:
117 return elementType.isIntOrIndexOrFloat();
118 case CombiningKind::MINUI:
119 case CombiningKind::MINSI:
120 case CombiningKind::MAXUI:
121 case CombiningKind::MAXSI:
122 case CombiningKind::AND:
123 case CombiningKind::OR:
124 case CombiningKind::XOR:
125 return elementType.isIntOrIndex();
126 case CombiningKind::MINF:
127 case CombiningKind::MAXF:
128 return elementType.isa<FloatType>();
129 }
130 return false;
131}
132
133/// Return true if the last dimension of the MemRefType has unit stride. Also
134/// return true for memrefs with no strides.
135bool mlir::vector::isLastMemrefDimUnitStride(MemRefType type) {
136 int64_t offset;
137 SmallVector<int64_t> strides;
138 auto successStrides = getStridesAndOffset(type, strides, offset);
139 return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
140}
141
142AffineMap mlir::vector::getTransferMinorIdentityMap(ShapedType shapedType,
143 VectorType vectorType) {
144 int64_t elementVectorRank = 0;
145 VectorType elementVectorType =
146 shapedType.getElementType().dyn_cast<VectorType>();
147 if (elementVectorType)
148 elementVectorRank += elementVectorType.getRank();
149 // 0-d transfers are to/from tensor<t>/memref<t> and vector<1xt>.
150 // TODO: replace once we have 0-d vectors.
151 if (shapedType.getRank() == 0 &&
152 vectorType.getShape() == ArrayRef<int64_t>{1})
153 return AffineMap::get(
154 /*numDims=*/0, /*numSymbols=*/0,
155 getAffineConstantExpr(0, shapedType.getContext()));
156 return AffineMap::getMinorIdentityMap(
157 shapedType.getRank(), vectorType.getRank() - elementVectorRank,
158 shapedType.getContext());
159}
160
161bool mlir::vector::checkSameValueRAW(vector::TransferWriteOp defWrite,
162 vector::TransferReadOp read) {
163 return !defWrite.hasOutOfBoundsDim() && !defWrite.getMask() &&
164 !read.getMask() && defWrite.getIndices() == read.getIndices() &&
165 defWrite.getVectorType() == read.getVectorType() &&
166 defWrite.getPermutationMap() == read.getPermutationMap();
167}
168
169bool mlir::vector::checkSameValueWAW(vector::TransferWriteOp write,
170 vector::TransferWriteOp priorWrite) {
171 return priorWrite.getIndices() == write.getIndices() &&
172 priorWrite.getMask() == write.getMask() &&
173 priorWrite.getVectorType() == write.getVectorType() &&
174 priorWrite.getPermutationMap() == write.getPermutationMap();
175}
176
177bool mlir::vector::isDisjointTransferIndices(
178 VectorTransferOpInterface transferA, VectorTransferOpInterface transferB) {
179 // For simplicity only look at transfer of same type.
180 if (transferA.getVectorType() != transferB.getVectorType())
181 return false;
182 unsigned rankOffset = transferA.getLeadingShapedRank();
183 for (unsigned i = 0, e = transferA.indices().size(); i < e; i++) {
184 auto indexA = transferA.indices()[i].getDefiningOp<arith::ConstantOp>();
185 auto indexB = transferB.indices()[i].getDefiningOp<arith::ConstantOp>();
186 // If any of the indices are dynamic we cannot prove anything.
187 if (!indexA || !indexB)
188 continue;
189
190 if (i < rankOffset) {
191 // For leading dimensions, if we can prove that index are different we
192 // know we are accessing disjoint slices.
193 if (indexA.getValue().cast<IntegerAttr>().getInt() !=
194 indexB.getValue().cast<IntegerAttr>().getInt())
195 return true;
196 } else {
197 // For this dimension, we slice a part of the memref we need to make sure
198 // the intervals accessed don't overlap.
199 int64_t distance =
200 std::abs(indexA.getValue().cast<IntegerAttr>().getInt() -
201 indexB.getValue().cast<IntegerAttr>().getInt());
202 if (distance >= transferA.getVectorType().getDimSize(i - rankOffset))
203 return true;
204 }
205 }
206 return false;
207}
208
209bool mlir::vector::isDisjointTransferSet(VectorTransferOpInterface transferA,
210 VectorTransferOpInterface transferB) {
211 if (transferA.source() != transferB.source())
212 return false;
213 return isDisjointTransferIndices(transferA, transferB);
214}
215
216// Helper to iterate over n-D vector slice elements. Calculate the next
217// `position` in the n-D vector of size `shape`, applying an offset `offsets`.
218// Modifies the `position` in place. Returns a failure when `position` becomes
219// the end position.
220static LogicalResult incSlicePosition(MutableArrayRef<int64_t> position,
221 ArrayRef<int64_t> shape,
222 ArrayRef<int64_t> offsets) {
223 for (auto [posInDim, dimSize, offsetInDim] :
224 llvm::reverse(llvm::zip_equal(position, shape, offsets))) {
225 ++posInDim;
226 if (posInDim < dimSize + offsetInDim)
227 return success();
228
229 // Carry the overflow to the next loop iteration.
230 posInDim = offsetInDim;
231 }
232
233 return failure();
234}
235
236//===----------------------------------------------------------------------===//
237// CombiningKindAttr
238//===----------------------------------------------------------------------===//
239
240namespace mlir {
241namespace vector {
242namespace detail {
243struct BitmaskEnumStorage : public AttributeStorage {
244 using KeyTy = uint64_t;
245
246 BitmaskEnumStorage(KeyTy val) : value(val) {}
247
248 bool operator==(const KeyTy &key) const { return value == key; }
249
250 static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
251 const KeyTy &key) {
252 return new (allocator.allocate<BitmaskEnumStorage>())
253 BitmaskEnumStorage(key);
254 }
255
256 KeyTy value = 0;
257};
258} // namespace detail
259} // namespace vector
260} // namespace mlir
261
262//===----------------------------------------------------------------------===//
263// VectorDialect
264//===----------------------------------------------------------------------===//
265
266void VectorDialect::initialize() {
267 addAttributes<
268#define GET_ATTRDEF_LIST
269#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
270 >();
271
272 addOperations<
273#define GET_OP_LIST
274#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"
275 >();
276}
277
278/// Materialize a single constant operation from a given attribute value with
279/// the desired resultant type.
280Operation *VectorDialect::materializeConstant(OpBuilder &builder,
281 Attribute value, Type type,
282 Location loc) {
283 return builder.create<arith::ConstantOp>(loc, type, value);
284}
285
286IntegerType vector::getVectorSubscriptType(Builder &builder) {
287 return builder.getIntegerType(64);
288}
289
290ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
291 ArrayRef<int64_t> values) {
292 return builder.getI64ArrayAttr(values);
293}
294
295//===----------------------------------------------------------------------===//
296// MultiDimReductionOp
297//===----------------------------------------------------------------------===//
298
299void vector::MultiDimReductionOp::build(OpBuilder &builder,
300 OperationState &result, Value source,
301 Value acc, ArrayRef<bool> reductionMask,
302 CombiningKind kind) {
303 SmallVector<int64_t> reductionDims;
304 for (const auto &en : llvm::enumerate(reductionMask))
305 if (en.value())
306 reductionDims.push_back(en.index());
307 build(builder, result, kind, source, acc,
308 builder.getI64ArrayAttr(reductionDims));
309}
310
311OpFoldResult MultiDimReductionOp::fold(ArrayRef<Attribute> operands) {
312 // Single parallel dim, this is a noop.
313 if (getSourceVectorType().getRank() == 1 && !isReducedDim(0))
314 return getSource();
315 return {};
316}
317
318std::optional<SmallVector<int64_t, 4>>
319MultiDimReductionOp::getShapeForUnroll() {
320 return llvm::to_vector<4>(getSourceVectorType().getShape());
321}
322
323LogicalResult MultiDimReductionOp::verify() {
324 SmallVector<int64_t> targetShape;
325 Type inferredReturnType;
326 for (auto it : llvm::enumerate(getSourceVectorType().getShape()))
327 if (!llvm::any_of(getReductionDims().getValue(), [&](Attribute attr) {
328 return attr.cast<IntegerAttr>().getValue() == it.index();
329 }))
330 targetShape.push_back(it.value());
331 // TODO: update to also allow 0-d vectors when available.
332 if (targetShape.empty())
333 inferredReturnType = getSourceVectorType().getElementType();
334 else
335 inferredReturnType =
336 VectorType::get(targetShape, getSourceVectorType().getElementType());
337 if (getType() != inferredReturnType)
338 return emitOpError() << "destination type " << getType()
339 << " is incompatible with source type "
340 << getSourceVectorType();
341
342 return success();
343}
344
345namespace {
346// Only unit dimensions that are being reduced are folded. If the dimension is
347// unit, but not reduced, it is not folded, thereby keeping the output type the
348// same. If not all dimensions which are reduced are of unit dimension, this
349// transformation does nothing. This is just a generalization of
350// ElideSingleElementReduction for ReduceOp.
351struct ElideUnitDimsInMultiDimReduction
352 : public OpRewritePattern<MultiDimReductionOp> {
353 using OpRewritePattern::OpRewritePattern;
354
355 LogicalResult matchAndRewrite(MultiDimReductionOp reductionOp,
356 PatternRewriter &rewriter) const override {
357 ArrayRef<int64_t> shape = reductionOp.getSourceVectorType().getShape();
358 for (const auto &dim : enumerate(shape)) {
359 if (reductionOp.isReducedDim(dim.index()) && dim.value() != 1)
360 return failure();
361 }
362 Location loc = reductionOp.getLoc();
363 Value acc = reductionOp.getAcc();
364 Value cast;
365 if (reductionOp.getDestType().isa<VectorType>()) {
366 cast = rewriter.create<vector::ShapeCastOp>(
367 loc, reductionOp.getDestType(), reductionOp.getSource());
368 } else {
369 // This means we are reducing all the dimensions, and all reduction
370 // dimensions are of size 1. So a simple extraction would do.
371 cast = rewriter.create<vector::ExtractOp>(
372 loc, reductionOp.getDestType(), reductionOp.getSource(),
373 rewriter.getI64ArrayAttr(SmallVector<int64_t>(shape.size(), 0)));
374 }
375
376 Value result = vector::makeArithReduction(rewriter, loc,
377 reductionOp.getKind(), acc, cast);
378 rewriter.replaceOp(reductionOp, result);
379 return success();
380 }
381};
382} // namespace
383
384void MultiDimReductionOp::getCanonicalizationPatterns(
385 RewritePatternSet &results, MLIRContext *context) {
386 results.add<ElideUnitDimsInMultiDimReduction>(context);
387}
388
389//===----------------------------------------------------------------------===//
390// ReductionOp
391//===----------------------------------------------------------------------===//
392
393void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
394 CombiningKind kind, Value vector) {
395 build(builder, result, kind, vector, /*acc=*/Value());
396}
397
398void vector::ReductionOp::build(OpBuilder &builder, OperationState &result,
399 CombiningKind kind, Value vector, Value acc) {
400 build(builder, result, vector.getType().cast<VectorType>().getElementType(),
401 kind, vector, acc);
402}
403
404LogicalResult ReductionOp::verify() {
405 // Verify for 0-D and 1-D vector.
406 int64_t rank = getVectorType().getRank();
407 if (rank > 1)
408 return emitOpError("unsupported reduction rank: ") << rank;
409
410 // Verify supported reduction kind.
411 Type eltType = getDest().getType();
412 if (!isSupportedCombiningKind(getKind(), eltType))
413 return emitOpError("unsupported reduction type '")
414 << eltType << "' for kind '" << stringifyCombiningKind(getKind())
415 << "'";
416
417 return success();
418}
419
420ParseResult ReductionOp::parse(OpAsmParser &parser, OperationState &result) {
421 SmallVector<OpAsmParser::UnresolvedOperand, 2> operandsInfo;
422 Type redType;
423 Type resType;
424 CombiningKindAttr kindAttr;
425 if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, "kind",
426 result.attributes) ||
427 parser.parseComma() || parser.parseOperandList(operandsInfo) ||
428 parser.parseColonType(redType) ||
429 parser.parseKeywordType("into", resType) ||
430 (!operandsInfo.empty() &&
431 parser.resolveOperand(operandsInfo[0], redType, result.operands)) ||
432 (operandsInfo.size() > 1 &&
433 parser.resolveOperand(operandsInfo[1], resType, result.operands)) ||
434 parser.addTypeToList(resType, result.types))
435 return failure();
436 if (operandsInfo.empty() || operandsInfo.size() > 2)
437 return parser.emitError(parser.getNameLoc(),
438 "unsupported number of operands");
439 return success();
440}
441
442void ReductionOp::print(OpAsmPrinter &p) {
443 p << " ";
444 getKindAttr().print(p);
445 p << ", " << getVector();
446 if (getAcc())
447 p << ", " << getAcc();
448 p << " : " << getVector().getType() << " into " << getDest().getType();
449}
450
451// MaskableOpInterface methods.
452
453/// Returns the mask type expected by this operation.
454Type ReductionOp::getExpectedMaskType() {
455 auto vecType = getVectorType();
456 return vecType.cloneWith(std::nullopt,
457 IntegerType::get(vecType.getContext(), /*width=*/1));
458}
459
460Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
461 OpBuilder &builder, Location loc,
462 Value vector) {
463 switch (op) {
464 case arith::AtomicRMWKind::addf:
465 case arith::AtomicRMWKind::addi:
466 return builder.create<vector::ReductionOp>(vector.getLoc(),
467 CombiningKind::ADD, vector);
468 case arith::AtomicRMWKind::mulf:
469 case arith::AtomicRMWKind::muli:
470 return builder.create<vector::ReductionOp>(vector.getLoc(),
471 CombiningKind::MUL, vector);
472 case arith::AtomicRMWKind::minf:
473 return builder.create<vector::ReductionOp>(vector.getLoc(),
474 CombiningKind::MINF, vector);
475 case arith::AtomicRMWKind::mins:
476 return builder.create<vector::ReductionOp>(vector.getLoc(),
477 CombiningKind::MINSI, vector);
478 case arith::AtomicRMWKind::minu:
479 return builder.create<vector::ReductionOp>(vector.getLoc(),
480 CombiningKind::MINUI, vector);
481 case arith::AtomicRMWKind::maxf:
482 return builder.create<vector::ReductionOp>(vector.getLoc(),
483 CombiningKind::MAXF, vector);
484 case arith::AtomicRMWKind::maxs:
485 return builder.create<vector::ReductionOp>(vector.getLoc(),
486 CombiningKind::MAXSI, vector);
487 case arith::AtomicRMWKind::maxu:
488 return builder.create<vector::ReductionOp>(vector.getLoc(),
489 CombiningKind::MAXUI, vector);
490 case arith::AtomicRMWKind::andi:
491 return builder.create<vector::ReductionOp>(vector.getLoc(),
492 CombiningKind::AND, vector);
493 case arith::AtomicRMWKind::ori:
494 return builder.create<vector::ReductionOp>(vector.getLoc(),
495 CombiningKind::OR, vector);
496 // TODO: Add remaining reduction operations.
497 default:
498 (void)emitOptionalError(loc, "Reduction operation type not supported");
499 break;
500 }
501 return nullptr;
502}
503
504std::optional<SmallVector<int64_t, 4>> ReductionOp::getShapeForUnroll() {
505 return llvm::to_vector<4>(getVectorType().getShape());
506}
507
508namespace {
509struct ElideSingleElementReduction : public OpRewritePattern<ReductionOp> {
510 using OpRewritePattern::OpRewritePattern;
511
512 LogicalResult matchAndRewrite(ReductionOp reductionOp,
513 PatternRewriter &rewriter) const override {
514 if (reductionOp.getVectorType().getDimSize(0) != 1)
515 return failure();
516
517 Location loc = reductionOp.getLoc();
518 Value result = rewriter.create<ExtractOp>(loc, reductionOp.getType(),
519 reductionOp.getVector(),
520 rewriter.getI64ArrayAttr(0));
521
522 if (Value acc = reductionOp.getAcc())
523 result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(),
524 result, acc);
525
526 rewriter.replaceOp(reductionOp, result);
527 return success();
528 }
529};
530} // namespace
531
532void ReductionOp::getCanonicalizationPatterns(RewritePatternSet &results,
533 MLIRContext *context) {
534 results.add<ElideSingleElementReduction>(context);
535}
536
537//===----------------------------------------------------------------------===//
538// ContractionOp
539//===----------------------------------------------------------------------===//
540
541void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
542 Value lhs, Value rhs, Value acc,
543 ArrayRef<ArrayRef<AffineExpr>> indexingExprs,
544 ArrayRef<IteratorType> iteratorTypes) {
545 result.addOperands({lhs, rhs, acc});
546 result.addTypes(acc.getType());
547 result.addAttribute(getIndexingMapsAttrName(result.name),
548 builder.getAffineMapArrayAttr(
549 AffineMap::inferFromExprList(indexingExprs)));
550 result.addAttribute(
551 getIteratorTypesAttrName(result.name),
552 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
553 iteratorTypes, [&](IteratorType t) -> mlir::Attribute {
554 return IteratorTypeAttr::get(builder.getContext(), t);
555 }))));
556}
557
558void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
559 Value lhs, Value rhs, Value acc,
560 ArrayAttr indexingMaps,
561 ArrayAttr iteratorTypes) {
562 build(builder, result, lhs, rhs, acc, indexingMaps, iteratorTypes,
563 ContractionOp::getDefaultKind());
564}
565
566void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
567 Value lhs, Value rhs, Value acc,
568 ArrayAttr indexingMaps,
569 ArrayAttr iteratorTypes, CombiningKind kind) {
570 result.addOperands({lhs, rhs, acc});
571 result.addTypes(acc.getType());
572 result.addAttribute(getIndexingMapsAttrName(result.name), indexingMaps);
573 result.addAttribute(getIteratorTypesAttrName(result.name), iteratorTypes);
574 result.addAttribute(getKindAttrName(result.name),
575 CombiningKindAttr::get(builder.getContext(), kind));
576}
577
578ParseResult ContractionOp::parse(OpAsmParser &parser, OperationState &result) {
579 OpAsmParser::UnresolvedOperand lhsInfo;
580 OpAsmParser::UnresolvedOperand rhsInfo;
581 OpAsmParser::UnresolvedOperand accInfo;
582 SmallVector<OpAsmParser::UnresolvedOperand, 2> masksInfo;
583 SmallVector<Type, 2> types;
584 Type resultType;
585 auto loc = parser.getCurrentLocation();
586 DictionaryAttr dictAttr;
587 // TODO: Unify linalg op attribute parsing.
588 if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
589 parser.parseOperand(lhsInfo) || parser.parseComma() ||
590 parser.parseOperand(rhsInfo) || parser.parseComma() ||
591 parser.parseOperand(accInfo) ||
592 parser.parseTrailingOperandList(masksInfo) ||
593 parser.parseOptionalAttrDict(result.attributes) ||
594 parser.parseColonTypeList(types) ||
595 parser.parseKeywordType("into", resultType) ||
596 parser.resolveOperand(lhsInfo, types[0], result.operands) ||
597 parser.resolveOperand(rhsInfo, types[1], result.operands) ||
598 parser.resolveOperand(accInfo, resultType, result.operands) ||
599 parser.addTypeToList(resultType, result.types))
600 return failure();
601 result.attributes.assign(dictAttr.getValue().begin(),
602 dictAttr.getValue().end());
603
604 // Convert array of string into an array of IteratyType enums. This is needed,
605 // because tests still use the old format when 'iterator_types' attribute is
606 // represented as an array of strings.
607 // TODO: Remove this conversion once tests are fixed.
608 ArrayAttr iteratorTypes =
609 result.attributes.get(getIteratorTypesAttrName(result.name))
610 .cast<ArrayAttr>();
611
612 SmallVector<Attribute> iteratorTypeAttrs;
613
614 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
615 auto maybeIteratorType = symbolizeIteratorType(s);
616 if (!maybeIteratorType.has_value())
617 return parser.emitError(loc) << "unexpected iterator_type (" << s << ")";
618
619 iteratorTypeAttrs.push_back(
620 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
621 }
622 result.attributes.set(getIteratorTypesAttrName(result.name),
623 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
624
625 if (!result.attributes.get(getKindAttrName(result.name))) {
626 result.addAttribute(
627 getKindAttrName(result.name),
628 CombiningKindAttr::get(result.getContext(),
629 ContractionOp::getDefaultKind()));
630 }
631 if (masksInfo.empty())
632 return success();
633 if (masksInfo.size() != 2)
634 return parser.emitError(parser.getNameLoc(),
635 "expected zero or exactly 2 vector mask operands");
636 auto lhsType = types[0].cast<VectorType>();
637 auto rhsType = types[1].cast<VectorType>();
638 auto maskElementType = parser.getBuilder().getI1Type();
639 std::array<Type, 2> maskTypes = {
640 VectorType::Builder(lhsType).setElementType(maskElementType),
641 VectorType::Builder(rhsType).setElementType(maskElementType)};
642 if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
643 return failure();
644 return success();
645}
646
647void ContractionOp::print(OpAsmPrinter &p) {
648 // TODO: Unify printing code with linalg ops.
649 auto attrNames = getTraitAttrNames();
650 llvm::StringSet<> traitAttrsSet;
651 traitAttrsSet.insert(attrNames.begin(), attrNames.end());
652 SmallVector<NamedAttribute, 8> attrs;
653 for (auto attr : (*this)->getAttrs()) {
654 if (attr.getName() == getIteratorTypesAttrName()) {
655 auto iteratorTypes =
656 attr.getValue()
657 .cast<ArrayAttr>()
658 .getAsValueRange<IteratorTypeAttr, IteratorType>();
659 // Convert IteratorType enums into the string representation. This is
660 // needed, because tests still use the old format when 'iterator_types'
661 // attribute is represented as an array of strings.
662 // TODO: Remove this conversion once tests are fixed.
663 SmallVector<Attribute> iteratorTypeNames = llvm::to_vector(
664 llvm::map_range(iteratorTypes, [&](IteratorType t) -> Attribute {
665 return StringAttr::get(getContext(), stringifyIteratorType(t));
666 }));
667
668 attrs.emplace_back(getIteratorTypesAttrName(),
669 ArrayAttr::get(getContext(), iteratorTypeNames));
670 } else if (traitAttrsSet.count(attr.getName().strref()) > 0)
671 attrs.push_back(attr);
672 }
673
674 auto dictAttr = DictionaryAttr::get(getContext(), attrs);
675 p << " " << dictAttr << " " << getLhs() << ", ";
676 p << getRhs() << ", " << getAcc();
677 if (getMasks().size() == 2)
678 p << ", " << getMasks();
679
680 p.printOptionalAttrDict((*this)->getAttrs(), attrNames);
681 p << " : " << getLhs().getType() << ", " << getRhs().getType() << " into "
682 << getResultType();
683}
684
685static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
686 const std::vector<std::pair<int64_t, int64_t>> &map) {
687 for (auto &dimPair : map) {
688 if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
689 dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
690 lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
691 return false;
692 }
693 return true;
694}
695
696static LogicalResult verifyOutputShape(
697 ContractionOp op, VectorType lhsType, VectorType rhsType, Type accType,
698 Type resType,
699 const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
700 const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
701 DenseSet<int64_t> lhsContractingDimSet;
702 DenseSet<int64_t> rhsContractingDimSet;
703 for (auto &dimPair : contractingDimMap) {
704 lhsContractingDimSet.insert(dimPair.first);
705 rhsContractingDimSet.insert(dimPair.second);
706 }
707 DenseSet<int64_t> rhsBatchDimSet;
708 for (auto &dimPair : batchDimMap)
709 rhsBatchDimSet.insert(dimPair.second);
710
711 // Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
712 SmallVector<int64_t, 4> expectedResultDims;
713 for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
714 if (lhsContractingDimSet.count(i) > 0)
715 continue;
716 expectedResultDims.push_back(lhsType.getDimSize(i));
717 }
718
719 // Add free dimensions from 'rhsType' to 'expectedResultDims'.
720 for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
721 if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
722 continue;
723 expectedResultDims.push_back(rhsType.getDimSize(i));
724 }
725
726 // Verify 'expectedResultDims'.
727 if (expectedResultDims.empty()) {
728 // No batch or free dimension implies a scalar result.
729 if (resType.isa<VectorType>() || accType.isa<VectorType>())
730 return op.emitOpError("invalid accumulator/result vector shape");
731 } else {
732 // At least one batch or free dimension implies a vector result.
733 auto resVectorType = resType.dyn_cast<VectorType>();
734 auto accVectorType = accType.dyn_cast<VectorType>();
735 if (!resVectorType || !accVectorType)
736 return op.emitOpError("invalid accumulator/result vector shape");
737
738 // Infer expected result vector type. Lhs + rhs map and lhs + rhs vector
739 // types fully define the result vector type. This assumes the affine maps
740 // are well-formed, which must have been verified already.
741 MLIRContext *ctx = op.getContext();
742 AffineMap lhsMap = op.getIndexingMapsArray()[0];
743 AffineMap rhsMap = op.getIndexingMapsArray()[1];
744 if (getUnusedDimsBitVector({lhsMap, rhsMap}).any())
745 return op.emitOpError(
746 "expected all dimensions to be either a LHS or a RHS dimension");
747 SmallVector<AffineExpr, 4> extents(lhsMap.getNumInputs());
748 for (auto pair :
749 {std::make_pair(lhsType, lhsMap), std::make_pair(rhsType, rhsMap)}) {
750 VectorType v = pair.first;
751 auto map = pair.second;
752 for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
753 unsigned pos = map.getDimPosition(idx);
754 if (!extents[pos])
755 extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
756 }
757 }
758 if (!llvm::all_of(extents, [](AffineExpr e) { return e; }))
759 return op.emitOpError("expected all dimensions to get an extent as "
760 "either a LHS or a RHS dimension");
761
762 AffineMap resMap = op.getIndexingMapsArray()[2];
763 auto extentsMap = AffineMap::get(/*dimCount=*/extents.size(),
764 /*symCount=*/0, extents, ctx);
765 // Compose the resMap with the extentsMap, which is a constant map.
766 AffineMap expectedMap = simplifyAffineMap(resMap.compose(extentsMap));
767 assert(llvm::all_of((static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__
__PRETTY_FUNCTION__))
768 expectedMap.getResults(),(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__
__PRETTY_FUNCTION__))
769 [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) &&(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__
__PRETTY_FUNCTION__))
770 "expected constant extent along all dimensions.")(static_cast <bool> (llvm::all_of( expectedMap.getResults
(), [](AffineExpr e) { return e.isa<AffineConstantExpr>
(); }) && "expected constant extent along all dimensions."
) ? void (0) : __assert_fail ("llvm::all_of( expectedMap.getResults(), [](AffineExpr e) { return e.isa<AffineConstantExpr>(); }) && \"expected constant extent along all dimensions.\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 770, __extension__
__PRETTY_FUNCTION__))
;
771 // Extract the expected shape and build the type.
772 auto expectedShape = llvm::to_vector<4>(
773 llvm::map_range(expectedMap.getResults(), [](AffineExpr e) {
774 return e.cast<AffineConstantExpr>().getValue();
775 }));
776 auto expected =
777 VectorType::get(expectedShape, resVectorType.getElementType());
778 if (resVectorType != expected || accVectorType != expected)
779 return op.emitOpError(
780 "invalid accumulator/result vector shape, expected: ")
781 << expected;
782 }
783 return success();
784}
785
786LogicalResult ContractionOp::verify() {
787 auto lhsType = getLhsType();
788 auto rhsType = getRhsType();
789 auto accType = getAccType();
790 auto resType = getResultType();
791
792 // Verify that an indexing map was specified for each vector operand.
793 if (getIndexingMapsArray().size() != 3)
794 return emitOpError("expected an indexing map for each vector operand");
795
796 // Verify that each index map has 'numIterators' inputs, no symbols, and
797 // that the number of map outputs equals the rank of its associated
798 // vector operand.
799 unsigned numIterators = getIteratorTypes().getValue().size();
800 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
801 auto index = it.index();
802 auto map = it.value();
803 if (map.getNumSymbols() != 0)
804 return emitOpError("expected indexing map ")
805 << index << " to have no symbols";
806 auto vectorType = getOperand(index).getType().dyn_cast<VectorType>();
807 unsigned rank = vectorType ? vectorType.getShape().size() : 0;
808 // Verify that the map has the right number of inputs, outputs, and indices.
809 // This also correctly accounts for (..) -> () for rank-0 results.
810 if (map.getNumDims() != numIterators)
811 return emitOpError("expected indexing map ")
812 << index << " to have " << numIterators << " number of inputs";
813 if (map.getNumResults() != rank)
814 return emitOpError("expected indexing map ")
815 << index << " to have " << rank << " number of outputs";
816 if (!map.isProjectedPermutation())
817 return emitOpError("expected indexing map ")
818 << index << " to be a projected permutation of its inputs";
819 }
820
821 auto contractingDimMap = getContractingDimMap();
822 auto batchDimMap = getBatchDimMap();
823
824 // Verify at least one contracting dimension pair was specified.
825 if (contractingDimMap.empty())
826 return emitOpError("expected at least one contracting dimension pair");
827
828 // Verify contracting dimension map was properly constructed.
829 if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
830 return emitOpError("invalid contracting dimension map");
831
832 // Verify batch dimension map was properly constructed.
833 if (!verifyDimMap(lhsType, rhsType, batchDimMap))
834 return emitOpError("invalid batch dimension map");
835
836 // Verify 'accType' and 'resType' shape.
837 if (failed(verifyOutputShape(*this, lhsType, rhsType, accType, resType,
838 contractingDimMap, batchDimMap)))
839 return failure();
840
841 // Verify that either two vector masks are set or none are set.
842 auto lhsMaskType = getLHSVectorMaskType();
843 auto rhsMaskType = getRHSVectorMaskType();
844 if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
845 return emitOpError("invalid number of vector masks specified");
846 if (lhsMaskType && rhsMaskType) {
847 // Verify mask rank == argument rank.
848 if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
849 rhsMaskType.getShape().size() != rhsType.getShape().size())
850 return emitOpError("invalid vector mask rank");
851 }
852
853 // Verify supported combining kind.
854 auto vectorType = resType.dyn_cast<VectorType>();
855 auto elementType = vectorType ? vectorType.getElementType() : resType;
856 if (!isSupportedCombiningKind(getKind(), elementType))
857 return emitOpError("unsupported contraction type");
858
859 return success();
860}
861
862SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
863 return SmallVector<StringRef>{getIndexingMapsAttrName(),
864 getIteratorTypesAttrName(), getKindAttrName()};
865}
866
867static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
868 for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
869 if (targetExpr == map.getResult(i))
870 return i;
871 return -1;
872}
873
874static std::vector<std::pair<int64_t, int64_t>>
875getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
876 IteratorType targetIteratorType, MLIRContext *context) {
877 std::vector<std::pair<int64_t, int64_t>> dimMap;
878 for (const auto &it : llvm::enumerate(iteratorTypes)) {
879 auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
880 if (iteratorType != targetIteratorType)
881 continue;
882 // Search lhs/rhs map results for 'targetExpr'.
883 auto targetExpr = getAffineDimExpr(it.index(), context);
884 int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
885 int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
886 if (lhsDim >= 0 && rhsDim >= 0)
887 dimMap.emplace_back(lhsDim, rhsDim);
888 }
889 return dimMap;
890}
891
892void ContractionOp::getIterationBounds(
893 SmallVectorImpl<int64_t> &iterationBounds) {
894 auto lhsShape = getLhsType().getShape();
895 auto resVectorType = getResultType().dyn_cast<VectorType>();
896 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
897 SmallVector<int64_t, 2> iterationShape;
898 for (const auto &it : llvm::enumerate(getIteratorTypes())) {
899 // Search lhs/rhs map results for 'targetExpr'.
900 auto targetExpr = getAffineDimExpr(it.index(), getContext());
901 auto iteratorType = it.value().cast<IteratorTypeAttr>().getValue();
902 if (iteratorType == IteratorType::reduction) {
903 // Get reduction dim size from lhs shape (same size in rhsShape).
904 int64_t lhsDimIndex = getResultIndex(indexingMaps[0], targetExpr);
905 assert(lhsDimIndex >= 0)(static_cast <bool> (lhsDimIndex >= 0) ? void (0) : __assert_fail
("lhsDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 905, __extension__ __PRETTY_FUNCTION__))
;
906 iterationBounds.push_back(lhsShape[lhsDimIndex]);
907 continue;
908 }
909 // Get parallel dimension size from result shape.
910 int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
911 assert(resDimIndex >= 0)(static_cast <bool> (resDimIndex >= 0) ? void (0) : __assert_fail
("resDimIndex >= 0", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 911, __extension__ __PRETTY_FUNCTION__))
;
912 assert(resVectorType != nullptr)(static_cast <bool> (resVectorType != nullptr) ? void (
0) : __assert_fail ("resVectorType != nullptr", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 912, __extension__ __PRETTY_FUNCTION__))
;
913 iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
914 }
915}
916
917void ContractionOp::getIterationIndexMap(
918 std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMap) {
919 unsigned numMaps = getIndexingMapsArray().size();
920 iterationIndexMap.resize(numMaps);
921 for (const auto &it : llvm::enumerate(getIndexingMapsArray())) {
922 auto index = it.index();
923 auto map = it.value();
924 for (unsigned i = 0, e = map.getNumResults(); i < e; ++i) {
925 auto dim = map.getResult(i).cast<AffineDimExpr>();
926 iterationIndexMap[index][dim.getPosition()] = i;
927 }
928 }
929}
930
931std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
932 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
933 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::reduction,
934 getContext());
935}
936
937std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
938 SmallVector<AffineMap, 4> indexingMaps(getIndexingMapsArray());
939 return getDimMap(indexingMaps, getIteratorTypes(), IteratorType::parallel,
940 getContext());
941}
942
943std::optional<SmallVector<int64_t, 4>> ContractionOp::getShapeForUnroll() {
944 SmallVector<int64_t, 4> shape;
945 getIterationBounds(shape);
946 return shape;
947}
948
949/// Return a fused vector::ContractionOp which represents a patterns such as:
950///
951/// ```mlir
952/// %c0 = vector.constant 0: ...
953/// %c = vector.contract %a, %b, %c0: ...
954/// %e = add %c, %d: ...
955/// ```
956///
957/// by:
958///
959/// ```mlir
960/// %e = vector.contract %a, %b, %d: ...
961/// ```
962///
963/// Return null if the canonicalization does not apply.
964// TODO: This should be a folding of Add into Contract in core but while they
965// live in different dialects, it is not possible without unnatural
966// dependencies.
967template <typename AddOpType>
968struct CanonicalizeContractAdd : public OpRewritePattern<AddOpType> {
969 using OpRewritePattern<AddOpType>::OpRewritePattern;
970
971 LogicalResult matchAndRewrite(AddOpType addOp,
972 PatternRewriter &rewriter) const override {
973 auto canonicalize = [&](Value maybeContraction,
974 Value otherOperand) -> vector::ContractionOp {
975 vector::ContractionOp contractionOp =
976 dyn_cast_or_null<vector::ContractionOp>(
977 maybeContraction.getDefiningOp());
978 if (!contractionOp)
979 return vector::ContractionOp();
980 if (auto maybeZero = dyn_cast_or_null<arith::ConstantOp>(
981 contractionOp.getAcc().getDefiningOp())) {
982 if (maybeZero.getValue() ==
983 rewriter.getZeroAttr(contractionOp.getAcc().getType())) {
984 BlockAndValueMapping bvm;
985 bvm.map(contractionOp.getAcc(), otherOperand);
986 auto newContraction =
987 cast<vector::ContractionOp>(rewriter.clone(*contractionOp, bvm));
988 rewriter.replaceOp(addOp, newContraction.getResult());
989 return newContraction;
990 }
991 }
992 return vector::ContractionOp();
993 };
994
995 Value a = addOp->getOperand(0), b = addOp->getOperand(1);
996 vector::ContractionOp contract = canonicalize(a, b);
997 contract = contract ? contract : canonicalize(b, a);
998 return contract ? success() : failure();
999 }
1000};
1001
1002void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results,
1003 MLIRContext *context) {
1004 results.add<CanonicalizeContractAdd<arith::AddIOp>,
1005 CanonicalizeContractAdd<arith::AddFOp>>(context);
1006}
1007
1008//===----------------------------------------------------------------------===//
1009// ExtractElementOp
1010//===----------------------------------------------------------------------===//
1011
1012void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1013 Value source) {
1014 result.addOperands({source});
1015 result.addTypes(source.getType().cast<VectorType>().getElementType());
1016}
1017
1018void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result,
1019 Value source, Value position) {
1020 result.addOperands({source, position});
1021 result.addTypes(source.getType().cast<VectorType>().getElementType());
1022}
1023
1024LogicalResult vector::ExtractElementOp::verify() {
1025 VectorType vectorType = getVectorType();
1026 if (vectorType.getRank() == 0) {
1027 if (getPosition())
1028 return emitOpError("expected position to be empty with 0-D vector");
1029 return success();
1030 }
1031 if (vectorType.getRank() != 1)
1032 return emitOpError("unexpected >1 vector rank");
1033 if (!getPosition())
1034 return emitOpError("expected position for 1-D vector");
1035 return success();
1036}
1037
1038OpFoldResult vector::ExtractElementOp::fold(ArrayRef<Attribute> operands) {
1039 // Skip the 0-D vector here now.
1040 if (operands.size() < 2)
1041 return {};
1042
1043 Attribute src = operands[0];
1044 Attribute pos = operands[1];
1045
1046 // Fold extractelement (splat X) -> X.
1047 if (auto splat = getVector().getDefiningOp<vector::SplatOp>())
1048 return splat.getInput();
1049
1050 if (!pos || !src)
1051 return {};
1052
1053 auto srcElements = src.cast<DenseElementsAttr>().getValues<Attribute>();
1054
1055 auto attr = pos.dyn_cast<IntegerAttr>();
1056 uint64_t posIdx = attr.getInt();
1057
1058 return srcElements[posIdx];
1059}
1060
1061//===----------------------------------------------------------------------===//
1062// ExtractOp
1063//===----------------------------------------------------------------------===//
1064
1065void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1066 Value source, ArrayRef<int64_t> position) {
1067 build(builder, result, source, getVectorSubscriptAttr(builder, position));
1068}
1069
1070// Convenience builder which assumes the values are constant indices.
1071void vector::ExtractOp::build(OpBuilder &builder, OperationState &result,
1072 Value source, ValueRange position) {
1073 SmallVector<int64_t, 4> positionConstants =
1074 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
1075 return pos.getDefiningOp<arith::ConstantIndexOp>().value();
1076 }));
1077 build(builder, result, source, positionConstants);
1078}
1079
1080LogicalResult
1081ExtractOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
1082 ValueRange operands, DictionaryAttr attributes,
1083 RegionRange,
1084 SmallVectorImpl<Type> &inferredReturnTypes) {
1085 ExtractOp::Adaptor op(operands, attributes);
1086 auto vectorType = op.getVector().getType().cast<VectorType>();
1087 if (static_cast<int64_t>(op.getPosition().size()) == vectorType.getRank()) {
1088 inferredReturnTypes.push_back(vectorType.getElementType());
1089 } else {
1090 auto n =
1091 std::min<size_t>(op.getPosition().size(), vectorType.getRank() - 1);
1092 inferredReturnTypes.push_back(VectorType::get(
1093 vectorType.getShape().drop_front(n), vectorType.getElementType()));
1094 }
1095 return success();
1096}
1097
1098bool ExtractOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1099 // Allow extracting 1-element vectors instead of scalars.
1100 auto isCompatible = [](TypeRange l, TypeRange r) {
1101 auto vectorType = l.front().dyn_cast<VectorType>();
1102 return vectorType && vectorType.getShape().equals({1}) &&
1103 vectorType.getElementType() == r.front();
1104 };
1105 if (l.size() == 1 && r.size() == 1 &&
1106 (isCompatible(l, r) || isCompatible(r, l)))
1107 return true;
1108 return l == r;
1109}
1110
1111LogicalResult vector::ExtractOp::verify() {
1112 auto positionAttr = getPosition().getValue();
1113 if (positionAttr.size() > static_cast<unsigned>(getVectorType().getRank()))
1114 return emitOpError(
1115 "expected position attribute of rank smaller than vector rank");
1116 for (const auto &en : llvm::enumerate(positionAttr)) {
1117 auto attr = en.value().dyn_cast<IntegerAttr>();
1118 if (!attr || attr.getInt() < 0 ||
1119 attr.getInt() >= getVectorType().getDimSize(en.index()))
1120 return emitOpError("expected position attribute #")
1121 << (en.index() + 1)
1122 << " to be a non-negative integer smaller than the corresponding "
1123 "vector dimension";
1124 }
1125 return success();
1126}
1127
1128template <typename IntType>
1129static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
1130 return llvm::to_vector<4>(llvm::map_range(
1131 arrayAttr.getAsRange<IntegerAttr>(),
1132 [](IntegerAttr attr) { return static_cast<IntType>(attr.getInt()); }));
1133}
1134
1135/// Fold the result of chains of ExtractOp in place by simply concatenating the
1136/// positions.
1137static LogicalResult foldExtractOpFromExtractChain(ExtractOp extractOp) {
1138 if (!extractOp.getVector().getDefiningOp<ExtractOp>())
1139 return failure();
1140
1141 SmallVector<int64_t, 4> globalPosition;
1142 ExtractOp currentOp = extractOp;
1143 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1144 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1145 while (ExtractOp nextOp = currentOp.getVector().getDefiningOp<ExtractOp>()) {
1146 currentOp = nextOp;
1147 auto extrPos = extractVector<int64_t>(currentOp.getPosition());
1148 globalPosition.append(extrPos.rbegin(), extrPos.rend());
1149 }
1150 extractOp.setOperand(currentOp.getVector());
1151 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1152 OpBuilder b(extractOp.getContext());
1153 std::reverse(globalPosition.begin(), globalPosition.end());
1154 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1155 b.getI64ArrayAttr(globalPosition));
1156 return success();
1157}
1158
1159namespace {
1160/// Fold an ExtractOp that is fed by a chain of InsertOps and TransposeOps.
1161/// Walk back a chain of InsertOp/TransposeOp until we hit a match.
1162/// Compose TransposeOp permutations as we walk back.
1163/// This helper class keeps an updated extraction position `extractPosition`
1164/// with extra trailing sentinels.
1165/// The sentinels encode the internal transposition status of the result vector.
1166/// As we iterate, extractPosition is permuted and updated.
1167class ExtractFromInsertTransposeChainState {
1168public:
1169 ExtractFromInsertTransposeChainState(ExtractOp e);
1170
1171 /// Iterate over producing insert and transpose ops until we find a fold.
1172 Value fold();
1173
1174private:
1175 /// Return true if the vector at position `a` is contained within the vector
1176 /// at position `b`. Under insert/extract semantics, this is the same as `a`
1177 /// is a prefix of `b`.
1178 template <typename ContainerA, typename ContainerB>
1179 bool isContainedWithin(const ContainerA &a, const ContainerB &b) {
1180 return a.size() <= b.size() &&
1181 std::equal(a.begin(), a.begin() + a.size(), b.begin());
1182 }
1183
1184 /// Return true if the vector at position `a` intersects the vector at
1185 /// position `b`. Under insert/extract semantics, this is the same as equality
1186 /// of all entries of `a` that are >=0 with the corresponding entries of b.
1187 /// Comparison is on the common prefix (i.e. zip).
1188 template <typename ContainerA, typename ContainerB>
1189 bool intersectsWhereNonNegative(const ContainerA &a, const ContainerB &b) {
1190 for (auto [elemA, elemB] : llvm::zip(a, b)) {
1191 if (elemA < 0 || elemB < 0)
1192 continue;
1193 if (elemA != elemB)
1194 return false;
1195 }
1196 return true;
1197 }
1198
1199 /// Folding is only possible in the absence of an internal permutation in the
1200 /// result vector.
1201 bool canFold() {
1202 return (sentinels ==
1203 makeArrayRef(extractPosition).drop_front(extractedRank));
1204 }
1205
1206 // Helper to get the next defining op of interest.
1207 void updateStateForNextIteration(Value v) {
1208 nextInsertOp = v.getDefiningOp<vector::InsertOp>();
1209 nextTransposeOp = v.getDefiningOp<vector::TransposeOp>();
1210 };
1211
1212 // Case 1. If we hit a transpose, just compose the map and iterate.
1213 // Invariant: insert + transpose do not change rank, we can always compose.
1214 LogicalResult handleTransposeOp();
1215
1216 // Case 2: the insert position matches extractPosition exactly, early return.
1217 LogicalResult handleInsertOpWithMatchingPos(Value &res);
1218
1219 /// Case 3: if the insert position is a prefix of extractPosition, extract a
1220 /// portion of the source of the insert.
1221 /// Example:
1222 /// ```
1223 /// %ins = vector.insert %source, %vest[1]: vector<3x4> into vector<2x3x4x5>
1224 /// // extractPosition == [1, 2, 3]
1225 /// %ext = vector.extract %ins[1, 0]: vector<3x4x5>
1226 /// // can fold to vector.extract %source[0, 3]
1227 /// %ext = vector.extract %source[3]: vector<5x6>
1228 /// ```
1229 /// To traverse through %source, we need to set the leading dims to 0 and
1230 /// drop the extra leading dims.
1231 /// This method updates the internal state.
1232 LogicalResult handleInsertOpWithPrefixPos(Value &res);
1233
1234 /// Try to fold in place to extract(source, extractPosition) and return the
1235 /// folded result. Return null if folding is not possible (e.g. due to an
1236 /// internal tranposition in the result).
1237 Value tryToFoldExtractOpInPlace(Value source);
1238
1239 ExtractOp extractOp;
1240 int64_t vectorRank;
1241 int64_t extractedRank;
1242
1243 InsertOp nextInsertOp;
1244 TransposeOp nextTransposeOp;
1245
1246 /// Sentinel values that encode the internal permutation status of the result.
1247 /// They are set to (-1, ... , -k) at the beginning and appended to
1248 /// `extractPosition`.
1249 /// In the end, the tail of `extractPosition` must be exactly `sentinels` to
1250 /// ensure that there is no internal transposition.
1251 /// Internal transposition cannot be accounted for with a folding pattern.
1252 // TODO: We could relax the internal transposition with an extra transposition
1253 // operation in a future canonicalizer.
1254 SmallVector<int64_t> sentinels;
1255 SmallVector<int64_t> extractPosition;
1256};
1257} // namespace
1258
1259ExtractFromInsertTransposeChainState::ExtractFromInsertTransposeChainState(
1260 ExtractOp e)
1261 : extractOp(e), vectorRank(extractOp.getVectorType().getRank()),
1262 extractedRank(extractOp.getPosition().size()) {
1263 assert(vectorRank >= extractedRank && "extracted pos overflow")(static_cast <bool> (vectorRank >= extractedRank &&
"extracted pos overflow") ? void (0) : __assert_fail ("vectorRank >= extractedRank && \"extracted pos overflow\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1263, __extension__
__PRETTY_FUNCTION__))
;
1264 sentinels.reserve(vectorRank - extractedRank);
1265 for (int64_t i = 0, e = vectorRank - extractedRank; i < e; ++i)
1266 sentinels.push_back(-(i + 1));
1267 extractPosition = extractVector<int64_t>(extractOp.getPosition());
1268 llvm::append_range(extractPosition, sentinels);
1269}
1270
1271// Case 1. If we hit a transpose, just compose the map and iterate.
1272// Invariant: insert + transpose do not change rank, we can always compose.
1273LogicalResult ExtractFromInsertTransposeChainState::handleTransposeOp() {
1274 if (!nextTransposeOp)
1275 return failure();
1276 auto permutation = extractVector<unsigned>(nextTransposeOp.getTransp());
1277 AffineMap m = inversePermutation(
1278 AffineMap::getPermutationMap(permutation, extractOp.getContext()));
1279 extractPosition = applyPermutationMap(m, makeArrayRef(extractPosition));
1280 return success();
1281}
1282
1283// Case 2: the insert position matches extractPosition exactly, early return.
1284LogicalResult
1285ExtractFromInsertTransposeChainState::handleInsertOpWithMatchingPos(
1286 Value &res) {
1287 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1288 if (makeArrayRef(insertedPos) !=
1289 llvm::makeArrayRef(extractPosition).take_front(extractedRank))
1290 return failure();
1291 // Case 2.a. early-exit fold.
1292 res = nextInsertOp.getSource();
1293 // Case 2.b. if internal transposition is present, canFold will be false.
1294 return success();
1295}
1296
1297/// Case 3: if inserted position is a prefix of extractPosition,
1298/// extract a portion of the source of the insertion.
1299/// This method updates the internal state.
1300LogicalResult
1301ExtractFromInsertTransposeChainState::handleInsertOpWithPrefixPos(Value &res) {
1302 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1303 if (!isContainedWithin(insertedPos, extractPosition))
1304 return failure();
1305 // Set leading dims to zero.
1306 std::fill_n(extractPosition.begin(), insertedPos.size(), 0);
1307 // Drop extra leading dims.
1308 extractPosition.erase(extractPosition.begin(),
1309 extractPosition.begin() + insertedPos.size());
1310 extractedRank = extractPosition.size() - sentinels.size();
1311 // Case 3.a. early-exit fold (break and delegate to post-while path).
1312 res = nextInsertOp.getSource();
1313 // Case 3.b. if internal transposition is present, canFold will be false.
1314 return success();
1315}
1316
1317/// Try to fold in place to extract(source, extractPosition) and return the
1318/// folded result. Return null if folding is not possible (e.g. due to an
1319/// internal tranposition in the result).
1320Value ExtractFromInsertTransposeChainState::tryToFoldExtractOpInPlace(
1321 Value source) {
1322 // If we can't fold (either internal transposition, or nothing to fold), bail.
1323 bool nothingToFold = (source == extractOp.getVector());
1324 if (nothingToFold || !canFold())
1325 return Value();
1326 // Otherwise, fold by updating the op inplace and return its result.
1327 OpBuilder b(extractOp.getContext());
1328 extractOp->setAttr(
1329 extractOp.getPositionAttrName(),
1330 b.getI64ArrayAttr(
1331 makeArrayRef(extractPosition).take_front(extractedRank)));
1332 extractOp.getVectorMutable().assign(source);
1333 return extractOp.getResult();
1334}
1335
1336/// Iterate over producing insert and transpose ops until we find a fold.
1337Value ExtractFromInsertTransposeChainState::fold() {
1338 Value valueToExtractFrom = extractOp.getVector();
1339 updateStateForNextIteration(valueToExtractFrom);
1340 while (nextInsertOp || nextTransposeOp) {
1341 // Case 1. If we hit a transpose, just compose the map and iterate.
1342 // Invariant: insert + transpose do not change rank, we can always compose.
1343 if (succeeded(handleTransposeOp())) {
1344 valueToExtractFrom = nextTransposeOp.getVector();
1345 updateStateForNextIteration(valueToExtractFrom);
1346 continue;
1347 }
1348
1349 Value result;
1350 // Case 2: the position match exactly.
1351 if (succeeded(handleInsertOpWithMatchingPos(result)))
1352 return result;
1353
1354 // Case 3: if the inserted position is a prefix of extractPosition, we can
1355 // just extract a portion of the source of the insert.
1356 if (succeeded(handleInsertOpWithPrefixPos(result)))
1357 return tryToFoldExtractOpInPlace(result);
1358
1359 // Case 4: extractPositionRef intersects insertedPosRef on non-sentinel
1360 // values. This is a more difficult case and we bail.
1361 auto insertedPos = extractVector<int64_t>(nextInsertOp.getPosition());
1362 if (isContainedWithin(extractPosition, insertedPos) ||
1363 intersectsWhereNonNegative(extractPosition, insertedPos))
1364 return Value();
1365
1366 // Case 5: No intersection, we forward the extract to insertOp.dest().
1367 valueToExtractFrom = nextInsertOp.getDest();
1368 updateStateForNextIteration(valueToExtractFrom);
1369 }
1370 // If after all this we can fold, go for it.
1371 return tryToFoldExtractOpInPlace(valueToExtractFrom);
1372}
1373
1374/// Fold extractOp with scalar result coming from BroadcastOp or SplatOp.
1375static Value foldExtractFromBroadcast(ExtractOp extractOp) {
1376 Operation *defOp = extractOp.getVector().getDefiningOp();
1377 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1378 return Value();
1379 Value source = defOp->getOperand(0);
1380 if (extractOp.getType() == source.getType())
1381 return source;
1382 auto getRank = [](Type type) {
1383 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1384 };
1385 // If splat or broadcast from a scalar, just return the source scalar.
1386 unsigned broadcastSrcRank = getRank(source.getType());
1387 if (broadcastSrcRank == 0)
1388 return source;
1389
1390 unsigned extractResultRank = getRank(extractOp.getType());
1391 if (extractResultRank >= broadcastSrcRank)
1392 return Value();
1393 // Check that the dimension of the result haven't been broadcasted.
1394 auto extractVecType = extractOp.getType().dyn_cast<VectorType>();
1395 auto broadcastVecType = source.getType().dyn_cast<VectorType>();
1396 if (extractVecType && broadcastVecType &&
1397 extractVecType.getShape() !=
1398 broadcastVecType.getShape().take_back(extractResultRank))
1399 return Value();
1400
1401 auto broadcastOp = cast<vector::BroadcastOp>(defOp);
1402 int64_t rankDiff = broadcastSrcRank - extractResultRank;
1403 // Detect all the positions that come from "dim-1" broadcasting.
1404 // These dimensions correspond to "dim-1" broadcasted dims; set the mathching
1405 // extract position to `0` when extracting from the source operand.
1406 llvm::SetVector<int64_t> broadcastedUnitDims =
1407 broadcastOp.computeBroadcastedUnitDims();
1408 auto extractPos = extractVector<int64_t>(extractOp.getPosition());
1409 for (int64_t i = rankDiff, e = extractPos.size(); i < e; ++i)
1410 if (broadcastedUnitDims.contains(i))
1411 extractPos[i] = 0;
1412 // `rankDiff` leading dimensions correspond to new broadcasted dims, drop the
1413 // matching extract position when extracting from the source operand.
1414 extractPos.erase(extractPos.begin(),
1415 std::next(extractPos.begin(), extractPos.size() - rankDiff));
1416 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1417 OpBuilder b(extractOp.getContext());
1418 extractOp.setOperand(source);
1419 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1420 b.getI64ArrayAttr(extractPos));
1421 return extractOp.getResult();
1422}
1423
1424// Fold extractOp with source coming from ShapeCast op.
1425static Value foldExtractFromShapeCast(ExtractOp extractOp) {
1426 auto shapeCastOp = extractOp.getVector().getDefiningOp<vector::ShapeCastOp>();
1427 if (!shapeCastOp)
1428 return Value();
1429 // Get the nth dimension size starting from lowest dimension.
1430 auto getDimReverse = [](VectorType type, int64_t n) {
1431 return type.getShape().take_back(n + 1).front();
1432 };
1433 int64_t destinationRank =
1434 extractOp.getType().isa<VectorType>()
1435 ? extractOp.getType().cast<VectorType>().getRank()
1436 : 0;
1437 if (destinationRank > shapeCastOp.getSourceVectorType().getRank())
1438 return Value();
1439 if (destinationRank > 0) {
1440 auto destinationType = extractOp.getResult().getType().cast<VectorType>();
1441 for (int64_t i = 0; i < destinationRank; i++) {
1442 // The lowest dimension of of the destination must match the lowest
1443 // dimension of the shapecast op source.
1444 // TODO: This case could be support in a canonicalization pattern.
1445 if (getDimReverse(shapeCastOp.getSourceVectorType(), i) !=
1446 getDimReverse(destinationType, i))
1447 return Value();
1448 }
1449 }
1450 // Extract the strides associated with the extract op vector source. Then use
1451 // this to calculate a linearized position for the extract.
1452 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1453 std::reverse(extractedPos.begin(), extractedPos.end());
1454 SmallVector<int64_t, 4> strides;
1455 int64_t stride = 1;
1456 for (int64_t i = 0, e = extractedPos.size(); i < e; i++) {
1457 strides.push_back(stride);
1458 stride *= getDimReverse(extractOp.getVectorType(), i + destinationRank);
1459 }
1460
1461 int64_t position = linearize(extractedPos, strides);
1462 // Then extract the strides associated to the shapeCast op vector source and
1463 // delinearize the position using those strides.
1464 SmallVector<int64_t, 4> newStrides;
1465 int64_t numDimension =
1466 shapeCastOp.getSourceVectorType().getRank() - destinationRank;
1467 stride = 1;
1468 for (int64_t i = 0; i < numDimension; i++) {
1469 newStrides.push_back(stride);
1470 stride *=
1471 getDimReverse(shapeCastOp.getSourceVectorType(), i + destinationRank);
1472 }
1473 std::reverse(newStrides.begin(), newStrides.end());
1474 SmallVector<int64_t, 4> newPosition = delinearize(newStrides, position);
1475 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1476 OpBuilder b(extractOp.getContext());
1477 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1478 b.getI64ArrayAttr(newPosition));
1479 extractOp.setOperand(shapeCastOp.getSource());
1480 return extractOp.getResult();
1481}
1482
1483/// Fold an ExtractOp from ExtractStridedSliceOp.
1484static Value foldExtractFromExtractStrided(ExtractOp extractOp) {
1485 auto extractStridedSliceOp =
1486 extractOp.getVector().getDefiningOp<vector::ExtractStridedSliceOp>();
1487 if (!extractStridedSliceOp)
1488 return Value();
1489 // Return if 'extractStridedSliceOp' has non-unit strides.
1490 if (extractStridedSliceOp.hasNonUnitStrides())
1491 return Value();
1492
1493 // Trim offsets for dimensions fully extracted.
1494 auto sliceOffsets =
1495 extractVector<int64_t>(extractStridedSliceOp.getOffsets());
1496 while (!sliceOffsets.empty()) {
1497 size_t lastOffset = sliceOffsets.size() - 1;
1498 if (sliceOffsets.back() != 0 ||
1499 extractStridedSliceOp.getType().getDimSize(lastOffset) !=
1500 extractStridedSliceOp.getVectorType().getDimSize(lastOffset))
1501 break;
1502 sliceOffsets.pop_back();
1503 }
1504 unsigned destinationRank = 0;
1505 if (auto vecType = extractOp.getType().dyn_cast<VectorType>())
1506 destinationRank = vecType.getRank();
1507 // The dimensions of the result need to be untouched by the
1508 // extractStridedSlice op.
1509 if (destinationRank >
1510 extractStridedSliceOp.getVectorType().getRank() - sliceOffsets.size())
1511 return Value();
1512 auto extractedPos = extractVector<int64_t>(extractOp.getPosition());
1513 assert(extractedPos.size() >= sliceOffsets.size())(static_cast <bool> (extractedPos.size() >= sliceOffsets
.size()) ? void (0) : __assert_fail ("extractedPos.size() >= sliceOffsets.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1513, __extension__
__PRETTY_FUNCTION__))
;
1514 for (size_t i = 0, e = sliceOffsets.size(); i < e; i++)
1515 extractedPos[i] = extractedPos[i] + sliceOffsets[i];
1516 extractOp.getVectorMutable().assign(extractStridedSliceOp.getVector());
1517 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1518 OpBuilder b(extractOp.getContext());
1519 extractOp->setAttr(ExtractOp::getPositionAttrStrName(),
1520 b.getI64ArrayAttr(extractedPos));
1521 return extractOp.getResult();
1522}
1523
1524/// Fold extract_op fed from a chain of insertStridedSlice ops.
1525static Value foldExtractStridedOpFromInsertChain(ExtractOp op) {
1526 int64_t destinationRank = op.getType().isa<VectorType>()
1527 ? op.getType().cast<VectorType>().getRank()
1528 : 0;
1529 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
1530 while (insertOp) {
1531 int64_t insertRankDiff = insertOp.getDestVectorType().getRank() -
1532 insertOp.getSourceVectorType().getRank();
1533 if (destinationRank > insertOp.getSourceVectorType().getRank())
1534 return Value();
1535 auto insertOffsets = extractVector<int64_t>(insertOp.getOffsets());
1536 auto extractOffsets = extractVector<int64_t>(op.getPosition());
1537
1538 if (llvm::any_of(insertOp.getStrides(), [](Attribute attr) {
1539 return attr.cast<IntegerAttr>().getInt() != 1;
1540 }))
1541 return Value();
1542 bool disjoint = false;
1543 SmallVector<int64_t, 4> offsetDiffs;
1544 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1545 int64_t start = insertOffsets[dim];
1546 int64_t size =
1547 (dim < insertRankDiff)
1548 ? 1
1549 : insertOp.getSourceVectorType().getDimSize(dim - insertRankDiff);
1550 int64_t end = start + size;
1551 int64_t offset = extractOffsets[dim];
1552 // Check if the start of the extract offset is in the interval inserted.
1553 if (start <= offset && offset < end) {
1554 if (dim >= insertRankDiff)
1555 offsetDiffs.push_back(offset - start);
1556 continue;
1557 }
1558 disjoint = true;
1559 break;
1560 }
1561 // The extract element chunk overlap with the vector inserted.
1562 if (!disjoint) {
1563 // If any of the inner dimensions are only partially inserted we have a
1564 // partial overlap.
1565 int64_t srcRankDiff =
1566 insertOp.getSourceVectorType().getRank() - destinationRank;
1567 for (int64_t i = 0; i < destinationRank; i++) {
1568 if (insertOp.getSourceVectorType().getDimSize(i + srcRankDiff) !=
1569 insertOp.getDestVectorType().getDimSize(i + srcRankDiff +
1570 insertRankDiff))
1571 return Value();
1572 }
1573 op.getVectorMutable().assign(insertOp.getSource());
1574 // OpBuilder is only used as a helper to build an I64ArrayAttr.
1575 OpBuilder b(op.getContext());
1576 op->setAttr(ExtractOp::getPositionAttrStrName(),
1577 b.getI64ArrayAttr(offsetDiffs));
1578 return op.getResult();
1579 }
1580 // If the chunk extracted is disjoint from the chunk inserted, keep
1581 // looking in the insert chain.
1582 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
1583 }
1584 return Value();
1585}
1586
1587OpFoldResult ExtractOp::fold(ArrayRef<Attribute>) {
1588 if (getPosition().empty())
1589 return getVector();
1590 if (succeeded(foldExtractOpFromExtractChain(*this)))
1591 return getResult();
1592 if (auto res = ExtractFromInsertTransposeChainState(*this).fold())
1593 return res;
1594 if (auto res = foldExtractFromBroadcast(*this))
1595 return res;
1596 if (auto res = foldExtractFromShapeCast(*this))
1597 return res;
1598 if (auto val = foldExtractFromExtractStrided(*this))
1599 return val;
1600 if (auto val = foldExtractStridedOpFromInsertChain(*this))
1601 return val;
1602 return OpFoldResult();
1603}
1604
1605namespace {
1606
1607// Pattern to rewrite a ExtractOp(Broadcast) -> Broadcast.
1608class ExtractOpFromBroadcast final : public OpRewritePattern<ExtractOp> {
1609public:
1610 using OpRewritePattern::OpRewritePattern;
1611
1612 LogicalResult matchAndRewrite(ExtractOp extractOp,
1613 PatternRewriter &rewriter) const override {
1614 Operation *defOp = extractOp.getVector().getDefiningOp();
1615 if (!defOp || !isa<vector::BroadcastOp, SplatOp>(defOp))
1616 return failure();
1617
1618 Value source = defOp->getOperand(0);
1619 if (extractOp.getType() == source.getType())
1620 return failure();
1621 auto getRank = [](Type type) {
1622 return type.isa<VectorType>() ? type.cast<VectorType>().getRank() : 0;
1623 };
1624 unsigned broadcastSrcRank = getRank(source.getType());
1625 unsigned extractResultRank = getRank(extractOp.getType());
1626 // We only consider the case where the rank of the source is less than or
1627 // equal to the rank of the extract dst. The other cases are handled in the
1628 // folding patterns.
1629 if (extractResultRank < broadcastSrcRank)
1630 return failure();
1631 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
1632 extractOp, extractOp.getType(), source);
1633 return success();
1634 }
1635};
1636
1637// Pattern to rewrite a ExtractOp(splat ConstantOp) -> ConstantOp.
1638class ExtractOpSplatConstantFolder final : public OpRewritePattern<ExtractOp> {
1639public:
1640 using OpRewritePattern::OpRewritePattern;
1641
1642 LogicalResult matchAndRewrite(ExtractOp extractOp,
1643 PatternRewriter &rewriter) const override {
1644 // Return if 'ExtractOp' operand is not defined by a splat vector
1645 // ConstantOp.
1646 Value sourceVector = extractOp.getVector();
1647 Attribute vectorCst;
1648 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1649 return failure();
1650 auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
1651 if (!splat)
1652 return failure();
1653 Attribute newAttr = splat.getSplatValue<Attribute>();
1654 if (auto vecDstType = extractOp.getType().dyn_cast<VectorType>())
1655 newAttr = DenseElementsAttr::get(vecDstType, newAttr);
1656 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1657 return success();
1658 }
1659};
1660
1661// Pattern to rewrite a ExtractOp(non-splat ConstantOp)[...] -> ConstantOp.
1662class ExtractOpNonSplatConstantFolder final
1663 : public OpRewritePattern<ExtractOp> {
1664public:
1665 using OpRewritePattern::OpRewritePattern;
1666
1667 LogicalResult matchAndRewrite(ExtractOp extractOp,
1668 PatternRewriter &rewriter) const override {
1669 // Return if 'ExtractOp' operand is not defined by a compatible vector
1670 // ConstantOp.
1671 Value sourceVector = extractOp.getVector();
1672 Attribute vectorCst;
1673 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
1674 return failure();
1675
1676 auto vecTy = sourceVector.getType().cast<VectorType>();
1677 if (vecTy.isScalable())
1678 return failure();
1679
1680 // The splat case is handled by `ExtractOpSplatConstantFolder`.
1681 auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
1682 if (!dense || dense.isSplat())
1683 return failure();
1684
1685 // Calculate the linearized position of the continuous chunk of elements to
1686 // extract.
1687 llvm::SmallVector<int64_t> completePositions(vecTy.getRank(), 0);
1688 copy(getI64SubArray(extractOp.getPosition()), completePositions.begin());
1689 int64_t elemBeginPosition =
1690 linearize(completePositions, computeStrides(vecTy.getShape()));
1691 auto denseValuesBegin = dense.value_begin<Attribute>() + elemBeginPosition;
1692
1693 Attribute newAttr;
1694 if (auto resVecTy = extractOp.getType().dyn_cast<VectorType>()) {
1695 SmallVector<Attribute> elementValues(
1696 denseValuesBegin, denseValuesBegin + resVecTy.getNumElements());
1697 newAttr = DenseElementsAttr::get(resVecTy, elementValues);
1698 } else {
1699 newAttr = *denseValuesBegin;
1700 }
1701
1702 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractOp, newAttr);
1703 return success();
1704 }
1705};
1706
1707} // namespace
1708
1709void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1710 MLIRContext *context) {
1711 results.add<ExtractOpSplatConstantFolder, ExtractOpNonSplatConstantFolder,
1712 ExtractOpFromBroadcast>(context);
1713}
1714
1715static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
1716 SmallVectorImpl<int64_t> &results) {
1717 for (auto attr : arrayAttr)
1718 results.push_back(attr.cast<IntegerAttr>().getInt());
1719}
1720
1721//===----------------------------------------------------------------------===//
1722// FmaOp
1723//===----------------------------------------------------------------------===//
1724
1725std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
1726 return llvm::to_vector<4>(getVectorType().getShape());
1727}
1728
1729//===----------------------------------------------------------------------===//
1730// BroadcastOp
1731//===----------------------------------------------------------------------===//
1732
1733/// Return the dimensions of the result vector that were formerly ones in the
1734/// source tensor and thus correspond to "dim-1" broadcasting.
1735static llvm::SetVector<int64_t>
1736computeBroadcastedUnitDims(ArrayRef<int64_t> srcShape,
1737 ArrayRef<int64_t> dstShape) {
1738 int64_t rankDiff = dstShape.size() - srcShape.size();
1739 int64_t dstDim = rankDiff;
1740 llvm::SetVector<int64_t> res;
1741 for (auto [s1, s2] :
1742 llvm::zip_equal(srcShape, dstShape.drop_front(rankDiff))) {
1743 if (s1 != s2) {
1744 assert(s1 == 1 && "expected dim-1 broadcasting")(static_cast <bool> (s1 == 1 && "expected dim-1 broadcasting"
) ? void (0) : __assert_fail ("s1 == 1 && \"expected dim-1 broadcasting\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1744, __extension__
__PRETTY_FUNCTION__))
;
1745 res.insert(dstDim);
1746 }
1747 ++dstDim;
1748 }
1749 return res;
1750}
1751
1752llvm::SetVector<int64_t> BroadcastOp::computeBroadcastedUnitDims() {
1753 // Scalar broadcast is without any unit dim broadcast.
1754 auto srcVectorType = getSourceType().dyn_cast<VectorType>();
1755 if (!srcVectorType)
1756 return {};
1757 return ::computeBroadcastedUnitDims(srcVectorType.getShape(),
1758 getVectorType().getShape());
1759}
1760
1761/// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
1762/// `broadcastedDims` dimensions in the dstShape are broadcasted.
1763/// This requires (and asserts) that the broadcast is free of dim-1
1764/// broadcasting.
1765/// Since vector.broadcast only allows expanding leading dimensions, an extra
1766/// vector.transpose may be inserted to make the broadcast possible.
1767/// `value`, `dstShape` and `broadcastedDims` must be properly specified or
1768/// the helper will assert. This means:
1769/// 1. `dstShape` must not be empty.
1770/// 2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
1771/// 2. `dstShape` trimmed of the dimensions specified in `broadcastedDims`
1772// must match the `value` shape.
1773Value BroadcastOp::createOrFoldBroadcastOp(
1774 OpBuilder &b, Value value, ArrayRef<int64_t> dstShape,
1775 const llvm::SetVector<int64_t> &broadcastedDims) {
1776 assert(!dstShape.empty() && "unexpected empty dst shape")(static_cast <bool> (!dstShape.empty() && "unexpected empty dst shape"
) ? void (0) : __assert_fail ("!dstShape.empty() && \"unexpected empty dst shape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1776, __extension__
__PRETTY_FUNCTION__))
;
1777
1778 // Well-formedness check.
1779 SmallVector<int64_t> checkShape;
1780 for (int i = 0, e = dstShape.size(); i < e; ++i) {
1781 if (broadcastedDims.contains(i))
1782 continue;
1783 checkShape.push_back(dstShape[i]);
1784 }
1785 assert(broadcastedDims.size() == dstShape.size() - checkShape.size() &&(static_cast <bool> (broadcastedDims.size() == dstShape
.size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to "
"destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__
__PRETTY_FUNCTION__))
1786 "ill-formed broadcastedDims contains values not confined to "(static_cast <bool> (broadcastedDims.size() == dstShape
.size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to "
"destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__
__PRETTY_FUNCTION__))
1787 "destVectorShape")(static_cast <bool> (broadcastedDims.size() == dstShape
.size() - checkShape.size() && "ill-formed broadcastedDims contains values not confined to "
"destVectorShape") ? void (0) : __assert_fail ("broadcastedDims.size() == dstShape.size() - checkShape.size() && \"ill-formed broadcastedDims contains values not confined to \" \"destVectorShape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1787, __extension__
__PRETTY_FUNCTION__))
;
1788
1789 Location loc = value.getLoc();
1790 Type elementType = getElementTypeOrSelf(value.getType());
1791 VectorType srcVectorType = value.getType().dyn_cast<VectorType>();
1792 VectorType dstVectorType = VectorType::get(dstShape, elementType);
1793
1794 // Step 2. If scalar -> dstShape broadcast, just do it.
1795 if (!srcVectorType) {
1796 assert(checkShape.empty() &&(static_cast <bool> (checkShape.empty() && "ill-formed createOrFoldBroadcastOp arguments"
) ? void (0) : __assert_fail ("checkShape.empty() && \"ill-formed createOrFoldBroadcastOp arguments\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1797, __extension__
__PRETTY_FUNCTION__))
1797 "ill-formed createOrFoldBroadcastOp arguments")(static_cast <bool> (checkShape.empty() && "ill-formed createOrFoldBroadcastOp arguments"
) ? void (0) : __assert_fail ("checkShape.empty() && \"ill-formed createOrFoldBroadcastOp arguments\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1797, __extension__
__PRETTY_FUNCTION__))
;
1798 return b.createOrFold<vector::BroadcastOp>(loc, dstVectorType, value);
1799 }
1800
1801 assert(srcVectorType.getShape().equals(checkShape) &&(static_cast <bool> (srcVectorType.getShape().equals(checkShape
) && "ill-formed createOrFoldBroadcastOp arguments") ?
void (0) : __assert_fail ("srcVectorType.getShape().equals(checkShape) && \"ill-formed createOrFoldBroadcastOp arguments\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1802, __extension__
__PRETTY_FUNCTION__))
1802 "ill-formed createOrFoldBroadcastOp arguments")(static_cast <bool> (srcVectorType.getShape().equals(checkShape
) && "ill-formed createOrFoldBroadcastOp arguments") ?
void (0) : __assert_fail ("srcVectorType.getShape().equals(checkShape) && \"ill-formed createOrFoldBroadcastOp arguments\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1802, __extension__
__PRETTY_FUNCTION__))
;
1803
1804 // Step 3. Since vector.broadcast only allows creating leading dims,
1805 // vector -> dstShape broadcast may require a transpose.
1806 // Traverse the dims in order and construct:
1807 // 1. The leading entries of the broadcastShape that is guaranteed to be
1808 // achievable by a simple broadcast.
1809 // 2. The induced permutation for the subsequent vector.transpose that will
1810 // bring us from `broadcastShape` back to he desired `dstShape`.
1811 // If the induced permutation is not the identity, create a vector.transpose.
1812 SmallVector<int64_t> broadcastShape, permutation(dstShape.size(), -1);
1813 broadcastShape.reserve(dstShape.size());
1814 // Consider the example:
1815 // srcShape = 2x4
1816 // dstShape = 1x2x3x4x5
1817 // broadcastedDims = [0, 2, 4]
1818 //
1819 // We want to build:
1820 // broadcastShape = 1x3x5x2x4
1821 // permutation = [0, 2, 4, 1, 3]
1822 // ---V--- -----V-----
1823 // leading broadcast part src shape part
1824 //
1825 // Note that the trailing dims of broadcastShape are exactly the srcShape
1826 // by construction.
1827 // nextSrcShapeDim is used to keep track of where in the permutation the
1828 // "src shape part" occurs.
1829 int64_t nextSrcShapeDim = broadcastedDims.size();
1830 for (int64_t i = 0, e = dstShape.size(); i < e; ++i) {
1831 if (broadcastedDims.contains(i)) {
1832 // 3.a. For each dim in the dst shape, if it is a broadcasted dim,
1833 // bring it to the head of the broadcastShape.
1834 // It will need to be permuted back from `broadcastShape.size() - 1` into
1835 // position `i`.
1836 broadcastShape.push_back(dstShape[i]);
1837 permutation[i] = broadcastShape.size() - 1;
1838 } else {
1839 // 3.b. Otherwise, the dim is not broadcasted, it comes from the src
1840 // shape and needs to be permuted into position `i`.
1841 // Don't touch `broadcastShape` here, the whole srcShape will be
1842 // appended after.
1843 permutation[i] = nextSrcShapeDim++;
1844 }
1845 }
1846 // 3.c. Append the srcShape.
1847 llvm::append_range(broadcastShape, srcVectorType.getShape());
1848
1849 // Ensure there are no dim-1 broadcasts.
1850 assert(::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape)(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType
.getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast"
) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__
__PRETTY_FUNCTION__))
1851 .empty() &&(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType
.getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast"
) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__
__PRETTY_FUNCTION__))
1852 "unexpected dim-1 broadcast")(static_cast <bool> (::computeBroadcastedUnitDims(srcVectorType
.getShape(), broadcastShape) .empty() && "unexpected dim-1 broadcast"
) ? void (0) : __assert_fail ("::computeBroadcastedUnitDims(srcVectorType.getShape(), broadcastShape) .empty() && \"unexpected dim-1 broadcast\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1852, __extension__
__PRETTY_FUNCTION__))
;
1853
1854 VectorType broadcastType = VectorType::get(broadcastShape, elementType);
1855 assert(vector::isBroadcastableTo(value.getType(), broadcastType) ==(static_cast <bool> (vector::isBroadcastableTo(value.getType
(), broadcastType) == vector::BroadcastableToResult::Success &&
"must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__
__PRETTY_FUNCTION__))
1856 vector::BroadcastableToResult::Success &&(static_cast <bool> (vector::isBroadcastableTo(value.getType
(), broadcastType) == vector::BroadcastableToResult::Success &&
"must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__
__PRETTY_FUNCTION__))
1857 "must be broadcastable")(static_cast <bool> (vector::isBroadcastableTo(value.getType
(), broadcastType) == vector::BroadcastableToResult::Success &&
"must be broadcastable") ? void (0) : __assert_fail ("vector::isBroadcastableTo(value.getType(), broadcastType) == vector::BroadcastableToResult::Success && \"must be broadcastable\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1857, __extension__
__PRETTY_FUNCTION__))
;
1858 Value res = b.createOrFold<vector::BroadcastOp>(loc, broadcastType, value);
1859 // Step 4. If we find any dimension that indeed needs to be permuted,
1860 // immediately return a new vector.transpose.
1861 for (int64_t i = 0, e = permutation.size(); i < e; ++i)
1862 if (permutation[i] != i)
1863 return b.createOrFold<vector::TransposeOp>(loc, res, permutation);
1864 // Otherwise return res.
1865 return res;
1866}
1867
1868BroadcastableToResult
1869mlir::vector::isBroadcastableTo(Type srcType, VectorType dstVectorType,
1870 std::pair<int, int> *mismatchingDims) {
1871 // Broadcast scalar to vector of the same element type.
1872 if (srcType.isIntOrIndexOrFloat() && dstVectorType &&
1873 getElementTypeOrSelf(srcType) == getElementTypeOrSelf(dstVectorType))
1874 return BroadcastableToResult::Success;
1875 // From now on, only vectors broadcast.
1876 VectorType srcVectorType = srcType.dyn_cast<VectorType>();
1877 if (!srcVectorType)
1878 return BroadcastableToResult::SourceTypeNotAVector;
1879
1880 int64_t srcRank = srcVectorType.getRank();
1881 int64_t dstRank = dstVectorType.getRank();
1882 if (srcRank > dstRank)
1883 return BroadcastableToResult::SourceRankHigher;
1884 // Source has an exact match or singleton value for all trailing dimensions
1885 // (all leading dimensions are simply duplicated).
1886 int64_t lead = dstRank - srcRank;
1887 for (int64_t r = 0; r < srcRank; ++r) {
1888 int64_t srcDim = srcVectorType.getDimSize(r);
1889 int64_t dstDim = dstVectorType.getDimSize(lead + r);
1890 if (srcDim != 1 && srcDim != dstDim) {
1891 if (mismatchingDims) {
1892 mismatchingDims->first = srcDim;
1893 mismatchingDims->second = dstDim;
1894 }
1895 return BroadcastableToResult::DimensionMismatch;
1896 }
1897 }
1898
1899 return BroadcastableToResult::Success;
1900}
1901
1902LogicalResult BroadcastOp::verify() {
1903 std::pair<int, int> mismatchingDims;
1904 BroadcastableToResult res =
1905 isBroadcastableTo(getSourceType(), getVectorType(), &mismatchingDims);
1906 if (res == BroadcastableToResult::Success)
1907 return success();
1908 if (res == BroadcastableToResult::SourceRankHigher)
1909 return emitOpError("source rank higher than destination rank");
1910 if (res == BroadcastableToResult::DimensionMismatch)
1911 return emitOpError("dimension mismatch (")
1912 << mismatchingDims.first << " vs. " << mismatchingDims.second << ")";
1913 if (res == BroadcastableToResult::SourceTypeNotAVector)
1914 return emitOpError("source type is not a vector");
1915 llvm_unreachable("unexpected vector.broadcast op error")::llvm::llvm_unreachable_internal("unexpected vector.broadcast op error"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 1915)
;
1916}
1917
1918OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
1919 if (getSourceType() == getVectorType())
1920 return getSource();
1921 if (!operands[0])
1922 return {};
1923 auto vectorType = getVectorType();
1924 if (operands[0].isa<IntegerAttr, FloatAttr>())
1925 return DenseElementsAttr::get(vectorType, operands[0]);
1926 if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
1927 return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
1928 return {};
1929}
1930
1931namespace {
1932
1933// Fold broadcast1(broadcast2(x)) into broadcast1(x).
1934struct BroadcastFolder : public OpRewritePattern<BroadcastOp> {
1935 using OpRewritePattern::OpRewritePattern;
1936
1937 LogicalResult matchAndRewrite(BroadcastOp broadcastOp,
1938 PatternRewriter &rewriter) const override {
1939 auto srcBroadcast = broadcastOp.getSource().getDefiningOp<BroadcastOp>();
1940 if (!srcBroadcast)
1941 return failure();
1942 rewriter.replaceOpWithNewOp<BroadcastOp>(
1943 broadcastOp, broadcastOp.getVectorType(), srcBroadcast.getSource());
1944 return success();
1945 }
1946};
1947} // namespace
1948
1949void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1950 MLIRContext *context) {
1951 // BroadcastToShapeCast is not a default canonicalization, it is opt-in by
1952 // calling `populateCastAwayVectorLeadingOneDimPatterns`
1953 results.add<BroadcastFolder>(context);
1954}
1955
1956//===----------------------------------------------------------------------===//
1957// ShuffleOp
1958//===----------------------------------------------------------------------===//
1959
1960void ShuffleOp::build(OpBuilder &builder, OperationState &result, Value v1,
1961 Value v2, ArrayRef<int64_t> mask) {
1962 build(builder, result, v1, v2, getVectorSubscriptAttr(builder, mask));
1963}
1964
1965LogicalResult ShuffleOp::verify() {
1966 VectorType resultType = getVectorType();
1967 VectorType v1Type = getV1VectorType();
1968 VectorType v2Type = getV2VectorType();
1969 // Verify ranks.
1970 int64_t resRank = resultType.getRank();
1971 int64_t v1Rank = v1Type.getRank();
1972 int64_t v2Rank = v2Type.getRank();
1973 bool wellFormed0DCase = v1Rank == 0 && v2Rank == 0 && resRank == 1;
1974 bool wellFormedNDCase = v1Rank == resRank && v2Rank == resRank;
1975 if (!wellFormed0DCase && !wellFormedNDCase)
1976 return emitOpError("rank mismatch");
1977
1978 // Verify all but leading dimension sizes.
1979 for (int64_t r = 1; r < v1Rank; ++r) {
1980 int64_t resDim = resultType.getDimSize(r);
1981 int64_t v1Dim = v1Type.getDimSize(r);
1982 int64_t v2Dim = v2Type.getDimSize(r);
1983 if (resDim != v1Dim || v1Dim != v2Dim)
1984 return emitOpError("dimension mismatch");
1985 }
1986 // Verify mask length.
1987 auto maskAttr = getMask().getValue();
1988 int64_t maskLength = maskAttr.size();
1989 if (maskLength <= 0)
1990 return emitOpError("invalid mask length");
1991 if (maskLength != resultType.getDimSize(0))
1992 return emitOpError("mask length mismatch");
1993 // Verify all indices.
1994 int64_t indexSize = (v1Type.getRank() == 0 ? 1 : v1Type.getDimSize(0)) +
1995 (v2Type.getRank() == 0 ? 1 : v2Type.getDimSize(0));
1996 for (const auto &en : llvm::enumerate(maskAttr)) {
1997 auto attr = en.value().dyn_cast<IntegerAttr>();
1998 if (!attr || attr.getInt() < 0 || attr.getInt() >= indexSize)
1999 return emitOpError("mask index #") << (en.index() + 1) << " out of range";
2000 }
2001 return success();
2002}
2003
2004LogicalResult
2005ShuffleOp::inferReturnTypes(MLIRContext *, std::optional<Location>,
2006 ValueRange operands, DictionaryAttr attributes,
2007 RegionRange,
2008 SmallVectorImpl<Type> &inferredReturnTypes) {
2009 ShuffleOp::Adaptor op(operands, attributes);
2010 auto v1Type = op.getV1().getType().cast<VectorType>();
2011 auto v1Rank = v1Type.getRank();
2012 // Construct resulting type: leading dimension matches mask
2013 // length, all trailing dimensions match the operands.
2014 SmallVector<int64_t, 4> shape;
2015 shape.reserve(v1Rank);
2016 shape.push_back(std::max<size_t>(1, op.getMask().size()));
2017 // In the 0-D case there is no trailing shape to append.
2018 if (v1Rank > 0)
2019 llvm::append_range(shape, v1Type.getShape().drop_front());
2020 inferredReturnTypes.push_back(
2021 VectorType::get(shape, v1Type.getElementType()));
2022 return success();
2023}
2024
2025static bool isStepIndexArray(ArrayAttr idxArr, uint64_t begin, size_t width) {
2026 uint64_t expected = begin;
2027 return idxArr.size() == width &&
2028 llvm::all_of(idxArr.getAsValueRange<IntegerAttr>(),
2029 [&expected](auto attr) {
2030 return attr.getZExtValue() == expected++;
2031 });
2032}
2033
2034OpFoldResult vector::ShuffleOp::fold(ArrayRef<Attribute> operands) {
2035 VectorType v1Type = getV1VectorType();
2036 // For consistency: 0-D shuffle return type is 1-D, this cannot be a folding
2037 // but must be a canonicalization into a vector.broadcast.
2038 if (v1Type.getRank() == 0)
2039 return {};
2040
2041 // fold shuffle V1, V2, [0, 1, 2, 3] : <4xi32>, <2xi32> -> V1
2042 if (!v1Type.isScalable() &&
2043 isStepIndexArray(getMask(), 0, v1Type.getDimSize(0)))
2044 return getV1();
2045 // fold shuffle V1, V2, [4, 5] : <4xi32>, <2xi32> -> V2
2046 if (!getV1VectorType().isScalable() && !getV2VectorType().isScalable() &&
2047 isStepIndexArray(getMask(), getV1VectorType().getDimSize(0),
2048 getV2VectorType().getDimSize(0)))
2049 return getV2();
2050
2051 Attribute lhs = operands.front(), rhs = operands.back();
2052 if (!lhs || !rhs)
2053 return {};
2054
2055 auto lhsType = lhs.cast<DenseElementsAttr>().getType().cast<VectorType>();
2056 // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
2057 // manipulation.
2058 if (lhsType.getRank() != 1)
2059 return {};
2060 int64_t lhsSize = lhsType.getDimSize(0);
2061
2062 SmallVector<Attribute> results;
2063 auto lhsElements = lhs.cast<DenseElementsAttr>().getValues<Attribute>();
2064 auto rhsElements = rhs.cast<DenseElementsAttr>().getValues<Attribute>();
2065 for (const auto &index : this->getMask().getAsValueRange<IntegerAttr>()) {
2066 int64_t i = index.getZExtValue();
2067 if (i >= lhsSize) {
2068 results.push_back(rhsElements[i - lhsSize]);
2069 } else {
2070 results.push_back(lhsElements[i]);
2071 }
2072 }
2073
2074 return DenseElementsAttr::get(getVectorType(), results);
2075}
2076
2077namespace {
2078
2079// Pattern to rewrite a 0-D shuffle with [0] or [1] mask returning a 1-D vector
2080// to a broadcast.
2081struct Canonicalize0DShuffleOp : public OpRewritePattern<ShuffleOp> {
2082 using OpRewritePattern::OpRewritePattern;
2083
2084 LogicalResult matchAndRewrite(ShuffleOp shuffleOp,
2085 PatternRewriter &rewriter) const override {
2086 VectorType v1VectorType = shuffleOp.getV1VectorType();
2087 ArrayAttr mask = shuffleOp.getMask();
2088 if (v1VectorType.getRank() > 0)
2089 return failure();
2090 if (mask.size() != 1)
2091 return failure();
2092 Type resType = VectorType::Builder(v1VectorType).setShape({1});
2093 if (mask[0].cast<IntegerAttr>().getInt() == 0)
2094 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2095 shuffleOp.getV1());
2096 else
2097 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(shuffleOp, resType,
2098 shuffleOp.getV2());
2099 return success();
2100 }
2101};
2102
2103/// Pattern to rewrite a ShuffleOp(SplatOp, SplatOp) to SplatOp.
2104class ShuffleSplat final : public OpRewritePattern<ShuffleOp> {
2105public:
2106 using OpRewritePattern::OpRewritePattern;
2107
2108 LogicalResult matchAndRewrite(ShuffleOp op,
2109 PatternRewriter &rewriter) const override {
2110 auto v1Splat = op.getV1().getDefiningOp<SplatOp>();
2111 auto v2Splat = op.getV2().getDefiningOp<SplatOp>();
2112
2113 if (!v1Splat || !v2Splat)
2114 return failure();
2115
2116 if (v1Splat.getInput() != v2Splat.getInput())
2117 return failure();
2118
2119 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), v1Splat.getInput());
2120 return success();
2121 }
2122};
2123
2124} // namespace
2125
2126void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results,
2127 MLIRContext *context) {
2128 results.add<ShuffleSplat, Canonicalize0DShuffleOp>(context);
2129}
2130
2131//===----------------------------------------------------------------------===//
2132// InsertElementOp
2133//===----------------------------------------------------------------------===//
2134
2135void InsertElementOp::build(OpBuilder &builder, OperationState &result,
2136 Value source, Value dest) {
2137 build(builder, result, source, dest, {});
2138}
2139
2140LogicalResult InsertElementOp::verify() {
2141 auto dstVectorType = getDestVectorType();
2142 if (dstVectorType.getRank() == 0) {
2143 if (getPosition())
2144 return emitOpError("expected position to be empty with 0-D vector");
2145 return success();
2146 }
2147 if (dstVectorType.getRank() != 1)
2148 return emitOpError("unexpected >1 vector rank");
2149 if (!getPosition())
2150 return emitOpError("expected position for 1-D vector");
2151 return success();
2152}
2153
2154OpFoldResult vector::InsertElementOp::fold(ArrayRef<Attribute> operands) {
2155 // Skip the 0-D vector here.
2156 if (operands.size() < 3)
2157 return {};
2158
2159 Attribute src = operands[0];
2160 Attribute dst = operands[1];
2161 Attribute pos = operands[2];
2162 if (!src || !dst || !pos)
2163 return {};
2164
2165 auto dstElements = dst.cast<DenseElementsAttr>().getValues<Attribute>();
2166
2167 SmallVector<Attribute> results(dstElements);
2168
2169 auto attr = pos.dyn_cast<IntegerAttr>();
2170 uint64_t posIdx = attr.getInt();
2171
2172 results[posIdx] = src;
2173
2174 return DenseElementsAttr::get(getDestVectorType(), results);
2175}
2176
2177//===----------------------------------------------------------------------===//
2178// InsertOp
2179//===----------------------------------------------------------------------===//
2180
2181void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
2182 Value dest, ArrayRef<int64_t> position) {
2183 result.addOperands({source, dest});
2184 auto positionAttr = getVectorSubscriptAttr(builder, position);
2185 result.addTypes(dest.getType());
2186 result.addAttribute(getPositionAttrStrName(), positionAttr);
2187}
2188
2189// Convenience builder which assumes the values are constant indices.
2190void InsertOp::build(OpBuilder &builder, OperationState &result, Value source,
2191 Value dest, ValueRange position) {
2192 SmallVector<int64_t, 4> positionConstants =
2193 llvm::to_vector<4>(llvm::map_range(position, [](Value pos) {
2194 return pos.getDefiningOp<arith::ConstantIndexOp>().value();
2195 }));
2196 build(builder, result, source, dest, positionConstants);
2197}
2198
2199LogicalResult InsertOp::verify() {
2200 auto positionAttr = getPosition().getValue();
2201 auto destVectorType = getDestVectorType();
2202 if (positionAttr.size() > static_cast<unsigned>(destVectorType.getRank()))
2203 return emitOpError(
2204 "expected position attribute of rank smaller than dest vector rank");
2205 auto srcVectorType = getSourceType().dyn_cast<VectorType>();
2206 if (srcVectorType &&
2207 (static_cast<unsigned>(srcVectorType.getRank()) + positionAttr.size() !=
2208 static_cast<unsigned>(destVectorType.getRank())))
2209 return emitOpError("expected position attribute rank + source rank to "
2210 "match dest vector rank");
2211 if (!srcVectorType &&
2212 (positionAttr.size() != static_cast<unsigned>(destVectorType.getRank())))
2213 return emitOpError(
2214 "expected position attribute rank to match the dest vector rank");
2215 for (const auto &en : llvm::enumerate(positionAttr)) {
2216 auto attr = en.value().dyn_cast<IntegerAttr>();
2217 if (!attr || attr.getInt() < 0 ||
2218 attr.getInt() >= destVectorType.getDimSize(en.index()))
2219 return emitOpError("expected position attribute #")
2220 << (en.index() + 1)
2221 << " to be a non-negative integer smaller than the corresponding "
2222 "dest vector dimension";
2223 }
2224 return success();
2225}
2226
2227namespace {
2228
2229// If insertOp is only inserting unit dimensions it can be transformed to a
2230// broadcast.
2231class InsertToBroadcast final : public OpRewritePattern<InsertOp> {
2232public:
2233 using OpRewritePattern::OpRewritePattern;
2234
2235 LogicalResult matchAndRewrite(InsertOp insertOp,
2236 PatternRewriter &rewriter) const override {
2237 auto srcVecType = insertOp.getSourceType().dyn_cast<VectorType>();
2238 if (!srcVecType || insertOp.getDestVectorType().getNumElements() !=
2239 srcVecType.getNumElements())
2240 return failure();
2241 rewriter.replaceOpWithNewOp<BroadcastOp>(
2242 insertOp, insertOp.getDestVectorType(), insertOp.getSource());
2243 return success();
2244 }
2245};
2246
2247/// Pattern to rewrite a InsertOp(SplatOp, SplatOp) to SplatOp.
2248class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
2249public:
2250 using OpRewritePattern::OpRewritePattern;
2251
2252 LogicalResult matchAndRewrite(InsertOp op,
2253 PatternRewriter &rewriter) const override {
2254 auto srcSplat = op.getSource().getDefiningOp<SplatOp>();
2255 auto dstSplat = op.getDest().getDefiningOp<SplatOp>();
2256
2257 if (!srcSplat || !dstSplat)
2258 return failure();
2259
2260 if (srcSplat.getInput() != dstSplat.getInput())
2261 return failure();
2262
2263 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), srcSplat.getInput());
2264 return success();
2265 }
2266};
2267
2268// Pattern to rewrite a InsertOp(ConstantOp into ConstantOp) -> ConstantOp.
2269class InsertOpConstantFolder final : public OpRewritePattern<InsertOp> {
2270public:
2271 using OpRewritePattern::OpRewritePattern;
2272
2273 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2274 // unless the source vector constant has a single use.
2275 static constexpr int64_t vectorSizeFoldThreshold = 256;
2276
2277 LogicalResult matchAndRewrite(InsertOp op,
2278 PatternRewriter &rewriter) const override {
2279 // Return if 'InsertOp' operand is not defined by a compatible vector
2280 // ConstantOp.
2281 TypedValue<VectorType> destVector = op.getDest();
2282 Attribute vectorDestCst;
2283 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2284 return failure();
2285
2286 VectorType destTy = destVector.getType();
2287 if (destTy.isScalable())
2288 return failure();
2289
2290 // Make sure we do not create too many large constants.
2291 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2292 !destVector.hasOneUse())
2293 return failure();
2294
2295 auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
2296
2297 Value sourceValue = op.getSource();
2298 Attribute sourceCst;
2299 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2300 return failure();
2301
2302 // Calculate the linearized position of the continuous chunk of elements to
2303 // insert.
2304 llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
2305 copy(getI64SubArray(op.getPosition()), completePositions.begin());
2306 int64_t insertBeginPosition =
2307 linearize(completePositions, computeStrides(destTy.getShape()));
2308
2309 SmallVector<Attribute> insertedValues;
2310 if (auto denseSource = sourceCst.dyn_cast<DenseElementsAttr>())
2311 llvm::append_range(insertedValues, denseSource.getValues<Attribute>());
2312 else
2313 insertedValues.push_back(sourceCst);
2314
2315 auto allValues = llvm::to_vector(denseDest.getValues<Attribute>());
2316 copy(insertedValues, allValues.begin() + insertBeginPosition);
2317 auto newAttr = DenseElementsAttr::get(destTy, allValues);
2318
2319 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2320 return success();
2321 }
2322};
2323
2324} // namespace
2325
2326void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
2327 MLIRContext *context) {
2328 results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
2329 InsertOpConstantFolder>(context);
2330}
2331
2332// Eliminates insert operations that produce values identical to their source
2333// value. This happens when the source and destination vectors have identical
2334// sizes.
2335OpFoldResult vector::InsertOp::fold(ArrayRef<Attribute> operands) {
2336 if (getPosition().empty())
2337 return getSource();
2338 return {};
2339}
2340
2341//===----------------------------------------------------------------------===//
2342// InsertStridedSliceOp
2343//===----------------------------------------------------------------------===//
2344
2345void InsertStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2346 Value source, Value dest,
2347 ArrayRef<int64_t> offsets,
2348 ArrayRef<int64_t> strides) {
2349 result.addOperands({source, dest});
2350 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2351 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2352 result.addTypes(dest.getType());
2353 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2354 result.addAttribute(getStridesAttrStrName(), stridesAttr);
2355}
2356
2357// TODO: Should be moved to Tablegen ConfinedAttr attributes.
2358template <typename OpType>
2359static LogicalResult isIntegerArrayAttrSmallerThanShape(OpType op,
2360 ArrayAttr arrayAttr,
2361 ArrayRef<int64_t> shape,
2362 StringRef attrName) {
2363 if (arrayAttr.size() > shape.size())
2364 return op.emitOpError("expected ")
2365 << attrName << " attribute of rank smaller than vector rank";
2366 return success();
2367}
2368
2369// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2370// interval. If `halfOpen` is true then the admissible interval is [min, max).
2371// Otherwise, the admissible interval is [min, max].
2372template <typename OpType>
2373static LogicalResult
2374isIntegerArrayAttrConfinedToRange(OpType op, ArrayAttr arrayAttr, int64_t min,
2375 int64_t max, StringRef attrName,
2376 bool halfOpen = true) {
2377 for (auto attr : arrayAttr) {
2378 auto val = attr.cast<IntegerAttr>().getInt();
2379 auto upper = max;
2380 if (!halfOpen)
2381 upper += 1;
2382 if (val < min || val >= upper)
2383 return op.emitOpError("expected ") << attrName << " to be confined to ["
2384 << min << ", " << upper << ")";
2385 }
2386 return success();
2387}
2388
2389// Returns true if all integers in `arrayAttr` are in the half-open [min, max}
2390// interval. If `halfOpen` is true then the admissible interval is [min, max).
2391// Otherwise, the admissible interval is [min, max].
2392template <typename OpType>
2393static LogicalResult
2394isIntegerArrayAttrConfinedToShape(OpType op, ArrayAttr arrayAttr,
2395 ArrayRef<int64_t> shape, StringRef attrName,
2396 bool halfOpen = true, int64_t min = 0) {
2397 for (auto [index, attrDimPair] :
2398 llvm::enumerate(llvm::zip_first(arrayAttr, shape))) {
2399 int64_t val =
2400 std::get<0>(attrDimPair).template cast<IntegerAttr>().getInt();
2401 int64_t max = std::get<1>(attrDimPair);
2402 if (!halfOpen)
2403 max += 1;
2404 if (val < min || val >= max)
2405 return op.emitOpError("expected ")
2406 << attrName << " dimension " << index << " to be confined to ["
2407 << min << ", " << max << ")";
2408 }
2409 return success();
2410}
2411
2412// Returns true if all integers in `arrayAttr` are in the interval [min, max}.
2413// interval. If `halfOpen` is true then the admissible interval is [min, max).
2414// Otherwise, the admissible interval is [min, max].
2415template <typename OpType>
2416static LogicalResult isSumOfIntegerArrayAttrConfinedToShape(
2417 OpType op, ArrayAttr arrayAttr1, ArrayAttr arrayAttr2,
2418 ArrayRef<int64_t> shape, StringRef attrName1, StringRef attrName2,
2419 bool halfOpen = true, int64_t min = 1) {
2420 assert(arrayAttr1.size() <= shape.size())(static_cast <bool> (arrayAttr1.size() <= shape.size
()) ? void (0) : __assert_fail ("arrayAttr1.size() <= shape.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2420, __extension__
__PRETTY_FUNCTION__))
;
2421 assert(arrayAttr2.size() <= shape.size())(static_cast <bool> (arrayAttr2.size() <= shape.size
()) ? void (0) : __assert_fail ("arrayAttr2.size() <= shape.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2421, __extension__
__PRETTY_FUNCTION__))
;
2422 for (auto [index, it] :
2423 llvm::enumerate(llvm::zip(arrayAttr1, arrayAttr2, shape))) {
2424 auto val1 = std::get<0>(it).template cast<IntegerAttr>().getInt();
2425 auto val2 = std::get<1>(it).template cast<IntegerAttr>().getInt();
2426 int64_t max = std::get<2>(it);
2427 if (!halfOpen)
2428 max += 1;
2429 if (val1 + val2 < 0 || val1 + val2 >= max)
2430 return op.emitOpError("expected sum(")
2431 << attrName1 << ", " << attrName2 << ") dimension " << index
2432 << " to be confined to [" << min << ", " << max << ")";
2433 }
2434 return success();
2435}
2436
2437static ArrayAttr makeI64ArrayAttr(ArrayRef<int64_t> values,
2438 MLIRContext *context) {
2439 auto attrs = llvm::map_range(values, [context](int64_t v) -> Attribute {
2440 return IntegerAttr::get(IntegerType::get(context, 64), APInt(64, v));
2441 });
2442 return ArrayAttr::get(context, llvm::to_vector<8>(attrs));
2443}
2444
2445LogicalResult InsertStridedSliceOp::verify() {
2446 auto sourceVectorType = getSourceVectorType();
2447 auto destVectorType = getDestVectorType();
2448 auto offsets = getOffsetsAttr();
2449 auto strides = getStridesAttr();
2450 if (offsets.size() != static_cast<unsigned>(destVectorType.getRank()))
2451 return emitOpError(
2452 "expected offsets of same size as destination vector rank");
2453 if (strides.size() != static_cast<unsigned>(sourceVectorType.getRank()))
2454 return emitOpError("expected strides of same size as source vector rank");
2455 if (sourceVectorType.getRank() > destVectorType.getRank())
2456 return emitOpError(
2457 "expected source rank to be smaller than destination rank");
2458
2459 auto sourceShape = sourceVectorType.getShape();
2460 auto destShape = destVectorType.getShape();
2461 SmallVector<int64_t, 4> sourceShapeAsDestShape(
2462 destShape.size() - sourceShape.size(), 0);
2463 sourceShapeAsDestShape.append(sourceShape.begin(), sourceShape.end());
2464 auto offName = InsertStridedSliceOp::getOffsetsAttrName();
2465 auto stridesName = InsertStridedSliceOp::getStridesAttrName();
2466 if (failed(isIntegerArrayAttrConfinedToShape(*this, offsets, destShape,
2467 offName)) ||
2468 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2469 stridesName,
2470 /*halfOpen=*/false)) ||
2471 failed(isSumOfIntegerArrayAttrConfinedToShape(
2472 *this, offsets,
2473 makeI64ArrayAttr(sourceShapeAsDestShape, getContext()), destShape,
2474 offName, "source vector shape",
2475 /*halfOpen=*/false, /*min=*/1)))
2476 return failure();
2477
2478 return success();
2479}
2480
2481namespace {
2482/// Pattern to rewrite an InsertStridedSliceOp(SplatOp(X):src_type,
2483/// SplatOp(X):dst_type) to SplatOp(X):dst_type.
2484class FoldInsertStridedSliceSplat final
2485 : public OpRewritePattern<InsertStridedSliceOp> {
2486public:
2487 using OpRewritePattern::OpRewritePattern;
2488
2489 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2490 PatternRewriter &rewriter) const override {
2491 auto srcSplatOp =
2492 insertStridedSliceOp.getSource().getDefiningOp<vector::SplatOp>();
2493 auto destSplatOp =
2494 insertStridedSliceOp.getDest().getDefiningOp<vector::SplatOp>();
2495
2496 if (!srcSplatOp || !destSplatOp)
2497 return failure();
2498
2499 if (srcSplatOp.getInput() != destSplatOp.getInput())
2500 return failure();
2501
2502 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2503 return success();
2504 }
2505};
2506
2507/// Pattern to rewrite an InsertStridedSliceOp(ExtractStridedSliceOp(dst), dst)
2508/// to dst.
2509class FoldInsertStridedSliceOfExtract final
2510 : public OpRewritePattern<InsertStridedSliceOp> {
2511public:
2512 using OpRewritePattern::OpRewritePattern;
2513
2514 LogicalResult matchAndRewrite(InsertStridedSliceOp insertStridedSliceOp,
2515 PatternRewriter &rewriter) const override {
2516 auto extractStridedSliceOp =
2517 insertStridedSliceOp.getSource()
2518 .getDefiningOp<vector::ExtractStridedSliceOp>();
2519
2520 if (!extractStridedSliceOp)
2521 return failure();
2522
2523 if (extractStridedSliceOp.getOperand() != insertStridedSliceOp.getDest())
2524 return failure();
2525
2526 // Check if have the same strides and offsets.
2527 if (extractStridedSliceOp.getStrides() !=
2528 insertStridedSliceOp.getStrides() ||
2529 extractStridedSliceOp.getOffsets() != insertStridedSliceOp.getOffsets())
2530 return failure();
2531
2532 rewriter.replaceOp(insertStridedSliceOp, insertStridedSliceOp.getDest());
2533 return success();
2534 }
2535};
2536
2537// Pattern to rewrite an InsertStridedSliceOp(ConstantOp into ConstantOp) ->
2538// ConstantOp.
2539class InsertStridedSliceConstantFolder final
2540 : public OpRewritePattern<InsertStridedSliceOp> {
2541public:
2542 using OpRewritePattern::OpRewritePattern;
2543
2544 // Do not create constants with more than `vectorSizeFoldThreashold` elements,
2545 // unless the source vector constant has a single use.
2546 static constexpr int64_t vectorSizeFoldThreshold = 256;
2547
2548 LogicalResult matchAndRewrite(InsertStridedSliceOp op,
2549 PatternRewriter &rewriter) const override {
2550 // Return if 'InsertOp' operand is not defined by a compatible vector
2551 // ConstantOp.
2552 TypedValue<VectorType> destVector = op.getDest();
2553 Attribute vectorDestCst;
2554 if (!matchPattern(destVector, m_Constant(&vectorDestCst)))
2555 return failure();
2556
2557 VectorType destTy = destVector.getType();
2558 if (destTy.isScalable())
2559 return failure();
2560
2561 // Make sure we do not create too many large constants.
2562 if (destTy.getNumElements() > vectorSizeFoldThreshold &&
2563 !destVector.hasOneUse())
2564 return failure();
2565
2566 auto denseDest = vectorDestCst.cast<DenseElementsAttr>();
2567
2568 TypedValue<VectorType> sourceValue = op.getSource();
2569 Attribute sourceCst;
2570 if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
2571 return failure();
2572
2573 // TODO: Handle non-unit strides when they become available.
2574 if (op.hasNonUnitStrides())
2575 return failure();
2576
2577 VectorType sliceVecTy = sourceValue.getType();
2578 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
2579 int64_t rankDifference = destTy.getRank() - sliceVecTy.getRank();
2580 SmallVector<int64_t, 4> offsets = getI64SubArray(op.getOffsets());
2581 SmallVector<int64_t, 4> destStrides = computeStrides(destTy.getShape());
2582
2583 // Calcualte the destination element indices by enumerating all slice
2584 // positions within the destination and linearizing them. The enumeration
2585 // order is lexicographic which yields a sequence of monotonically
2586 // increasing linearized position indices.
2587 // Because the destination may have higher dimensionality then the slice,
2588 // we keep track of two overlapping sets of positions and offsets.
2589 auto denseSlice = sourceCst.cast<DenseElementsAttr>();
2590 auto sliceValuesIt = denseSlice.value_begin<Attribute>();
2591 auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
2592 SmallVector<int64_t> currDestPosition(offsets.begin(), offsets.end());
2593 MutableArrayRef<int64_t> currSlicePosition(
2594 currDestPosition.begin() + rankDifference, currDestPosition.end());
2595 ArrayRef<int64_t> sliceOffsets(offsets.begin() + rankDifference,
2596 offsets.end());
2597 do {
2598 int64_t linearizedPosition = linearize(currDestPosition, destStrides);
2599 assert(linearizedPosition < destTy.getNumElements() && "Invalid index")(static_cast <bool> (linearizedPosition < destTy.getNumElements
() && "Invalid index") ? void (0) : __assert_fail ("linearizedPosition < destTy.getNumElements() && \"Invalid index\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2599, __extension__
__PRETTY_FUNCTION__))
;
2600 assert(sliceValuesIt != denseSlice.value_end<Attribute>() &&(static_cast <bool> (sliceValuesIt != denseSlice.value_end
<Attribute>() && "Invalid slice element") ? void
(0) : __assert_fail ("sliceValuesIt != denseSlice.value_end<Attribute>() && \"Invalid slice element\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2601, __extension__
__PRETTY_FUNCTION__))
2601 "Invalid slice element")(static_cast <bool> (sliceValuesIt != denseSlice.value_end
<Attribute>() && "Invalid slice element") ? void
(0) : __assert_fail ("sliceValuesIt != denseSlice.value_end<Attribute>() && \"Invalid slice element\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2601, __extension__
__PRETTY_FUNCTION__))
;
2602 newValues[linearizedPosition] = *sliceValuesIt;
2603 ++sliceValuesIt;
2604 } while (succeeded(
2605 incSlicePosition(currSlicePosition, sliceShape, sliceOffsets)));
2606
2607 auto newAttr = DenseElementsAttr::get(destTy, newValues);
2608 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
2609 return success();
2610 }
2611};
2612
2613} // namespace
2614
2615void vector::InsertStridedSliceOp::getCanonicalizationPatterns(
2616 RewritePatternSet &results, MLIRContext *context) {
2617 results.add<FoldInsertStridedSliceSplat, FoldInsertStridedSliceOfExtract,
2618 InsertStridedSliceConstantFolder>(context);
2619}
2620
2621OpFoldResult InsertStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2622 if (getSourceVectorType() == getDestVectorType())
2623 return getSource();
2624 return {};
2625}
2626
2627//===----------------------------------------------------------------------===//
2628// OuterProductOp
2629//===----------------------------------------------------------------------===//
2630
2631/// Build an op without mask, use the type of `acc` as the return type.
2632void OuterProductOp::build(OpBuilder &builder, OperationState &result,
2633 Value lhs, Value rhs, Value acc) {
2634 result.addOperands({lhs, rhs, acc});
2635 result.addTypes(acc.getType());
2636}
2637
2638void OuterProductOp::print(OpAsmPrinter &p) {
2639 p << " " << getLhs() << ", " << getRhs();
2640 if (!getAcc().empty()) {
2641 p << ", " << getAcc();
2642 p.printOptionalAttrDict((*this)->getAttrs());
2643 }
2644 p << " : " << getLhs().getType() << ", " << getRhs().getType();
2645}
2646
2647ParseResult OuterProductOp::parse(OpAsmParser &parser, OperationState &result) {
2648 SmallVector<OpAsmParser::UnresolvedOperand, 3> operandsInfo;
2649 Type tLHS, tRHS;
2650 if (parser.parseOperandList(operandsInfo) ||
2651 parser.parseOptionalAttrDict(result.attributes) ||
2652 parser.parseColonType(tLHS) || parser.parseComma() ||
2653 parser.parseType(tRHS))
2654 return failure();
2655 if (operandsInfo.size() < 2)
2656 return parser.emitError(parser.getNameLoc(),
2657 "expected at least 2 operands");
2658 VectorType vLHS = tLHS.dyn_cast<VectorType>();
2659 VectorType vRHS = tRHS.dyn_cast<VectorType>();
2660 if (!vLHS)
2661 return parser.emitError(parser.getNameLoc(),
2662 "expected vector type for operand #1");
2663 VectorType resType =
2664 vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
2665 vLHS.getElementType())
2666 : VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
2667
2668 if (!result.attributes.get(OuterProductOp::getKindAttrStrName())) {
2669 result.attributes.append(
2670 OuterProductOp::getKindAttrStrName(),
2671 CombiningKindAttr::get(result.getContext(),
2672 OuterProductOp::getDefaultKind()));
2673 }
2674
2675 return failure(
2676 parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
2677 parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
2678 (operandsInfo.size() > 2 &&
2679 parser.resolveOperand(operandsInfo[2], resType, result.operands)) ||
2680 parser.addTypeToList(resType, result.types));
2681}
2682
2683LogicalResult OuterProductOp::verify() {
2684 Type tRHS = getOperandTypeRHS();
2685 VectorType vLHS = getOperandVectorTypeLHS(),
2686 vRHS = tRHS.dyn_cast<VectorType>(),
2687 vACC = getOperandVectorTypeACC(), vRES = getVectorType();
2688
2689 if (vLHS.getRank() != 1)
2690 return emitOpError("expected 1-d vector for operand #1");
2691
2692 if (vRHS) {
2693 // Proper OUTER operation.
2694 if (vRHS.getRank() != 1)
2695 return emitOpError("expected 1-d vector for operand #2");
2696 if (vRES.getRank() != 2)
2697 return emitOpError("expected 2-d vector result");
2698 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2699 return emitOpError("expected #1 operand dim to match result dim #1");
2700 if (vRHS.getDimSize(0) != vRES.getDimSize(1))
2701 return emitOpError("expected #2 operand dim to match result dim #2");
2702 } else {
2703 // An AXPY operation.
2704 if (vRES.getRank() != 1)
2705 return emitOpError("expected 1-d vector result");
2706 if (vLHS.getDimSize(0) != vRES.getDimSize(0))
2707 return emitOpError("expected #1 operand dim to match result dim #1");
2708 }
2709
2710 if (vACC && vACC != vRES)
2711 return emitOpError("expected operand #3 of same type as result type");
2712
2713 // Verify supported combining kind.
2714 if (!isSupportedCombiningKind(getKind(), vRES.getElementType()))
2715 return emitOpError("unsupported outerproduct type");
2716
2717 return success();
2718}
2719
2720//===----------------------------------------------------------------------===//
2721// ReshapeOp
2722//===----------------------------------------------------------------------===//
2723
2724LogicalResult ReshapeOp::verify() {
2725 // Verify that rank(numInputs/outputs) + numFixedVec dim matches vec rank.
2726 auto inputVectorType = getInputVectorType();
2727 auto outputVectorType = getOutputVectorType();
2728 int64_t inputShapeRank = getNumInputShapeSizes();
2729 int64_t outputShapeRank = getNumOutputShapeSizes();
2730 SmallVector<int64_t, 4> fixedVectorSizes;
2731 getFixedVectorSizes(fixedVectorSizes);
2732 int64_t numFixedVectorSizes = fixedVectorSizes.size();
2733
2734 if (inputVectorType.getRank() != inputShapeRank + numFixedVectorSizes)
2735 return emitError("invalid input shape for vector type ") << inputVectorType;
2736
2737 if (outputVectorType.getRank() != outputShapeRank + numFixedVectorSizes)
2738 return emitError("invalid output shape for vector type ")
2739 << outputVectorType;
2740
2741 // Verify that the 'fixedVectorSizes' match an input/output vector shape
2742 // suffix.
2743 unsigned inputVectorRank = inputVectorType.getRank();
2744 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2745 unsigned index = inputVectorRank - numFixedVectorSizes - i;
2746 if (fixedVectorSizes[i] != inputVectorType.getShape()[index])
2747 return emitError("fixed vector size must match input vector for dim ")
2748 << i;
2749 }
2750
2751 unsigned outputVectorRank = outputVectorType.getRank();
2752 for (unsigned i = 0; i < numFixedVectorSizes; ++i) {
2753 unsigned index = outputVectorRank - numFixedVectorSizes - i;
2754 if (fixedVectorSizes[i] != outputVectorType.getShape()[index])
2755 return emitError("fixed vector size must match output vector for dim ")
2756 << i;
2757 }
2758
2759 // If all shape operands are produced by constant ops, verify that product
2760 // of dimensions for input/output shape match.
2761 auto isDefByConstant = [](Value operand) {
2762 return isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
2763 };
2764 if (llvm::all_of(getInputShape(), isDefByConstant) &&
2765 llvm::all_of(getOutputShape(), isDefByConstant)) {
2766 int64_t numInputElements = 1;
2767 for (auto operand : getInputShape())
2768 numInputElements *=
2769 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2770 int64_t numOutputElements = 1;
2771 for (auto operand : getOutputShape())
2772 numOutputElements *=
2773 cast<arith::ConstantIndexOp>(operand.getDefiningOp()).value();
2774 if (numInputElements != numOutputElements)
2775 return emitError("product of input and output shape sizes must match");
2776 }
2777 return success();
2778}
2779
2780void ReshapeOp::getFixedVectorSizes(SmallVectorImpl<int64_t> &results) {
2781 populateFromInt64AttrArray(getFixedVectorSizes(), results);
2782}
2783
2784//===----------------------------------------------------------------------===//
2785// ExtractStridedSliceOp
2786//===----------------------------------------------------------------------===//
2787
2788// Inference works as follows:
2789// 1. Add 'sizes' from prefix of dims in 'offsets'.
2790// 2. Add sizes from 'vectorType' for remaining dims.
2791static Type inferStridedSliceOpResultType(VectorType vectorType,
2792 ArrayAttr offsets, ArrayAttr sizes,
2793 ArrayAttr strides) {
2794 assert(offsets.size() == sizes.size() && offsets.size() == strides.size())(static_cast <bool> (offsets.size() == sizes.size() &&
offsets.size() == strides.size()) ? void (0) : __assert_fail
("offsets.size() == sizes.size() && offsets.size() == strides.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 2794, __extension__
__PRETTY_FUNCTION__))
;
2795 SmallVector<int64_t, 4> shape;
2796 shape.reserve(vectorType.getRank());
2797 unsigned idx = 0;
2798 for (unsigned e = offsets.size(); idx < e; ++idx)
2799 shape.push_back(sizes[idx].cast<IntegerAttr>().getInt());
2800 for (unsigned e = vectorType.getShape().size(); idx < e; ++idx)
2801 shape.push_back(vectorType.getShape()[idx]);
2802
2803 return VectorType::get(shape, vectorType.getElementType());
2804}
2805
2806void ExtractStridedSliceOp::build(OpBuilder &builder, OperationState &result,
2807 Value source, ArrayRef<int64_t> offsets,
2808 ArrayRef<int64_t> sizes,
2809 ArrayRef<int64_t> strides) {
2810 result.addOperands(source);
2811 auto offsetsAttr = getVectorSubscriptAttr(builder, offsets);
2812 auto sizesAttr = getVectorSubscriptAttr(builder, sizes);
2813 auto stridesAttr = getVectorSubscriptAttr(builder, strides);
2814 result.addTypes(
2815 inferStridedSliceOpResultType(source.getType().cast<VectorType>(),
2816 offsetsAttr, sizesAttr, stridesAttr));
2817 result.addAttribute(getOffsetsAttrStrName(), offsetsAttr);
2818 result.addAttribute(getSizesAttrStrName(), sizesAttr);
2819 result.addAttribute(getStridesAttrStrName(), stridesAttr);
2820}
2821
2822LogicalResult ExtractStridedSliceOp::verify() {
2823 auto type = getVectorType();
2824 auto offsets = getOffsetsAttr();
2825 auto sizes = getSizesAttr();
2826 auto strides = getStridesAttr();
2827 if (offsets.size() != sizes.size() || offsets.size() != strides.size())
2828 return emitOpError(
2829 "expected offsets, sizes and strides attributes of same size");
2830
2831 auto shape = type.getShape();
2832 auto offName = getOffsetsAttrName();
2833 auto sizesName = getSizesAttrName();
2834 auto stridesName = getStridesAttrName();
2835 if (failed(
2836 isIntegerArrayAttrSmallerThanShape(*this, offsets, shape, offName)) ||
2837 failed(
2838 isIntegerArrayAttrSmallerThanShape(*this, sizes, shape, sizesName)) ||
2839 failed(isIntegerArrayAttrSmallerThanShape(*this, strides, shape,
2840 stridesName)) ||
2841 failed(
2842 isIntegerArrayAttrConfinedToShape(*this, offsets, shape, offName)) ||
2843 failed(isIntegerArrayAttrConfinedToShape(*this, sizes, shape, sizesName,
2844 /*halfOpen=*/false,
2845 /*min=*/1)) ||
2846 failed(isIntegerArrayAttrConfinedToRange(*this, strides, 1, 1,
2847 stridesName,
2848 /*halfOpen=*/false)) ||
2849 failed(isSumOfIntegerArrayAttrConfinedToShape(*this, offsets, sizes,
2850 shape, offName, sizesName,
2851 /*halfOpen=*/false)))
2852 return failure();
2853
2854 auto resultType =
2855 inferStridedSliceOpResultType(getVectorType(), offsets, sizes, strides);
2856 if (getResult().getType() != resultType)
2857 return emitOpError("expected result type to be ") << resultType;
2858
2859 return success();
2860}
2861
2862// When the source of ExtractStrided comes from a chain of InsertStrided ops try
2863// to use the source of the InsertStrided ops if we can detect that the
2864// extracted vector is a subset of one of the vector inserted.
2865static LogicalResult
2866foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
2867 // Helper to extract integer out of ArrayAttr.
2868 auto getElement = [](ArrayAttr array, int idx) {
2869 return array[idx].cast<IntegerAttr>().getInt();
2870 };
2871 ArrayAttr extractOffsets = op.getOffsets();
2872 ArrayAttr extractStrides = op.getStrides();
2873 ArrayAttr extractSizes = op.getSizes();
2874 auto insertOp = op.getVector().getDefiningOp<InsertStridedSliceOp>();
2875 while (insertOp) {
2876 if (op.getVectorType().getRank() !=
2877 insertOp.getSourceVectorType().getRank())
2878 return failure();
2879 ArrayAttr insertOffsets = insertOp.getOffsets();
2880 ArrayAttr insertStrides = insertOp.getStrides();
2881 // If the rank of extract is greater than the rank of insert, we are likely
2882 // extracting a partial chunk of the vector inserted.
2883 if (extractOffsets.size() > insertOffsets.size())
2884 return failure();
2885 bool patialoverlap = false;
2886 bool disjoint = false;
2887 SmallVector<int64_t, 4> offsetDiffs;
2888 for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
2889 if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
2890 return failure();
2891 int64_t start = getElement(insertOffsets, dim);
2892 int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
2893 int64_t offset = getElement(extractOffsets, dim);
2894 int64_t size = getElement(extractSizes, dim);
2895 // Check if the start of the extract offset is in the interval inserted.
2896 if (start <= offset && offset < end) {
2897 // If the extract interval overlaps but is not fully included we may
2898 // have a partial overlap that will prevent any folding.
2899 if (offset + size > end)
2900 patialoverlap = true;
2901 offsetDiffs.push_back(offset - start);
2902 continue;
2903 }
2904 disjoint = true;
2905 break;
2906 }
2907 // The extract element chunk is a subset of the insert element.
2908 if (!disjoint && !patialoverlap) {
2909 op.setOperand(insertOp.getSource());
2910 // OpBuilder is only used as a helper to build an I64ArrayAttr.
2911 OpBuilder b(op.getContext());
2912 op->setAttr(ExtractStridedSliceOp::getOffsetsAttrStrName(),
2913 b.getI64ArrayAttr(offsetDiffs));
2914 return success();
2915 }
2916 // If the chunk extracted is disjoint from the chunk inserted, keep looking
2917 // in the insert chain.
2918 if (disjoint)
2919 insertOp = insertOp.getDest().getDefiningOp<InsertStridedSliceOp>();
2920 else {
2921 // The extracted vector partially overlap the inserted vector, we cannot
2922 // fold.
2923 return failure();
2924 }
2925 }
2926 return failure();
2927}
2928
2929OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
2930 if (getVectorType() == getResult().getType())
2931 return getVector();
2932 if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
2933 return getResult();
2934 return {};
2935}
2936
2937void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
2938 populateFromInt64AttrArray(getOffsets(), results);
2939}
2940
2941namespace {
2942
2943// Pattern to rewrite an ExtractStridedSliceOp(ConstantMaskOp) to
2944// ConstantMaskOp.
2945class StridedSliceConstantMaskFolder final
2946 : public OpRewritePattern<ExtractStridedSliceOp> {
2947public:
2948 using OpRewritePattern::OpRewritePattern;
2949
2950 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
2951 PatternRewriter &rewriter) const override {
2952 // Return if 'extractStridedSliceOp' operand is not defined by a
2953 // ConstantMaskOp.
2954 auto *defOp = extractStridedSliceOp.getVector().getDefiningOp();
2955 auto constantMaskOp = dyn_cast_or_null<ConstantMaskOp>(defOp);
2956 if (!constantMaskOp)
2957 return failure();
2958 // Return if 'extractStridedSliceOp' has non-unit strides.
2959 if (extractStridedSliceOp.hasNonUnitStrides())
2960 return failure();
2961 // Gather constant mask dimension sizes.
2962 SmallVector<int64_t, 4> maskDimSizes;
2963 populateFromInt64AttrArray(constantMaskOp.getMaskDimSizes(), maskDimSizes);
2964 // Gather strided slice offsets and sizes.
2965 SmallVector<int64_t, 4> sliceOffsets;
2966 populateFromInt64AttrArray(extractStridedSliceOp.getOffsets(),
2967 sliceOffsets);
2968 SmallVector<int64_t, 4> sliceSizes;
2969 populateFromInt64AttrArray(extractStridedSliceOp.getSizes(), sliceSizes);
2970
2971 // Compute slice of vector mask region.
2972 SmallVector<int64_t, 4> sliceMaskDimSizes;
2973 sliceMaskDimSizes.reserve(maskDimSizes.size());
2974 for (auto [maskDimSize, sliceOffset, sliceSize] :
2975 llvm::zip(maskDimSizes, sliceOffsets, sliceSizes)) {
2976 int64_t sliceMaskDimSize = std::max(
2977 static_cast<int64_t>(0),
2978 std::min(sliceOffset + sliceSize, maskDimSize) - sliceOffset);
2979 sliceMaskDimSizes.push_back(sliceMaskDimSize);
2980 }
2981 // Add unchanged dimensions.
2982 if (sliceMaskDimSizes.size() < maskDimSizes.size())
2983 for (size_t i = sliceMaskDimSizes.size(); i < maskDimSizes.size(); ++i)
2984 sliceMaskDimSizes.push_back(maskDimSizes[i]);
2985 // If any of 'sliceMaskDimSizes' are zero, then set all to zero (masked
2986 // region is a conjunction of mask dim intervals).
2987 if (llvm::is_contained(sliceMaskDimSizes, 0))
2988 sliceMaskDimSizes.assign(maskDimSizes.size(), 0);
2989
2990 // Replace 'extractStridedSliceOp' with ConstantMaskOp with sliced mask
2991 // region.
2992 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
2993 extractStridedSliceOp, extractStridedSliceOp.getResult().getType(),
2994 vector::getVectorSubscriptAttr(rewriter, sliceMaskDimSizes));
2995 return success();
2996 }
2997};
2998
2999// Pattern to rewrite a ExtractStridedSliceOp(splat ConstantOp) -> ConstantOp.
3000class StridedSliceSplatConstantFolder final
3001 : public OpRewritePattern<ExtractStridedSliceOp> {
3002public:
3003 using OpRewritePattern::OpRewritePattern;
3004
3005 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3006 PatternRewriter &rewriter) const override {
3007 // Return if 'ExtractStridedSliceOp' operand is not defined by a splat
3008 // ConstantOp.
3009 Value sourceVector = extractStridedSliceOp.getVector();
3010 Attribute vectorCst;
3011 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3012 return failure();
3013
3014 auto splat = vectorCst.dyn_cast<SplatElementsAttr>();
3015 if (!splat)
3016 return failure();
3017
3018 auto newAttr = SplatElementsAttr::get(extractStridedSliceOp.getType(),
3019 splat.getSplatValue<Attribute>());
3020 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3021 newAttr);
3022 return success();
3023 }
3024};
3025
3026// Pattern to rewrite a ExtractStridedSliceOp(non-splat ConstantOp) ->
3027// ConstantOp.
3028class StridedSliceNonSplatConstantFolder final
3029 : public OpRewritePattern<ExtractStridedSliceOp> {
3030public:
3031 using OpRewritePattern::OpRewritePattern;
3032
3033 LogicalResult matchAndRewrite(ExtractStridedSliceOp extractStridedSliceOp,
3034 PatternRewriter &rewriter) const override {
3035 // Return if 'ExtractStridedSliceOp' operand is not defined by a non-splat
3036 // ConstantOp.
3037 Value sourceVector = extractStridedSliceOp.getVector();
3038 Attribute vectorCst;
3039 if (!matchPattern(sourceVector, m_Constant(&vectorCst)))
3040 return failure();
3041
3042 // The splat case is handled by `StridedSliceSplatConstantFolder`.
3043 auto dense = vectorCst.dyn_cast<DenseElementsAttr>();
3044 if (!dense || dense.isSplat())
3045 return failure();
3046
3047 // TODO: Handle non-unit strides when they become available.
3048 if (extractStridedSliceOp.hasNonUnitStrides())
3049 return failure();
3050
3051 auto sourceVecTy = sourceVector.getType().cast<VectorType>();
3052 ArrayRef<int64_t> sourceShape = sourceVecTy.getShape();
3053 SmallVector<int64_t, 4> sourceStrides = computeStrides(sourceShape);
3054
3055 VectorType sliceVecTy = extractStridedSliceOp.getType();
3056 ArrayRef<int64_t> sliceShape = sliceVecTy.getShape();
3057 int64_t sliceRank = sliceVecTy.getRank();
3058
3059 // Expand offsets and sizes to match the vector rank.
3060 SmallVector<int64_t, 4> offsets(sliceRank, 0);
3061 copy(getI64SubArray(extractStridedSliceOp.getOffsets()), offsets.begin());
3062
3063 SmallVector<int64_t, 4> sizes(sourceShape.begin(), sourceShape.end());
3064 copy(getI64SubArray(extractStridedSliceOp.getSizes()), sizes.begin());
3065
3066 // Calculate the slice elements by enumerating all slice positions and
3067 // linearizing them. The enumeration order is lexicographic which yields a
3068 // sequence of monotonically increasing linearized position indices.
3069 auto denseValuesBegin = dense.value_begin<Attribute>();
3070 SmallVector<Attribute> sliceValues;
3071 sliceValues.reserve(sliceVecTy.getNumElements());
3072 SmallVector<int64_t> currSlicePosition(offsets.begin(), offsets.end());
3073 do {
3074 int64_t linearizedPosition = linearize(currSlicePosition, sourceStrides);
3075 assert(linearizedPosition < sourceVecTy.getNumElements() &&(static_cast <bool> (linearizedPosition < sourceVecTy
.getNumElements() && "Invalid index") ? void (0) : __assert_fail
("linearizedPosition < sourceVecTy.getNumElements() && \"Invalid index\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3076, __extension__
__PRETTY_FUNCTION__))
3076 "Invalid index")(static_cast <bool> (linearizedPosition < sourceVecTy
.getNumElements() && "Invalid index") ? void (0) : __assert_fail
("linearizedPosition < sourceVecTy.getNumElements() && \"Invalid index\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3076, __extension__
__PRETTY_FUNCTION__))
;
3077 sliceValues.push_back(*(denseValuesBegin + linearizedPosition));
3078 } while (
3079 succeeded(incSlicePosition(currSlicePosition, sliceShape, offsets)));
3080
3081 assert(static_cast<int64_t>(sliceValues.size()) ==(static_cast <bool> (static_cast<int64_t>(sliceValues
.size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements"
) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__
__PRETTY_FUNCTION__))
3082 sliceVecTy.getNumElements() &&(static_cast <bool> (static_cast<int64_t>(sliceValues
.size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements"
) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__
__PRETTY_FUNCTION__))
3083 "Invalid number of slice elements")(static_cast <bool> (static_cast<int64_t>(sliceValues
.size()) == sliceVecTy.getNumElements() && "Invalid number of slice elements"
) ? void (0) : __assert_fail ("static_cast<int64_t>(sliceValues.size()) == sliceVecTy.getNumElements() && \"Invalid number of slice elements\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3083, __extension__
__PRETTY_FUNCTION__))
;
3084 auto newAttr = DenseElementsAttr::get(sliceVecTy, sliceValues);
3085 rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
3086 newAttr);
3087 return success();
3088 }
3089};
3090
3091// Pattern to rewrite an ExtractStridedSliceOp(BroadcastOp) to
3092// BroadcastOp(ExtractStrideSliceOp).
3093class StridedSliceBroadcast final
3094 : public OpRewritePattern<ExtractStridedSliceOp> {
3095public:
3096 using OpRewritePattern::OpRewritePattern;
3097
3098 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3099 PatternRewriter &rewriter) const override {
3100 auto broadcast = op.getVector().getDefiningOp<BroadcastOp>();
3101 if (!broadcast)
3102 return failure();
3103 auto srcVecType = broadcast.getSource().getType().dyn_cast<VectorType>();
3104 unsigned srcRank = srcVecType ? srcVecType.getRank() : 0;
3105 auto dstVecType = op.getType().cast<VectorType>();
3106 unsigned dstRank = dstVecType.getRank();
3107 unsigned rankDiff = dstRank - srcRank;
3108 // Check if the most inner dimensions of the source of the broadcast are the
3109 // same as the destination of the extract. If this is the case we can just
3110 // use a broadcast as the original dimensions are untouched.
3111 bool lowerDimMatch = true;
3112 for (unsigned i = 0; i < srcRank; i++) {
3113 if (srcVecType.getDimSize(i) != dstVecType.getDimSize(i + rankDiff)) {
3114 lowerDimMatch = false;
3115 break;
3116 }
3117 }
3118 Value source = broadcast.getSource();
3119 // If the inner dimensions don't match, it means we need to extract from the
3120 // source of the orignal broadcast and then broadcast the extracted value.
3121 // We also need to handle degenerated cases where the source is effectively
3122 // just a single scalar.
3123 bool isScalarSrc = (srcRank == 0 || srcVecType.getNumElements() == 1);
3124 if (!lowerDimMatch && !isScalarSrc) {
3125 source = rewriter.create<ExtractStridedSliceOp>(
3126 op->getLoc(), source,
3127 getI64SubArray(op.getOffsets(), /* dropFront=*/rankDiff),
3128 getI64SubArray(op.getSizes(), /* dropFront=*/rankDiff),
3129 getI64SubArray(op.getStrides(), /* dropFront=*/rankDiff));
3130 }
3131 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(), source);
3132 return success();
3133 }
3134};
3135
3136/// Pattern to rewrite an ExtractStridedSliceOp(SplatOp) to SplatOp.
3137class StridedSliceSplat final : public OpRewritePattern<ExtractStridedSliceOp> {
3138public:
3139 using OpRewritePattern::OpRewritePattern;
3140
3141 LogicalResult matchAndRewrite(ExtractStridedSliceOp op,
3142 PatternRewriter &rewriter) const override {
3143 auto splat = op.getVector().getDefiningOp<SplatOp>();
3144 if (!splat)
3145 return failure();
3146 rewriter.replaceOpWithNewOp<SplatOp>(op, op.getType(), splat.getInput());
3147 return success();
3148 }
3149};
3150
3151} // namespace
3152
3153void ExtractStridedSliceOp::getCanonicalizationPatterns(
3154 RewritePatternSet &results, MLIRContext *context) {
3155 // Pattern to rewrite a ExtractStridedSliceOp(ConstantMaskOp) ->
3156 // ConstantMaskOp and ExtractStridedSliceOp(ConstantOp) -> ConstantOp.
3157 results.add<StridedSliceConstantMaskFolder, StridedSliceSplatConstantFolder,
3158 StridedSliceNonSplatConstantFolder, StridedSliceBroadcast,
3159 StridedSliceSplat>(context);
3160}
3161
3162//===----------------------------------------------------------------------===//
3163// TransferReadOp
3164//===----------------------------------------------------------------------===//
3165
3166/// 1. Builder that sets padding to zero and an empty mask (variant with attrs).
3167void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3168 VectorType vectorType, Value source,
3169 ValueRange indices, AffineMapAttr permutationMapAttr,
3170 /*optional*/ ArrayAttr inBoundsAttr) {
3171 Type elemType = source.getType().cast<ShapedType>().getElementType();
3172 Value padding = builder.create<arith::ConstantOp>(
3173 result.location, elemType, builder.getZeroAttr(elemType));
3174 build(builder, result, vectorType, source, indices, permutationMapAttr,
3175 padding, /*mask=*/Value(), inBoundsAttr);
3176}
3177
3178/// 2. Builder that sets padding to zero an empty mask (variant without attrs).
3179void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3180 VectorType vectorType, Value source,
3181 ValueRange indices, AffineMap permutationMap,
3182 std::optional<ArrayRef<bool>> inBounds) {
3183 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3184 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3185 ? builder.getBoolArrayAttr(inBounds.value())
3186 : ArrayAttr();
3187 build(builder, result, vectorType, source, indices, permutationMapAttr,
3188 inBoundsAttr);
3189}
3190
3191/// 3. Builder that sets permutation map to 'getMinorIdentityMap'.
3192void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3193 VectorType vectorType, Value source,
3194 ValueRange indices, Value padding,
3195 std::optional<ArrayRef<bool>> inBounds) {
3196 AffineMap permutationMap = getTransferMinorIdentityMap(
3197 source.getType().cast<ShapedType>(), vectorType);
3198 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3199 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3200 ? builder.getBoolArrayAttr(inBounds.value())
3201 : ArrayAttr();
3202 build(builder, result, vectorType, source, indices, permutationMapAttr,
3203 padding,
3204 /*mask=*/Value(), inBoundsAttr);
3205}
3206
3207/// 4. Builder that sets padding to zero and permutation map to
3208/// 'getMinorIdentityMap'.
3209void TransferReadOp::build(OpBuilder &builder, OperationState &result,
3210 VectorType vectorType, Value source,
3211 ValueRange indices,
3212 std::optional<ArrayRef<bool>> inBounds) {
3213 Type elemType = source.getType().cast<ShapedType>().getElementType();
3214 Value padding = builder.create<arith::ConstantOp>(
3215 result.location, elemType, builder.getZeroAttr(elemType));
3216 build(builder, result, vectorType, source, indices, padding, inBounds);
3217}
3218
3219template <typename EmitFun>
3220static LogicalResult verifyPermutationMap(AffineMap permutationMap,
3221 EmitFun emitOpError) {
3222 SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
3223 for (auto expr : permutationMap.getResults()) {
3224 auto dim = expr.dyn_cast<AffineDimExpr>();
3225 auto zero = expr.dyn_cast<AffineConstantExpr>();
3226 if (zero) {
3227 if (zero.getValue() != 0) {
3228 return emitOpError(
3229 "requires a projected permutation_map (at most one dim or the zero "
3230 "constant can appear in each result)");
3231 }
3232 continue;
3233 }
3234 if (!dim) {
3235 return emitOpError("requires a projected permutation_map (at most one "
3236 "dim or the zero constant can appear in each result)");
3237 }
3238 if (seen[dim.getPosition()]) {
3239 return emitOpError(
3240 "requires a permutation_map that is a permutation (found one dim "
3241 "used more than once)");
3242 }
3243 seen[dim.getPosition()] = true;
3244 }
3245 return success();
3246}
3247
3248static LogicalResult
3249verifyTransferOp(VectorTransferOpInterface op, ShapedType shapedType,
3250 VectorType vectorType, VectorType maskType,
3251 VectorType inferredMaskType, AffineMap permutationMap,
3252 ArrayAttr inBounds) {
3253 if (op->hasAttr("masked")) {
3254 return op->emitOpError("masked attribute has been removed. "
3255 "Use in_bounds instead.");
3256 }
3257
3258 if (!shapedType.isa<MemRefType, RankedTensorType>())
3259 return op->emitOpError(
3260 "requires source to be a memref or ranked tensor type");
3261
3262 auto elementType = shapedType.getElementType();
3263 DataLayout dataLayout = DataLayout::closest(op);
3264 if (auto vectorElementType = elementType.dyn_cast<VectorType>()) {
3265 // Memref or tensor has vector element type.
3266 unsigned sourceVecSize =
3267 dataLayout.getTypeSizeInBits(vectorElementType.getElementType()) *
3268 vectorElementType.getShape().back();
3269 unsigned resultVecSize =
3270 dataLayout.getTypeSizeInBits(vectorType.getElementType()) *
3271 vectorType.getShape().back();
3272 if (resultVecSize % sourceVecSize != 0)
3273 return op->emitOpError(
3274 "requires the bitwidth of the minor 1-D vector to be an integral "
3275 "multiple of the bitwidth of the minor 1-D vector of the source");
3276
3277 unsigned sourceVecEltRank = vectorElementType.getRank();
3278 unsigned resultVecRank = vectorType.getRank();
3279 if (sourceVecEltRank > resultVecRank)
3280 return op->emitOpError(
3281 "requires source vector element and vector result ranks to match.");
3282 unsigned rankOffset = resultVecRank - sourceVecEltRank;
3283 // Check that permutation map results match 'rankOffset' of vector type.
3284 if (permutationMap.getNumResults() != rankOffset)
3285 return op->emitOpError("requires a permutation_map with result dims of "
3286 "the same rank as the vector type");
3287
3288 if (maskType)
3289 return op->emitOpError("does not support masks with vector element type");
3290 } else {
3291 // Memref or tensor has scalar element type.
3292 unsigned minorSize =
3293 vectorType.getRank() == 0 ? 1 : vectorType.getShape().back();
3294 unsigned resultVecSize =
3295 dataLayout.getTypeSizeInBits(vectorType.getElementType()) * minorSize;
3296 if (resultVecSize % dataLayout.getTypeSizeInBits(elementType) != 0)
3297 return op->emitOpError(
3298 "requires the bitwidth of the minor 1-D vector to be an integral "
3299 "multiple of the bitwidth of the source element type");
3300
3301 // Check that permutation map results match rank of vector type.
3302 if (permutationMap.getNumResults() != vectorType.getRank())
3303 return op->emitOpError("requires a permutation_map with result dims of "
3304 "the same rank as the vector type");
3305 }
3306
3307 if (permutationMap.getNumSymbols() != 0)
3308 return op->emitOpError("requires permutation_map without symbols");
3309
3310 if (permutationMap.getNumInputs() != shapedType.getRank())
3311 return op->emitOpError("requires a permutation_map with input dims of the "
3312 "same rank as the source type");
3313
3314 if (maskType && maskType != inferredMaskType)
3315 return op->emitOpError("inferred mask type (")
3316 << inferredMaskType << ") and mask operand type (" << maskType
3317 << ") don't match";
3318
3319 if (inBounds) {
3320 if (permutationMap.getNumResults() != static_cast<int64_t>(inBounds.size()))
3321 return op->emitOpError("expects the optional in_bounds attr of same rank "
3322 "as permutation_map results: ")
3323 << AffineMapAttr::get(permutationMap)
3324 << " vs inBounds of size: " << inBounds.size();
3325 for (unsigned int i = 0; i < permutationMap.getNumResults(); ++i)
3326 if (permutationMap.getResult(i).isa<AffineConstantExpr>() &&
3327 !inBounds.getValue()[i].cast<BoolAttr>().getValue())
3328 return op->emitOpError("requires broadcast dimensions to be in-bounds");
3329 }
3330
3331 return success();
3332}
3333
3334static void printTransferAttrs(OpAsmPrinter &p, VectorTransferOpInterface op) {
3335 SmallVector<StringRef, 3> elidedAttrs;
3336 elidedAttrs.push_back(TransferReadOp::getOperandSegmentSizeAttr());
3337 if (op.permutation_map().isMinorIdentity())
3338 elidedAttrs.push_back(op.getPermutationMapAttrStrName());
3339 bool elideInBounds = true;
3340 if (auto inBounds = op.in_bounds()) {
3341 for (auto attr : *inBounds) {
3342 if (attr.template cast<BoolAttr>().getValue()) {
3343 elideInBounds = false;
3344 break;
3345 }
3346 }
3347 }
3348 if (elideInBounds)
3349 elidedAttrs.push_back(op.getInBoundsAttrStrName());
3350 p.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
3351}
3352
3353void TransferReadOp::print(OpAsmPrinter &p) {
3354 p << " " << getSource() << "[" << getIndices() << "], " << getPadding();
3355 if (getMask())
3356 p << ", " << getMask();
3357 printTransferAttrs(p, *this);
3358 p << " : " << getShapedType() << ", " << getVectorType();
3359}
3360
3361/// Infers the mask type for a transfer read given its vector type and
3362/// permutation map. The mask in a transfer read operation applies to the
3363/// tensor/buffer reading part of it and its type should match the shape read
3364/// *before* any permutation or broadcasting.
3365static VectorType inferTransferReadMaskType(VectorType vecType,
3366 AffineMap permMap) {
3367 auto i1Type = IntegerType::get(permMap.getContext(), 1);
3368 AffineMap invPermMap = inversePermutation(compressUnusedDims(permMap));
3369 assert(invPermMap && "Inversed permutation map couldn't be computed")(static_cast <bool> (invPermMap && "Inversed permutation map couldn't be computed"
) ? void (0) : __assert_fail ("invPermMap && \"Inversed permutation map couldn't be computed\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3369, __extension__
__PRETTY_FUNCTION__))
;
3370 SmallVector<int64_t, 8> maskShape = invPermMap.compose(vecType.getShape());
3371 return VectorType::get(maskShape, i1Type);
3372}
3373
3374ParseResult TransferReadOp::parse(OpAsmParser &parser, OperationState &result) {
3375 auto &builder = parser.getBuilder();
3376 SMLoc typesLoc;
3377 OpAsmParser::UnresolvedOperand sourceInfo;
3378 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3379 OpAsmParser::UnresolvedOperand paddingInfo;
3380 SmallVector<Type, 2> types;
3381 OpAsmParser::UnresolvedOperand maskInfo;
3382 // Parsing with support for paddingValue.
3383 if (parser.parseOperand(sourceInfo) ||
3384 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
3385 parser.parseComma() || parser.parseOperand(paddingInfo))
3386 return failure();
3387 ParseResult hasMask = parser.parseOptionalComma();
3388 if (hasMask.succeeded()) {
3389 if (parser.parseOperand(maskInfo))
3390 return failure();
3391 }
3392 if (parser.parseOptionalAttrDict(result.attributes) ||
3393 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3394 return failure();
3395 if (types.size() != 2)
3396 return parser.emitError(typesLoc, "requires two types");
3397 auto indexType = builder.getIndexType();
3398 auto shapedType = types[0].dyn_cast<ShapedType>();
3399 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3400 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3401 VectorType vectorType = types[1].dyn_cast<VectorType>();
3402 if (!vectorType)
3403 return parser.emitError(typesLoc, "requires vector type");
3404 auto permMapAttrName = TransferReadOp::getPermutationMapAttrStrName();
3405 Attribute permMapAttr = result.attributes.get(permMapAttrName);
3406 AffineMap permMap;
3407 if (!permMapAttr) {
3408 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3409 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3410 } else {
3411 permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3412 }
3413 if (parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3414 parser.resolveOperands(indexInfo, indexType, result.operands) ||
3415 parser.resolveOperand(paddingInfo, shapedType.getElementType(),
3416 result.operands))
3417 return failure();
3418 if (hasMask.succeeded()) {
3419 if (shapedType.getElementType().dyn_cast<VectorType>())
3420 return parser.emitError(
3421 maskInfo.location, "does not support masks with vector element type");
3422 // Instead of adding the mask type as an op type, compute it based on the
3423 // vector type and the permutation map (to keep the type signature small).
3424 auto maskType = inferTransferReadMaskType(vectorType, permMap);
3425 if (parser.resolveOperand(maskInfo, maskType, result.operands))
3426 return failure();
3427 }
3428 result.addAttribute(TransferReadOp::getOperandSegmentSizeAttr(),
3429 builder.getDenseI32ArrayAttr(
3430 {1, static_cast<int32_t>(indexInfo.size()), 1,
3431 static_cast<int32_t>(hasMask.succeeded())}));
3432 return parser.addTypeToList(vectorType, result.types);
3433}
3434
3435LogicalResult TransferReadOp::verify() {
3436 // Consistency of elemental types in source and vector.
3437 ShapedType shapedType = getShapedType();
3438 VectorType vectorType = getVectorType();
3439 VectorType maskType = getMaskType();
3440 auto paddingType = getPadding().getType();
3441 auto permutationMap = getPermutationMap();
3442 VectorType inferredMaskType =
3443 maskType ? inferTransferReadMaskType(vectorType, permutationMap)
3444 : VectorType();
3445 auto sourceElementType = shapedType.getElementType();
3446
3447 if (static_cast<int64_t>(getIndices().size()) != shapedType.getRank())
3448 return emitOpError("requires ") << shapedType.getRank() << " indices";
3449
3450 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3451 shapedType, vectorType, maskType,
3452 inferredMaskType, permutationMap,
3453 getInBounds() ? *getInBounds() : ArrayAttr())))
3454 return failure();
3455
3456 if (auto sourceVectorElementType = sourceElementType.dyn_cast<VectorType>()) {
3457 // Source has vector element type.
3458 // Check that 'sourceVectorElementType' and 'paddingType' types match.
3459 if (sourceVectorElementType != paddingType)
3460 return emitOpError(
3461 "requires source element type and padding type to match.");
3462
3463 } else {
3464 // Check that 'paddingType' is valid to store in a vector type.
3465 if (!VectorType::isValidElementType(paddingType))
3466 return emitOpError("requires valid padding vector elemental type");
3467
3468 // Check that padding type and vector element types match.
3469 if (paddingType != sourceElementType)
3470 return emitOpError(
3471 "requires formal padding and source of the same elemental type");
3472 }
3473
3474 return verifyPermutationMap(permutationMap,
3475 [&](Twine t) { return emitOpError(t); });
3476}
3477
3478// MaskableOpInterface methods.
3479
3480/// Returns the mask type expected by this operation. Mostly used for
3481/// verification purposes. It requires the operation to be vectorized."
3482Type TransferReadOp::getExpectedMaskType() {
3483 return inferTransferReadMaskType(getVectorType(), getPermutationMap());
3484}
3485
3486template <typename TransferOp>
3487static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
3488 // TODO: support more aggressive createOrFold on:
3489 // `op.indices()[indicesIdx] + vectorType < dim(op.source(), indicesIdx)`
3490 if (op.getShapedType().isDynamicDim(indicesIdx))
3491 return false;
3492 Value index = op.getIndices()[indicesIdx];
3493 auto cstOp = index.getDefiningOp<arith::ConstantIndexOp>();
3494 if (!cstOp)
3495 return false;
3496
3497 int64_t sourceSize = op.getShapedType().getDimSize(indicesIdx);
3498 int64_t vectorSize = op.getVectorType().getDimSize(resultIdx);
3499
3500 return cstOp.value() + vectorSize <= sourceSize;
3501}
3502
3503template <typename TransferOp>
3504static LogicalResult foldTransferInBoundsAttribute(TransferOp op) {
3505 // TODO: support 0-d corner case.
3506 // TODO: Be less conservative.
3507 if (op.getTransferRank() == 0)
3508 return failure();
3509 AffineMap permutationMap = op.getPermutationMap();
3510 bool changed = false;
3511 SmallVector<bool, 4> newInBounds;
3512 newInBounds.reserve(op.getTransferRank());
3513 for (unsigned i = 0; i < op.getTransferRank(); ++i) {
3514 // Already marked as in-bounds, nothing to see here.
3515 if (op.isDimInBounds(i)) {
3516 newInBounds.push_back(true);
3517 continue;
3518 }
3519 // Currently out-of-bounds, check whether we can statically determine it is
3520 // inBounds.
3521 auto dimExpr = permutationMap.getResult(i).dyn_cast<AffineDimExpr>();
3522 assert(dimExpr && "Broadcast dims must be in-bounds")(static_cast <bool> (dimExpr && "Broadcast dims must be in-bounds"
) ? void (0) : __assert_fail ("dimExpr && \"Broadcast dims must be in-bounds\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 3522, __extension__
__PRETTY_FUNCTION__))
;
3523 auto inBounds =
3524 isInBounds(op, /*resultIdx=*/i, /*indicesIdx=*/dimExpr.getPosition());
3525 newInBounds.push_back(inBounds);
3526 // We commit the pattern if it is "more inbounds".
3527 changed |= inBounds;
3528 }
3529 if (!changed)
3530 return failure();
3531 // OpBuilder is only used as a helper to build an I64ArrayAttr.
3532 OpBuilder b(op.getContext());
3533 op->setAttr(TransferOp::getInBoundsAttrStrName(),
3534 b.getBoolArrayAttr(newInBounds));
3535 return success();
3536}
3537
3538/// ```
3539/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
3540/// : vector<1x4xf32>, tensor<4x4xf32>
3541/// %0 = vector.transfer_read %w0[%c1, %c0], %cf0 {in_bounds = [true, true]}
3542/// : tensor<4x4xf32>, vector<1x4xf32>
3543/// ```
3544/// -> Folds into
3545/// ```
3546/// %v0
3547/// ```
3548static Value foldRAW(TransferReadOp readOp) {
3549 if (!readOp.getShapedType().isa<RankedTensorType>())
3550 return {};
3551 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3552 while (defWrite) {
3553 if (checkSameValueRAW(defWrite, readOp))
3554 return defWrite.getVector();
3555 if (!isDisjointTransferIndices(
3556 cast<VectorTransferOpInterface>(defWrite.getOperation()),
3557 cast<VectorTransferOpInterface>(readOp.getOperation())))
3558 break;
3559 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
3560 }
3561 return {};
3562}
3563
3564OpFoldResult TransferReadOp::fold(ArrayRef<Attribute>) {
3565 if (Value vec = foldRAW(*this))
3566 return vec;
3567 /// transfer_read(memrefcast) -> transfer_read
3568 if (succeeded(foldTransferInBoundsAttribute(*this)))
3569 return getResult();
3570 if (succeeded(memref::foldMemRefCast(*this)))
3571 return getResult();
3572 if (succeeded(tensor::foldTensorCast(*this)))
3573 return getResult();
3574 return OpFoldResult();
3575}
3576
3577std::optional<SmallVector<int64_t, 4>> TransferReadOp::getShapeForUnroll() {
3578 return llvm::to_vector<4>(getVectorType().getShape());
3579}
3580
3581void TransferReadOp::getEffects(
3582 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
3583 &effects) {
3584 if (getShapedType().isa<MemRefType>())
3585 effects.emplace_back(MemoryEffects::Read::get(), getSource(),
3586 SideEffects::DefaultResource::get());
3587}
3588
3589/// Returns true if all rank reduced in the given `extractOp` happen in leading
3590/// dimensions earlier than last `trailingRank` dimensions.
3591static bool areAllRankReducedLeadingDim(tensor::ExtractSliceOp extractOp,
3592 unsigned trailingRank) {
3593 // If no ranks are reduced at all, it's a degenerated case; always true.
3594 if (extractOp.getSourceType().getRank() == extractOp.getType().getRank())
3595 return true;
3596
3597 RankedTensorType inferredType = extractOp.inferResultType(
3598 extractOp.getSourceType(), extractOp.getMixedOffsets(),
3599 extractOp.getMixedSizes(), extractOp.getMixedStrides());
3600 return extractOp.getType().getShape().take_back(trailingRank) ==
3601 inferredType.getShape().take_back(trailingRank);
3602}
3603
3604namespace {
3605/// Fold transfer_reads of a tensor.extract_slice op. E.g.:
3606///
3607/// ```
3608/// %0 = tensor.extract_slice %t[%a, %b] [%c, %d] [1, 1]
3609/// : tensor<?x?xf32> to tensor<?x?xf32>
3610/// %1 = vector.transfer_read %0[%e, %f], %cst {in_bounds = [true, true]}
3611/// : tensor<?x?xf32>, vector<4x5xf32>
3612/// ```
3613/// is rewritten to:
3614/// ```
3615/// %p0 = arith.addi %a, %e : index
3616/// %p1 = arith.addi %b, %f : index
3617/// %1 = vector.transfer_read %t[%p0, %p1], %cst {in_bounds = [true, true]}
3618/// : tensor<?x?xf32>, vector<4x5xf32>
3619/// ```
3620struct FoldExtractSliceIntoTransferRead
3621 : public OpRewritePattern<TransferReadOp> {
3622public:
3623 using OpRewritePattern::OpRewritePattern;
3624
3625 LogicalResult matchAndRewrite(TransferReadOp xferOp,
3626 PatternRewriter &rewriter) const override {
3627 // TODO: support 0-d corner case.
3628 if (xferOp.getTransferRank() == 0)
3629 return failure();
3630 if (xferOp.hasOutOfBoundsDim())
3631 return failure();
3632 if (!xferOp.getPermutationMap().isMinorIdentity())
3633 return failure();
3634 if (xferOp.getMask())
3635 return failure();
3636 auto extractOp = xferOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
3637 if (!extractOp)
3638 return failure();
3639 if (!extractOp.hasUnitStride())
3640 return failure();
3641
3642 // Bail on illegal rank-reduction: we need to check that the rank-reduced
3643 // dims are exactly the leading dims. I.e. the following is illegal:
3644 // ```
3645 // %0 = tensor.extract_slice %t[0,0,0][2,1,4][1,1,1] :
3646 // tensor<2x1x4xf32> to tensor<2x4xf32>
3647 // %1 = vector.transfer_read %0[0,0], %cst :
3648 // tensor<2x4xf32>, vector<2x4xf32>
3649 // ```
3650 //
3651 // Cannot fold into:
3652 // ```
3653 // %0 = vector.transfer_read %t[0,0,0], %cst :
3654 // tensor<2x1x4xf32>, vector<2x4xf32>
3655 // ```
3656 // For this, check the trailing `vectorRank` dims of the extract_slice
3657 // result tensor match the trailing dims of the inferred result tensor.
3658 if (!areAllRankReducedLeadingDim(extractOp, extractOp.getType().getRank()))
3659 return failure();
3660
3661 int64_t rankReduced =
3662 extractOp.getSourceType().getRank() - extractOp.getType().getRank();
3663
3664 SmallVector<Value> newIndices;
3665 // In case this is a rank-reducing ExtractSliceOp, copy rank-reduced
3666 // indices first.
3667 for (int64_t i = 0; i < rankReduced; ++i) {
3668 OpFoldResult offset = extractOp.getMixedOffsets()[i];
3669 newIndices.push_back(getValueOrCreateConstantIndexOp(
3670 rewriter, extractOp.getLoc(), offset));
3671 }
3672 for (const auto &it : llvm::enumerate(xferOp.getIndices())) {
3673 OpFoldResult offset =
3674 extractOp.getMixedOffsets()[it.index() + rankReduced];
3675 newIndices.push_back(rewriter.create<arith::AddIOp>(
3676 xferOp->getLoc(), it.value(),
3677 getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(),
3678 offset)));
3679 }
3680 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
3681 rewriter.replaceOpWithNewOp<TransferReadOp>(
3682 xferOp, xferOp.getVectorType(), extractOp.getSource(), newIndices,
3683 xferOp.getPadding(), ArrayRef<bool>{inBounds});
3684
3685 return success();
3686 }
3687};
3688
3689/// Store to load forwarding for transfer operations with permuation maps.
3690/// Even if the permutation maps are different we can still propagate the store
3691/// into the load if the size of the dimensions read and written match. Then we
3692/// can replace the transfer_read + transfer_write by vector.broadcast and
3693/// vector.transpose.
3694/// Example:
3695/// ```
3696/// %w0 = vector.transfer_write %v0, %arg0[%c0, %c0, %c0]
3697/// {in_bounds = [true, true],
3698/// permutation_map = affine_map<(d0, d1, d2) -> (d2, d1)>} :
3699/// vector<4x1xf32>, tensor<4x4x4xf32>
3700/// %r = vector.transfer_read %w0[%c0, %c0, %c0], %cf0
3701/// {in_bounds = [true, true, true, true],
3702/// permutation_map = affine_map<(d0, d1, d2) -> (d1, 0, d2, 0)>} :
3703/// tensor<4x4x4xf32>, vector<1x100x4x5xf32>
3704/// ```
3705/// To:
3706/// ```
3707/// %0 = vector.broadcast %arg1 : vector<4x1xf32> to vector<100x5x4x1xf32>
3708/// %r = vector.transpose %0, [3, 0, 2, 1] :
3709/// vector<100x5x4x1xf32> to vector<1x100x4x5xf32>
3710/// ```
3711struct TransferReadAfterWriteToBroadcast
3712 : public OpRewritePattern<TransferReadOp> {
3713 using OpRewritePattern::OpRewritePattern;
3714
3715 LogicalResult matchAndRewrite(TransferReadOp readOp,
3716 PatternRewriter &rewriter) const override {
3717 if (readOp.hasOutOfBoundsDim() ||
3718 !readOp.getShapedType().isa<RankedTensorType>())
3719 return failure();
3720 auto defWrite = readOp.getSource().getDefiningOp<vector::TransferWriteOp>();
3721 if (!defWrite)
3722 return failure();
3723
3724 SmallVector<int64_t> readDims = readOp.getTransferChunkAccessed();
3725 Value vec;
3726 if (readOp.getIndices() == defWrite.getIndices() &&
3727 readOp.getMask() == defWrite.getMask()) {
3728 SmallVector<int64_t> writeDims = defWrite.getTransferChunkAccessed();
3729 // TODO: If the writeDim is a superset of the read dims we could do an
3730 // extract_strided_slice.
3731 if (writeDims == readDims)
3732 vec = defWrite.getVector();
3733 }
3734 // TODO: loop through the chain of transfer_write if we can prove that they
3735 // don't overlap with the transfer_read. This requires improving
3736 // `isDisjointTransferIndices` helper.
3737 if (!vec)
3738 return failure();
3739 SmallVector<unsigned> permutation;
3740 AffineMap readMap = compressUnusedDims(readOp.getPermutationMap());
3741 AffineMap writeMap = compressUnusedDims(defWrite.getPermutationMap());
3742 AffineMap map = readMap.compose(writeMap);
3743 if (map.getNumResults() == 0)
3744 return failure();
3745 // Calculate the permuation to apply to go from the vector stored to the
3746 // vector read.
3747 if (!map.isPermutationOfMinorIdentityWithBroadcasting(permutation))
3748 return failure();
3749
3750 Location loc = readOp.getLoc();
3751 // Calculate the broadcast shape by applying the reverse permuation to the
3752 // final shape we want.
3753 ArrayRef<int64_t> destShape = readOp.getVectorType().getShape();
3754 SmallVector<int64_t> broadcastShape(destShape.size());
3755 for (const auto &pos : llvm::enumerate(permutation))
3756 broadcastShape[pos.value()] = destShape[pos.index()];
3757 VectorType broadcastedType = VectorType::get(
3758 broadcastShape, defWrite.getVectorType().getElementType());
3759 vec = rewriter.create<vector::BroadcastOp>(loc, broadcastedType, vec);
3760 SmallVector<int64_t> transposePerm(permutation.begin(), permutation.end());
3761 rewriter.replaceOpWithNewOp<vector::TransposeOp>(readOp, vec,
3762 transposePerm);
3763 return success();
3764 }
3765};
3766} // namespace
3767
3768void TransferReadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3769 MLIRContext *context) {
3770 results
3771 .add<FoldExtractSliceIntoTransferRead, TransferReadAfterWriteToBroadcast>(
3772 context);
3773}
3774
3775//===----------------------------------------------------------------------===//
3776// TransferWriteOp
3777//===----------------------------------------------------------------------===//
3778
3779/// 1. Builder with type inference.
3780void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3781 Value vector, Value dest, ValueRange indices,
3782 AffineMapAttr permutationMapAttr,
3783 /*optional*/ Value mask,
3784 /*optional*/ ArrayAttr inBoundsAttr) {
3785 Type resultType = dest.getType().dyn_cast<RankedTensorType>();
3786 build(builder, result, resultType, vector, dest, indices, permutationMapAttr,
3787 mask, inBoundsAttr);
3788}
3789
3790/// 2. Builder with type inference that sets an empty mask (variant with attrs).
3791void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3792 Value vector, Value dest, ValueRange indices,
3793 AffineMapAttr permutationMapAttr,
3794 /*optional*/ ArrayAttr inBoundsAttr) {
3795 build(builder, result, vector, dest, indices, permutationMapAttr,
3796 /*mask=*/Value(), inBoundsAttr);
3797}
3798
3799/// 3. Builder with type inference that sets an empty mask (variant without
3800/// attrs)
3801void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3802 Value vector, Value dest, ValueRange indices,
3803 AffineMap permutationMap,
3804 std::optional<ArrayRef<bool>> inBounds) {
3805 auto permutationMapAttr = AffineMapAttr::get(permutationMap);
3806 auto inBoundsAttr = (inBounds && !inBounds.value().empty())
3807 ? builder.getBoolArrayAttr(inBounds.value())
3808 : ArrayAttr();
3809 build(builder, result, vector, dest, indices, permutationMapAttr,
3810 /*mask=*/Value(), inBoundsAttr);
3811}
3812
3813/// 4. Builder with type inference that sets an empty mask and sets permutation
3814/// map to 'getMinorIdentityMap'.
3815void TransferWriteOp::build(OpBuilder &builder, OperationState &result,
3816 Value vector, Value dest, ValueRange indices,
3817 std::optional<ArrayRef<bool>> inBounds) {
3818 auto vectorType = vector.getType().cast<VectorType>();
3819 AffineMap permutationMap = getTransferMinorIdentityMap(
3820 dest.getType().cast<ShapedType>(), vectorType);
3821 build(builder, result, vector, dest, indices, permutationMap, inBounds);
3822}
3823
3824/// Infers the mask type for a transfer write given its vector type and
3825/// permutation map. The mask in a transfer read operation applies to the
3826/// tensor/buffer writing part of it and its type should match the shape written
3827/// *after* any permutation.
3828static VectorType inferTransferWriteMaskType(VectorType vecType,
3829 AffineMap permMap) {
3830 auto i1Type = IntegerType::get(permMap.getContext(), 1);
3831 SmallVector<int64_t, 8> maskShape =
3832 compressUnusedDims(permMap).compose(vecType.getShape());
3833 return VectorType::get(maskShape, i1Type);
3834}
3835
3836ParseResult TransferWriteOp::parse(OpAsmParser &parser,
3837 OperationState &result) {
3838 auto &builder = parser.getBuilder();
3839 SMLoc typesLoc;
3840 OpAsmParser::UnresolvedOperand vectorInfo, sourceInfo;
3841 SmallVector<OpAsmParser::UnresolvedOperand, 8> indexInfo;
3842 SmallVector<Type, 2> types;
3843 OpAsmParser::UnresolvedOperand maskInfo;
3844 if (parser.parseOperand(vectorInfo) || parser.parseComma() ||
3845 parser.parseOperand(sourceInfo) ||
3846 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square))
3847 return failure();
3848 ParseResult hasMask = parser.parseOptionalComma();
3849 if (hasMask.succeeded() && parser.parseOperand(maskInfo))
3850 return failure();
3851 if (parser.parseOptionalAttrDict(result.attributes) ||
3852 parser.getCurrentLocation(&typesLoc) || parser.parseColonTypeList(types))
3853 return failure();
3854 if (types.size() != 2)
3855 return parser.emitError(typesLoc, "requires two types");
3856 auto indexType = builder.getIndexType();
3857 VectorType vectorType = types[0].dyn_cast<VectorType>();
3858 if (!vectorType)
3859 return parser.emitError(typesLoc, "requires vector type");
3860 ShapedType shapedType = types[1].dyn_cast<ShapedType>();
3861 if (!shapedType || !shapedType.isa<MemRefType, RankedTensorType>())
3862 return parser.emitError(typesLoc, "requires memref or ranked tensor type");
3863 auto permMapAttrName = TransferWriteOp::getPermutationMapAttrStrName();
3864 auto permMapAttr = result.attributes.get(permMapAttrName);
3865 AffineMap permMap;
3866 if (!permMapAttr) {
3867 permMap = getTransferMinorIdentityMap(shapedType, vectorType);
3868 result.attributes.set(permMapAttrName, AffineMapAttr::get(permMap));
3869 } else {
3870 permMap = permMapAttr.cast<AffineMapAttr>().getValue();
3871 }
3872 if (parser.resolveOperand(vectorInfo, vectorType, result.operands) ||
3873 parser.resolveOperand(sourceInfo, shapedType, result.operands) ||
3874 parser.resolveOperands(indexInfo, indexType, result.operands))
3875 return failure();
3876 if (hasMask.succeeded()) {
3877 if (shapedType.getElementType().dyn_cast<VectorType>())
3878 return parser.emitError(
3879 maskInfo.location, "does not support masks with vector element type");
3880 auto maskType = inferTransferWriteMaskType(vectorType, permMap);
3881 if (parser.resolveOperand(maskInfo, maskType, result.operands))
3882 return failure();
3883 }
3884 result.addAttribute(TransferWriteOp::getOperandSegmentSizeAttr(),
3885 builder.getDenseI32ArrayAttr(
3886 {1, 1, static_cast<int32_t>(indexInfo.size()),
3887 static_cast<int32_t>(hasMask.succeeded())}));
3888 return failure(shapedType.isa<RankedTensorType>() &&
3889 parser.addTypeToList(shapedType, result.types));
3890}
3891
3892void TransferWriteOp::print(OpAsmPrinter &p) {
3893 p << " " << getVector() << ", " << getSource() << "[" << getIndices() << "]";
3894 if (getMask())
3895 p << ", " << getMask();
3896 printTransferAttrs(p, *this);
3897 p << " : " << getVectorType() << ", " << getShapedType();
3898}
3899
3900LogicalResult TransferWriteOp::verify() {
3901 // Consistency of elemental types in shape and vector.
3902 ShapedType shapedType = getShapedType();
3903 VectorType vectorType = getVectorType();
3904 VectorType maskType = getMaskType();
3905 auto permutationMap = getPermutationMap();
3906 VectorType inferredMaskType =
3907 maskType ? inferTransferWriteMaskType(vectorType, permutationMap)
3908 : VectorType();
3909
3910 if (llvm::size(getIndices()) != shapedType.getRank())
3911 return emitOpError("requires ") << shapedType.getRank() << " indices";
3912
3913 // We do not allow broadcast dimensions on TransferWriteOps for the moment,
3914 // as the semantics is unclear. This can be revisited later if necessary.
3915 if (hasBroadcastDim())
3916 return emitOpError("should not have broadcast dimensions");
3917
3918 if (failed(verifyTransferOp(cast<VectorTransferOpInterface>(getOperation()),
3919 shapedType, vectorType, maskType,
3920 inferredMaskType, permutationMap,
3921 getInBounds() ? *getInBounds() : ArrayAttr())))
3922 return failure();
3923
3924 return verifyPermutationMap(permutationMap,
3925 [&](Twine t) { return emitOpError(t); });
3926}
3927
3928// MaskableOpInterface methods.
3929
3930/// Returns the mask type expected by this operation. Mostly used for
3931/// verification purposes.
3932Type TransferWriteOp::getExpectedMaskType() {
3933 return inferTransferWriteMaskType(getVectorType(), getPermutationMap());
3934}
3935
3936/// Fold:
3937/// ```
3938/// %t1 = ...
3939/// %v = vector.transfer_read %t0[%c0...], {in_bounds = [true...]} :
3940/// tensor<static_sizesxf32>, vector<static_sizesxf32>
3941/// %t2 = vector.transfer_write %v, %t1[%c0...] {in_bounds = [true...]} :
3942/// vector<static_sizesxf32>, tensor<static_sizesxf32>
3943/// ```
3944///
3945/// into:
3946///
3947/// ```
3948/// %t0
3949/// ```
3950///
3951/// The producer of t1 may or may not be DCE'd depending on whether it is a
3952/// block argument or has side effects.
3953static LogicalResult foldReadInitWrite(TransferWriteOp write,
3954 ArrayRef<Attribute>,
3955 SmallVectorImpl<OpFoldResult> &results) {
3956 // TODO: support 0-d corner case.
3957 if (write.getTransferRank() == 0)
3958 return failure();
3959 auto rankedTensorType =
3960 write.getSource().getType().dyn_cast<RankedTensorType>();
3961 // If not operating on tensors, bail.
3962 if (!rankedTensorType)
3963 return failure();
3964 // If no read, bail.
3965 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
3966 if (!read)
3967 return failure();
3968 // TODO: support 0-d corner case.
3969 if (read.getTransferRank() == 0)
3970 return failure();
3971 // For now, only accept minor identity. Future: composition is minor identity.
3972 if (!read.getPermutationMap().isMinorIdentity() ||
3973 !write.getPermutationMap().isMinorIdentity())
3974 return failure();
3975 // Bail on mismatching ranks.
3976 if (read.getTransferRank() != write.getTransferRank())
3977 return failure();
3978 // Bail on potential out-of-bounds accesses.
3979 if (read.hasOutOfBoundsDim() || write.hasOutOfBoundsDim())
3980 return failure();
3981 // Tensor types must be the same.
3982 if (read.getSource().getType() != rankedTensorType)
3983 return failure();
3984 // Vector types must be the same.
3985 if (read.getVectorType() != write.getVectorType())
3986 return failure();
3987 // Vector and Tensor shapes must match.
3988 if (read.getVectorType().getShape() != rankedTensorType.getShape())
3989 return failure();
3990 // If any index is nonzero.
3991 auto isNotConstantZero = [](Value v) {
3992 auto cstOp = v.getDefiningOp<arith::ConstantIndexOp>();
3993 return !cstOp || cstOp.value() != 0;
3994 };
3995 if (llvm::any_of(read.getIndices(), isNotConstantZero) ||
3996 llvm::any_of(write.getIndices(), isNotConstantZero))
3997 return failure();
3998 // Success.
3999 results.push_back(read.getSource());
4000 return success();
4001}
4002
4003static bool checkSameValueWAR(vector::TransferReadOp read,
4004 vector::TransferWriteOp write) {
4005 return read.getSource() == write.getSource() &&
4006 read.getIndices() == write.getIndices() &&
4007 read.getPermutationMap() == write.getPermutationMap() &&
4008 read.getVectorType() == write.getVectorType() && !read.getMask() &&
4009 !write.getMask();
4010}
4011/// Fold transfer_write write after read:
4012/// ```
4013/// %t0 = ...
4014/// %v = vector.transfer_read %t0[%c0...] :
4015/// tensor<static_sizesxf32>, vector<static_sizesxf32>
4016/// %t1 = vector.transfer_write %v, %t0[%c0...] :
4017/// vector<static_sizesxf32>, tensor<static_sizesxf32>
4018/// ```
4019///
4020/// into:
4021///
4022/// ```
4023/// %t0
4024/// ```
4025static LogicalResult foldWAR(TransferWriteOp write,
4026 SmallVectorImpl<OpFoldResult> &results) {
4027 if (!write.getSource().getType().isa<RankedTensorType>())
4028 return failure();
4029 auto read = write.getVector().getDefiningOp<vector::TransferReadOp>();
4030 if (!read)
4031 return failure();
4032
4033 if (!checkSameValueWAR(read, write))
4034 return failure();
4035 results.push_back(read.getSource());
4036 return success();
4037}
4038
4039LogicalResult TransferWriteOp::fold(ArrayRef<Attribute> operands,
4040 SmallVectorImpl<OpFoldResult> &results) {
4041 if (succeeded(foldReadInitWrite(*this, operands, results)))
4042 return success();
4043 if (succeeded(foldWAR(*this, results)))
4044 return success();
4045 if (succeeded(foldTransferInBoundsAttribute(*this)))
4046 return success();
4047 return memref::foldMemRefCast(*this);
4048}
4049
4050std::optional<SmallVector<int64_t, 4>> TransferWriteOp::getShapeForUnroll() {
4051 return llvm::to_vector<4>(getVectorType().getShape());
4052}
4053
4054void TransferWriteOp::getEffects(
4055 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
4056 &effects) {
4057 if (getShapedType().isa<MemRefType>())
4058 effects.emplace_back(MemoryEffects::Write::get(), getSource(),
4059 SideEffects::DefaultResource::get());
4060}
4061
4062namespace {
4063/// Remove dead transfer write from the SSA chain so that it an be eliminated by
4064/// DCE
4065/// ```
4066/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4067/// : vector<1x4xf32>, tensor<4x4xf32>
4068/// %w1 = vector.transfer_write %v0, %w0[%c2, %c0] {in_bounds = [true, true]}
4069/// : vector<1x4xf32>, tensor<4x4xf32>
4070/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4071/// : vector<1x4xf32>, tensor<4x4xf32>
4072/// ```
4073///
4074/// into:
4075///
4076/// ```
4077/// %w0 = vector.transfer_write %v0, %arg0[%c1, %c0] {in_bounds = [true, true]}
4078/// : vector<1x4xf32>, tensor<4x4xf32>
4079/// %w1 = vector.transfer_write %v0, %arg0[%c2, %c0] {in_bounds = [true, true]}
4080/// : vector<1x4xf32>, tensor<4x4xf32>
4081/// %w2 = vector.transfer_write %v1, %w1[%c1, %c0] {in_bounds = [true, true]}
4082/// : vector<1x4xf32>, tensor<4x4xf32>
4083/// ```
4084///
4085/// `%w0 = vector.transfer_write` op will be removed by DCE if it doesn't have
4086/// any other uses.
4087class FoldWaw final : public OpRewritePattern<TransferWriteOp> {
4088public:
4089 using OpRewritePattern::OpRewritePattern;
4090 LogicalResult matchAndRewrite(TransferWriteOp writeOp,
4091 PatternRewriter &rewriter) const override {
4092 if (!writeOp.getShapedType().isa<RankedTensorType>())
4093 return failure();
4094 vector::TransferWriteOp writeToModify = writeOp;
4095
4096 auto defWrite =
4097 writeOp.getSource().getDefiningOp<vector::TransferWriteOp>();
4098 while (defWrite) {
4099 if (checkSameValueWAW(writeOp, defWrite)) {
4100 writeToModify.getSourceMutable().assign(defWrite.getSource());
4101 return success();
4102 }
4103 if (!isDisjointTransferIndices(
4104 cast<VectorTransferOpInterface>(defWrite.getOperation()),
4105 cast<VectorTransferOpInterface>(writeOp.getOperation())))
4106 break;
4107 // If the previous write op doesn't have any other use we an safely look
4108 // at the previous store to see if it can be removed.
4109 if (!defWrite->hasOneUse())
4110 break;
4111 writeToModify = defWrite;
4112 defWrite = defWrite.getSource().getDefiningOp<vector::TransferWriteOp>();
4113 }
4114 return failure();
4115 }
4116};
4117
4118/// Fold tensor.insert_slice into vector.transfer_write if the transfer_write
4119/// could directly write to the insert_slice's destination. E.g.:
4120///
4121/// ```
4122/// %0 = vector.transfer_write %v, %t1[%c0, %c0] {in_bounds = [true, true]}
4123/// : vector<4x5xf32>, tensor<4x5xf32>
4124/// %1 = tensor.insert_slice %0 into %t2[%a, %b] [4, 5] [1, 1]
4125/// : tensor<4x5xf32> into tensor<?x?xf32>
4126/// ```
4127/// is rewritten to:
4128/// ```
4129/// %1 = vector.transfer_write %v, %t2[%a, %b] {in_bounds = [true, true]}
4130/// : vector<4x5xf32>, tensor<?x?xf32>
4131/// ```
4132struct FoldInsertSliceIntoTransferWrite
4133 : public OpRewritePattern<tensor::InsertSliceOp> {
4134public:
4135 using OpRewritePattern::OpRewritePattern;
4136
4137 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4138 PatternRewriter &rewriter) const override {
4139 if (!insertOp.hasUnitStride())
4140 return failure();
4141
4142 auto xferOp = insertOp.getSource().getDefiningOp<TransferWriteOp>();
4143 if (!xferOp)
4144 return failure();
4145 // TODO: support 0-d corner case.
4146 if (xferOp.getTransferRank() == 0)
4147 return failure();
4148
4149 if (xferOp.hasOutOfBoundsDim())
4150 return failure();
4151 if (xferOp.getVectorType().getRank() != xferOp.getShapedType().getRank())
4152 return failure();
4153 if (xferOp.getMask())
4154 return failure();
4155 // Fold only if the TransferWriteOp completely overwrites the `source` with
4156 // a vector. I.e., the result of the TransferWriteOp is a new tensor whose
4157 // content is the data of the vector.
4158 if (!llvm::equal(xferOp.getVectorType().getShape(),
4159 xferOp.getShapedType().getShape()))
4160 return failure();
4161 if (!xferOp.getPermutationMap().isIdentity())
4162 return failure();
4163
4164 // Bail on illegal rank-reduction: we need to check that the rank-reduced
4165 // dims are exactly the leading dims. I.e. the following is illegal:
4166 // ```
4167 // %0 = vector.transfer_write %v, %t[0,0], %cst :
4168 // vector<2x4xf32>, tensor<2x4xf32>
4169 // %1 = tensor.insert_slice %0 into %tt[0,0,0][2,1,4][1,1,1] :
4170 // tensor<2x4xf32> into tensor<2x1x4xf32>
4171 // ```
4172 //
4173 // Cannot fold into:
4174 // ```
4175 // %0 = vector.transfer_write %v, %t[0,0,0], %cst :
4176 // vector<2x4xf32>, tensor<2x1x4xf32>
4177 // ```
4178 // For this, check the trailing `vectorRank` dims of the insert_slice result
4179 // tensor match the trailing dims of the inferred result tensor.
4180 int64_t rankReduced =
4181 insertOp.getType().getRank() - insertOp.getSourceType().getRank();
4182 int64_t vectorRank = xferOp.getVectorType().getRank();
4183 RankedTensorType inferredSourceTensorType =
4184 tensor::ExtractSliceOp::inferResultType(
4185 insertOp.getType(), insertOp.getMixedOffsets(),
4186 insertOp.getMixedSizes(), insertOp.getMixedStrides());
4187 auto actualSourceTensorShape = insertOp.getSourceType().getShape();
4188 if (rankReduced > 0 &&
4189 actualSourceTensorShape.take_back(vectorRank) !=
4190 inferredSourceTensorType.getShape().take_back(vectorRank))
4191 return failure();
4192
4193 SmallVector<Value> indices = getValueOrCreateConstantIndexOp(
4194 rewriter, insertOp.getLoc(), insertOp.getMixedOffsets());
4195 SmallVector<bool> inBounds(xferOp.getTransferRank(), true);
4196 rewriter.replaceOpWithNewOp<TransferWriteOp>(insertOp, xferOp.getVector(),
4197 insertOp.getDest(), indices,
4198 ArrayRef<bool>{inBounds});
4199 return success();
4200 }
4201};
4202
4203/// Rewrite tensor::ExtractSliceOp(vector::TransferWriteOp) to
4204/// vector::TransferWriteOp(tensor::ExtractSliceOp) if the full slice is
4205/// overwritten and inserted into another tensor. After this rewrite, the
4206/// operations bufferize in-place since all of them work on the same slice.
4207///
4208/// For example:
4209/// ```mlir
4210/// %0 = vector.transfer_write %vec, %init_tensor[%c0, %c0]
4211/// : vector<8x16xf32>, tensor<8x16xf32>
4212/// %1 = tensor.extract_slice %0[0, 0] [%sz0, %sz1] [1, 1]
4213/// : tensor<8x16xf32> to tensor<?x?xf32>
4214/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4215/// : tensor<?x?xf32> into tensor<27x37xf32>
4216/// ```
4217/// folds to
4218/// ```mlir
4219/// %0 = tensor.extract_slice %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4220/// : tensor<27x37xf32> to tensor<?x?xf32>
4221/// %1 = vector.transfer_write %vec, %0[%c0, %c0]
4222/// : vector<8x16xf32>, tensor<?x?xf32>
4223/// %r = tensor.insert_slice %1 into %iter_arg[%iv0, %iv1] [%sz0, %sz1] [1, 1]
4224/// : tensor<?x?xf32> into tensor<27x37xf32>
4225/// ```
4226struct SwapExtractSliceOfTransferWrite
4227 : public OpRewritePattern<tensor::InsertSliceOp> {
4228public:
4229 using OpRewritePattern::OpRewritePattern;
4230
4231 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
4232 PatternRewriter &rewriter) const override {
4233 if (!insertOp.hasUnitStride())
4234 return failure();
4235 auto extractOp =
4236 insertOp.getSource().getDefiningOp<tensor::ExtractSliceOp>();
4237 if (!extractOp || !extractOp.hasUnitStride() || !extractOp->hasOneUse())
4238 return failure();
4239 auto transferOp = extractOp.getSource().getDefiningOp<TransferWriteOp>();
4240 if (!transferOp || !transferOp->hasOneUse())
4241 return failure();
4242
4243 // Fail if vector::TransferWriteOp or tensor::ExtractSliceOp is
4244 // rank-reducing.
4245 if (insertOp.getSourceType().getRank() != transferOp.getTransferRank()) {
4246 return rewriter.notifyMatchFailure(insertOp,
4247 "use-def chain is rank-reducing");
4248 }
4249
4250 // Fail if tensor::ExtractSliceOp has non-zero offset.
4251 if (!extractOp.hasZeroOffset()) {
4252 return rewriter.notifyMatchFailure(insertOp,
4253 "ExtractSliceOp has non-zero offset");
4254 }
4255
4256 // Fail if tensor::TransferWriteOp has non-zero offset.
4257 if (!llvm::all_of(transferOp.getIndices(), [](Value value) {
4258 return getConstantIntValue(value) == static_cast<int64_t>(0);
4259 })) {
4260 return rewriter.notifyMatchFailure(insertOp,
4261 "TranferWriteOp has non-zero offset");
4262 }
4263
4264 // Fail if tensor::ExtractSliceOp and tensor::InsertSliceOp sizes differ.
4265 for (auto [insertSize, extractSize] :
4266 llvm::zip_equal(insertOp.getMixedSizes(), extractOp.getMixedSizes())) {
4267 if (!isEqualConstantIntOrValue(insertSize, extractSize)) {
4268 return rewriter.notifyMatchFailure(
4269 insertOp, "InsertSliceOp and ExtractSliceOp sizes differ");
4270 }
4271 }
4272
4273 // Fail if the vector::TransferWriteOp may not overwrite the full tensor.
4274 assert(transferOp.getVectorType().hasStaticShape() &&(static_cast <bool> (transferOp.getVectorType().hasStaticShape
() && "expected vector to have a static shape") ? void
(0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4275, __extension__
__PRETTY_FUNCTION__))
4275 "expected vector to have a static shape")(static_cast <bool> (transferOp.getVectorType().hasStaticShape
() && "expected vector to have a static shape") ? void
(0) : __assert_fail ("transferOp.getVectorType().hasStaticShape() && \"expected vector to have a static shape\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4275, __extension__
__PRETTY_FUNCTION__))
;
4276 ArrayRef<int64_t> vectorShape = transferOp.getVectorType().getShape();
4277 SmallVector<int64_t> resultShape = applyPermutationMap(
4278 transferOp.getPermutationMap(), transferOp.getShapedType().getShape());
4279 if (transferOp.getMask() || !vectorShape.equals(resultShape)) {
4280 return rewriter.notifyMatchFailure(
4281 insertOp, "TransferWriteOp may not write the full tensor.");
4282 }
4283
4284 // Swap the tensor::ExtractSliceOp in front of the vector::TransferWriteOp.
4285 SmallVector<int64_t> newResultShape = applyPermutationMap(
4286 transferOp.getPermutationMap(), insertOp.getSourceType().getShape());
4287 SmallVector<bool> newInBounds;
4288 for (const auto &en : enumerate(newResultShape))
4289 newInBounds.push_back(en.value() == vectorShape[en.index()]);
4290 auto newExtractOp = rewriter.create<tensor::ExtractSliceOp>(
4291 extractOp.getLoc(), insertOp.getSourceType(), insertOp.getDest(),
4292 insertOp.getMixedOffsets(), insertOp.getMixedSizes(),
4293 insertOp.getMixedStrides());
4294 auto newTransferWriteOp = rewriter.create<TransferWriteOp>(
4295 transferOp.getLoc(), transferOp.getVector(), newExtractOp.getResult(),
4296 transferOp.getIndices(), transferOp.getPermutationMapAttr(),
4297 rewriter.getBoolArrayAttr(newInBounds));
4298 rewriter.updateRootInPlace(insertOp, [&]() {
4299 insertOp.getSourceMutable().assign(newTransferWriteOp.getResult());
4300 });
4301 return success();
4302 }
4303};
4304
4305} // namespace
4306
4307void TransferWriteOp::getCanonicalizationPatterns(RewritePatternSet &results,
4308 MLIRContext *context) {
4309 results.add<FoldWaw, FoldInsertSliceIntoTransferWrite,
4310 SwapExtractSliceOfTransferWrite>(context);
4311}
4312
4313//===----------------------------------------------------------------------===//
4314// LoadOp
4315//===----------------------------------------------------------------------===//
4316
4317static LogicalResult verifyLoadStoreMemRefLayout(Operation *op,
4318 MemRefType memRefTy) {
4319 if (!isLastMemrefDimUnitStride(memRefTy))
4320 return op->emitOpError("most minor memref dim must have unit stride");
4321 return success();
4322}
4323
4324LogicalResult vector::LoadOp::verify() {
4325 VectorType resVecTy = getVectorType();
4326 MemRefType memRefTy = getMemRefType();
4327
4328 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4329 return failure();
4330
4331 // Checks for vector memrefs.
4332 Type memElemTy = memRefTy.getElementType();
4333 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
4334 if (memVecTy != resVecTy)
4335 return emitOpError("base memref and result vector types should match");
4336 memElemTy = memVecTy.getElementType();
4337 }
4338
4339 if (resVecTy.getElementType() != memElemTy)
4340 return emitOpError("base and result element types should match");
4341 if (llvm::size(getIndices()) != memRefTy.getRank())
4342 return emitOpError("requires ") << memRefTy.getRank() << " indices";
4343 return success();
4344}
4345
4346OpFoldResult LoadOp::fold(ArrayRef<Attribute>) {
4347 if (succeeded(memref::foldMemRefCast(*this)))
4348 return getResult();
4349 return OpFoldResult();
4350}
4351
4352//===----------------------------------------------------------------------===//
4353// StoreOp
4354//===----------------------------------------------------------------------===//
4355
4356LogicalResult vector::StoreOp::verify() {
4357 VectorType valueVecTy = getVectorType();
4358 MemRefType memRefTy = getMemRefType();
4359
4360 if (failed(verifyLoadStoreMemRefLayout(*this, memRefTy)))
4361 return failure();
4362
4363 // Checks for vector memrefs.
4364 Type memElemTy = memRefTy.getElementType();
4365 if (auto memVecTy = memElemTy.dyn_cast<VectorType>()) {
4366 if (memVecTy != valueVecTy)
4367 return emitOpError(
4368 "base memref and valueToStore vector types should match");
4369 memElemTy = memVecTy.getElementType();
4370 }
4371
4372 if (valueVecTy.getElementType() != memElemTy)
4373 return emitOpError("base and valueToStore element type should match");
4374 if (llvm::size(getIndices()) != memRefTy.getRank())
4375 return emitOpError("requires ") << memRefTy.getRank() << " indices";
4376 return success();
4377}
4378
4379LogicalResult StoreOp::fold(ArrayRef<Attribute> operands,
4380 SmallVectorImpl<OpFoldResult> &results) {
4381 return memref::foldMemRefCast(*this);
4382}
4383
4384//===----------------------------------------------------------------------===//
4385// MaskedLoadOp
4386//===----------------------------------------------------------------------===//
4387
4388LogicalResult MaskedLoadOp::verify() {
4389 VectorType maskVType = getMaskVectorType();
4390 VectorType passVType = getPassThruVectorType();
4391 VectorType resVType = getVectorType();
4392 MemRefType memType = getMemRefType();
4393
4394 if (resVType.getElementType() != memType.getElementType())
4395 return emitOpError("base and result element type should match");
4396 if (llvm::size(getIndices()) != memType.getRank())
4397 return emitOpError("requires ") << memType.getRank() << " indices";
4398 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4399 return emitOpError("expected result dim to match mask dim");
4400 if (resVType != passVType)
4401 return emitOpError("expected pass_thru of same type as result type");
4402 return success();
4403}
4404
4405namespace {
4406class MaskedLoadFolder final : public OpRewritePattern<MaskedLoadOp> {
4407public:
4408 using OpRewritePattern::OpRewritePattern;
4409 LogicalResult matchAndRewrite(MaskedLoadOp load,
4410 PatternRewriter &rewriter) const override {
4411 switch (getMaskFormat(load.getMask())) {
4412 case MaskFormat::AllTrue:
4413 rewriter.replaceOpWithNewOp<vector::LoadOp>(
4414 load, load.getType(), load.getBase(), load.getIndices());
4415 return success();
4416 case MaskFormat::AllFalse:
4417 rewriter.replaceOp(load, load.getPassThru());
4418 return success();
4419 case MaskFormat::Unknown:
4420 return failure();
4421 }
4422 llvm_unreachable("Unexpected 1DMaskFormat on MaskedLoad")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedLoad"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4422)
;
4423 }
4424};
4425} // namespace
4426
4427void MaskedLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4428 MLIRContext *context) {
4429 results.add<MaskedLoadFolder>(context);
4430}
4431
4432OpFoldResult MaskedLoadOp::fold(ArrayRef<Attribute>) {
4433 if (succeeded(memref::foldMemRefCast(*this)))
4434 return getResult();
4435 return OpFoldResult();
4436}
4437
4438//===----------------------------------------------------------------------===//
4439// MaskedStoreOp
4440//===----------------------------------------------------------------------===//
4441
4442LogicalResult MaskedStoreOp::verify() {
4443 VectorType maskVType = getMaskVectorType();
4444 VectorType valueVType = getVectorType();
4445 MemRefType memType = getMemRefType();
4446
4447 if (valueVType.getElementType() != memType.getElementType())
4448 return emitOpError("base and valueToStore element type should match");
4449 if (llvm::size(getIndices()) != memType.getRank())
4450 return emitOpError("requires ") << memType.getRank() << " indices";
4451 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4452 return emitOpError("expected valueToStore dim to match mask dim");
4453 return success();
4454}
4455
4456namespace {
4457class MaskedStoreFolder final : public OpRewritePattern<MaskedStoreOp> {
4458public:
4459 using OpRewritePattern::OpRewritePattern;
4460 LogicalResult matchAndRewrite(MaskedStoreOp store,
4461 PatternRewriter &rewriter) const override {
4462 switch (getMaskFormat(store.getMask())) {
4463 case MaskFormat::AllTrue:
4464 rewriter.replaceOpWithNewOp<vector::StoreOp>(
4465 store, store.getValueToStore(), store.getBase(), store.getIndices());
4466 return success();
4467 case MaskFormat::AllFalse:
4468 rewriter.eraseOp(store);
4469 return success();
4470 case MaskFormat::Unknown:
4471 return failure();
4472 }
4473 llvm_unreachable("Unexpected 1DMaskFormat on MaskedStore")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on MaskedStore"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4473)
;
4474 }
4475};
4476} // namespace
4477
4478void MaskedStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4479 MLIRContext *context) {
4480 results.add<MaskedStoreFolder>(context);
4481}
4482
4483LogicalResult MaskedStoreOp::fold(ArrayRef<Attribute> operands,
4484 SmallVectorImpl<OpFoldResult> &results) {
4485 return memref::foldMemRefCast(*this);
4486}
4487
4488//===----------------------------------------------------------------------===//
4489// GatherOp
4490//===----------------------------------------------------------------------===//
4491
4492LogicalResult GatherOp::verify() {
4493 VectorType indVType = getIndexVectorType();
4494 VectorType maskVType = getMaskVectorType();
4495 VectorType resVType = getVectorType();
4496 ShapedType baseType = getBaseType();
4497
4498 if (!baseType.isa<MemRefType, RankedTensorType>())
4499 return emitOpError("requires base to be a memref or ranked tensor type");
4500
4501 if (resVType.getElementType() != baseType.getElementType())
4502 return emitOpError("base and result element type should match");
4503 if (llvm::size(getIndices()) != baseType.getRank())
4504 return emitOpError("requires ") << baseType.getRank() << " indices";
4505 if (resVType.getShape() != indVType.getShape())
4506 return emitOpError("expected result dim to match indices dim");
4507 if (resVType.getShape() != maskVType.getShape())
4508 return emitOpError("expected result dim to match mask dim");
4509 if (resVType != getPassThruVectorType())
4510 return emitOpError("expected pass_thru of same type as result type");
4511 return success();
4512}
4513
4514namespace {
4515class GatherFolder final : public OpRewritePattern<GatherOp> {
4516public:
4517 using OpRewritePattern::OpRewritePattern;
4518 LogicalResult matchAndRewrite(GatherOp gather,
4519 PatternRewriter &rewriter) const override {
4520 switch (getMaskFormat(gather.getMask())) {
4521 case MaskFormat::AllTrue:
4522 return failure(); // no unmasked equivalent
4523 case MaskFormat::AllFalse:
4524 rewriter.replaceOp(gather, gather.getPassThru());
4525 return success();
4526 case MaskFormat::Unknown:
4527 return failure();
4528 }
4529 llvm_unreachable("Unexpected 1DMaskFormat on GatherFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on GatherFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4529)
;
4530 }
4531};
4532} // namespace
4533
4534void GatherOp::getCanonicalizationPatterns(RewritePatternSet &results,
4535 MLIRContext *context) {
4536 results.add<GatherFolder>(context);
4537}
4538
4539//===----------------------------------------------------------------------===//
4540// ScatterOp
4541//===----------------------------------------------------------------------===//
4542
4543LogicalResult ScatterOp::verify() {
4544 VectorType indVType = getIndexVectorType();
4545 VectorType maskVType = getMaskVectorType();
4546 VectorType valueVType = getVectorType();
4547 MemRefType memType = getMemRefType();
4548
4549 if (valueVType.getElementType() != memType.getElementType())
4550 return emitOpError("base and valueToStore element type should match");
4551 if (llvm::size(getIndices()) != memType.getRank())
4552 return emitOpError("requires ") << memType.getRank() << " indices";
4553 if (valueVType.getDimSize(0) != indVType.getDimSize(0))
4554 return emitOpError("expected valueToStore dim to match indices dim");
4555 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4556 return emitOpError("expected valueToStore dim to match mask dim");
4557 return success();
4558}
4559
4560namespace {
4561class ScatterFolder final : public OpRewritePattern<ScatterOp> {
4562public:
4563 using OpRewritePattern::OpRewritePattern;
4564 LogicalResult matchAndRewrite(ScatterOp scatter,
4565 PatternRewriter &rewriter) const override {
4566 switch (getMaskFormat(scatter.getMask())) {
4567 case MaskFormat::AllTrue:
4568 return failure(); // no unmasked equivalent
4569 case MaskFormat::AllFalse:
4570 rewriter.eraseOp(scatter);
4571 return success();
4572 case MaskFormat::Unknown:
4573 return failure();
4574 }
4575 llvm_unreachable("Unexpected 1DMaskFormat on ScatterFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ScatterFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4575)
;
4576 }
4577};
4578} // namespace
4579
4580void ScatterOp::getCanonicalizationPatterns(RewritePatternSet &results,
4581 MLIRContext *context) {
4582 results.add<ScatterFolder>(context);
4583}
4584
4585//===----------------------------------------------------------------------===//
4586// ExpandLoadOp
4587//===----------------------------------------------------------------------===//
4588
4589LogicalResult ExpandLoadOp::verify() {
4590 VectorType maskVType = getMaskVectorType();
4591 VectorType passVType = getPassThruVectorType();
4592 VectorType resVType = getVectorType();
4593 MemRefType memType = getMemRefType();
4594
4595 if (resVType.getElementType() != memType.getElementType())
4596 return emitOpError("base and result element type should match");
4597 if (llvm::size(getIndices()) != memType.getRank())
4598 return emitOpError("requires ") << memType.getRank() << " indices";
4599 if (resVType.getDimSize(0) != maskVType.getDimSize(0))
4600 return emitOpError("expected result dim to match mask dim");
4601 if (resVType != passVType)
4602 return emitOpError("expected pass_thru of same type as result type");
4603 return success();
4604}
4605
4606namespace {
4607class ExpandLoadFolder final : public OpRewritePattern<ExpandLoadOp> {
4608public:
4609 using OpRewritePattern::OpRewritePattern;
4610 LogicalResult matchAndRewrite(ExpandLoadOp expand,
4611 PatternRewriter &rewriter) const override {
4612 switch (getMaskFormat(expand.getMask())) {
4613 case MaskFormat::AllTrue:
4614 rewriter.replaceOpWithNewOp<vector::LoadOp>(
4615 expand, expand.getType(), expand.getBase(), expand.getIndices());
4616 return success();
4617 case MaskFormat::AllFalse:
4618 rewriter.replaceOp(expand, expand.getPassThru());
4619 return success();
4620 case MaskFormat::Unknown:
4621 return failure();
4622 }
4623 llvm_unreachable("Unexpected 1DMaskFormat on ExpandLoadFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on ExpandLoadFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4623)
;
4624 }
4625};
4626} // namespace
4627
4628void ExpandLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
4629 MLIRContext *context) {
4630 results.add<ExpandLoadFolder>(context);
4631}
4632
4633//===----------------------------------------------------------------------===//
4634// CompressStoreOp
4635//===----------------------------------------------------------------------===//
4636
4637LogicalResult CompressStoreOp::verify() {
4638 VectorType maskVType = getMaskVectorType();
4639 VectorType valueVType = getVectorType();
4640 MemRefType memType = getMemRefType();
4641
4642 if (valueVType.getElementType() != memType.getElementType())
4643 return emitOpError("base and valueToStore element type should match");
4644 if (llvm::size(getIndices()) != memType.getRank())
4645 return emitOpError("requires ") << memType.getRank() << " indices";
4646 if (valueVType.getDimSize(0) != maskVType.getDimSize(0))
4647 return emitOpError("expected valueToStore dim to match mask dim");
4648 return success();
4649}
4650
4651namespace {
4652class CompressStoreFolder final : public OpRewritePattern<CompressStoreOp> {
4653public:
4654 using OpRewritePattern::OpRewritePattern;
4655 LogicalResult matchAndRewrite(CompressStoreOp compress,
4656 PatternRewriter &rewriter) const override {
4657 switch (getMaskFormat(compress.getMask())) {
4658 case MaskFormat::AllTrue:
4659 rewriter.replaceOpWithNewOp<vector::StoreOp>(
4660 compress, compress.getValueToStore(), compress.getBase(),
4661 compress.getIndices());
4662 return success();
4663 case MaskFormat::AllFalse:
4664 rewriter.eraseOp(compress);
4665 return success();
4666 case MaskFormat::Unknown:
4667 return failure();
4668 }
4669 llvm_unreachable("Unexpected 1DMaskFormat on CompressStoreFolder")::llvm::llvm_unreachable_internal("Unexpected 1DMaskFormat on CompressStoreFolder"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 4669)
;
4670 }
4671};
4672} // namespace
4673
4674void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
4675 MLIRContext *context) {
4676 results.add<CompressStoreFolder>(context);
4677}
4678
4679//===----------------------------------------------------------------------===//
4680// ShapeCastOp
4681//===----------------------------------------------------------------------===//
4682
4683/// Returns true if each element of 'a' is equal to the product of a contiguous
4684/// sequence of the elements of 'b'. Returns false otherwise.
4685static bool isValidShapeCast(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
4686 unsigned rankA = a.size();
4687 unsigned rankB = b.size();
4688 assert(rankA < rankB)(static_cast <bool> (rankA < rankB) ? void (0) : __assert_fail
("rankA < rankB", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 4688, __extension__ __PRETTY_FUNCTION__))
;
4689
4690 unsigned i = 0;
4691 unsigned j = 0;
4692 while (i < rankA && j < rankB) {
4693 int64_t dimA = a[i];
4694 int64_t dimB = 1;
4695 while (dimB < dimA && j < rankB)
4696 dimB *= b[j++];
4697 if (dimA != dimB)
4698 break;
4699 ++i;
4700
4701 // Handle the case when trailing dimensions are of size 1.
4702 // Include them into the contiguous sequence.
4703 auto isOne = [](int64_t v) { return v == 1; };
4704 if (i < rankA && llvm::all_of(a.slice(i), isOne))
4705 i = rankA;
4706 if (j < rankB && llvm::all_of(b.slice(j), isOne))
4707 j = rankB;
4708 }
4709
4710 return i == rankA && j == rankB;
4711}
4712
4713static LogicalResult verifyVectorShapeCast(Operation *op,
4714 VectorType sourceVectorType,
4715 VectorType resultVectorType) {
4716 // Check that element type is the same.
4717 if (sourceVectorType.getElementType() != resultVectorType.getElementType())
4718 return op->emitOpError("source/result vectors must have same element type");
4719 auto sourceShape = sourceVectorType.getShape();
4720 auto resultShape = resultVectorType.getShape();
4721
4722 // Check that product of source dim sizes matches product of result dim sizes.
4723 int64_t sourceDimProduct = std::accumulate(
4724 sourceShape.begin(), sourceShape.end(), 1LL, std::multiplies<int64_t>{});
4725 int64_t resultDimProduct = std::accumulate(
4726 resultShape.begin(), resultShape.end(), 1LL, std::multiplies<int64_t>{});
4727 if (sourceDimProduct != resultDimProduct)
4728 return op->emitOpError("source/result number of elements must match");
4729
4730 // Check that expanding/contracting rank cases.
4731 unsigned sourceRank = sourceVectorType.getRank();
4732 unsigned resultRank = resultVectorType.getRank();
4733 if (sourceRank < resultRank) {
4734 if (!isValidShapeCast(sourceShape, resultShape))
4735 return op->emitOpError("invalid shape cast");
4736 } else if (sourceRank > resultRank) {
4737 if (!isValidShapeCast(resultShape, sourceShape))
4738 return op->emitOpError("invalid shape cast");
4739 }
4740 return success();
4741}
4742
4743LogicalResult ShapeCastOp::verify() {
4744 auto sourceVectorType = getSource().getType().dyn_cast_or_null<VectorType>();
4745 auto resultVectorType = getResult().getType().dyn_cast_or_null<VectorType>();
4746
4747 // Check if source/result are of vector type.
4748 if (sourceVectorType && resultVectorType)
4749 return verifyVectorShapeCast(*this, sourceVectorType, resultVectorType);
4750
4751 return success();
4752}
4753
4754OpFoldResult ShapeCastOp::fold(ArrayRef<Attribute> operands) {
4755 // No-op shape cast.
4756 if (getSource().getType() == getResult().getType())
4757 return getSource();
4758
4759 // Canceling shape casts.
4760 if (auto otherOp = getSource().getDefiningOp<ShapeCastOp>()) {
4761 if (getResult().getType() == otherOp.getSource().getType())
4762 return otherOp.getSource();
4763
4764 // Only allows valid transitive folding.
4765 VectorType srcType = otherOp.getSource().getType().cast<VectorType>();
4766 VectorType resultType = getResult().getType().cast<VectorType>();
4767 if (srcType.getRank() < resultType.getRank()) {
4768 if (!isValidShapeCast(srcType.getShape(), resultType.getShape()))
4769 return {};
4770 } else if (srcType.getRank() > resultType.getRank()) {
4771 if (!isValidShapeCast(resultType.getShape(), srcType.getShape()))
4772 return {};
4773 } else {
4774 return {};
4775 }
4776
4777 setOperand(otherOp.getSource());
4778 return getResult();
4779 }
4780
4781 // Cancelling broadcast and shape cast ops.
4782 if (auto bcastOp = getSource().getDefiningOp<BroadcastOp>()) {
4783 if (bcastOp.getSourceType() == getType())
4784 return bcastOp.getSource();
4785 }
4786
4787 return {};
4788}
4789
4790namespace {
4791// Pattern to rewrite a ShapeCast(splat ConstantOp) -> ConstantOp.
4792class ShapeCastConstantFolder final : public OpRewritePattern<ShapeCastOp> {
4793public:
4794 using OpRewritePattern::OpRewritePattern;
4795
4796 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4797 PatternRewriter &rewriter) const override {
4798 auto constantOp =
4799 shapeCastOp.getSource().getDefiningOp<arith::ConstantOp>();
4800 if (!constantOp)
4801 return failure();
4802 // Only handle splat for now.
4803 auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
4804 if (!dense)
4805 return failure();
4806 auto newAttr =
4807 DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
4808 dense.getSplatValue<Attribute>());
4809 rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
4810 return success();
4811 }
4812};
4813
4814/// Pattern to rewrite a ShapeCast(Broadcast) -> Broadcast.
4815/// This only applies when the shape of the broadcast source is a suffix of the
4816/// shape of the result (i.e. when broadcast without reshape is expressive
4817/// enough to capture the result in a single op).
4818class ShapeCastBroadcastFolder final : public OpRewritePattern<ShapeCastOp> {
4819public:
4820 using OpRewritePattern::OpRewritePattern;
4821
4822 LogicalResult matchAndRewrite(ShapeCastOp shapeCastOp,
4823 PatternRewriter &rewriter) const override {
4824 auto broadcastOp =
4825 shapeCastOp.getSource().getDefiningOp<vector::BroadcastOp>();
4826 if (!broadcastOp)
4827 return failure();
4828
4829 auto broadcastSourceVectorType =
4830 broadcastOp.getSourceType().dyn_cast<VectorType>();
4831 auto broadcastSourceShape = broadcastSourceVectorType
4832 ? broadcastSourceVectorType.getShape()
4833 : ArrayRef<int64_t>{};
4834 auto shapeCastTargetShape = shapeCastOp.getResultVectorType().getShape();
4835
4836 // Bail if `broadcastSourceShape` is not a suffix of the result.
4837 bool isSuffix = (broadcastSourceShape == shapeCastTargetShape.take_back(
4838 broadcastSourceShape.size()));
4839 if (!isSuffix)
4840 return failure();
4841
4842 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
4843 shapeCastOp, shapeCastOp.getResultVectorType(),
4844 broadcastOp.getSource());
4845 return success();
4846 }
4847};
4848
4849} // namespace
4850
4851void ShapeCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
4852 MLIRContext *context) {
4853 results.add<ShapeCastConstantFolder, ShapeCastBroadcastFolder>(context);
4854}
4855
4856//===----------------------------------------------------------------------===//
4857// VectorBitCastOp
4858//===----------------------------------------------------------------------===//
4859
4860LogicalResult BitCastOp::verify() {
4861 auto sourceVectorType = getSourceVectorType();
4862 auto resultVectorType = getResultVectorType();
4863
4864 for (int64_t i = 0, e = sourceVectorType.getRank() - 1; i < e; i++) {
4865 if (sourceVectorType.getDimSize(i) != resultVectorType.getDimSize(i))
4866 return emitOpError("dimension size mismatch at: ") << i;
4867 }
4868
4869 DataLayout dataLayout = DataLayout::closest(*this);
4870 auto sourceElementBits =
4871 dataLayout.getTypeSizeInBits(sourceVectorType.getElementType());
4872 auto resultElementBits =
4873 dataLayout.getTypeSizeInBits(resultVectorType.getElementType());
4874
4875 if (sourceVectorType.getRank() == 0) {
4876 if (sourceElementBits != resultElementBits)
4877 return emitOpError("source/result bitwidth of the 0-D vector element "
4878 "types must be equal");
4879 } else if (sourceElementBits * sourceVectorType.getShape().back() !=
4880 resultElementBits * resultVectorType.getShape().back()) {
4881 return emitOpError(
4882 "source/result bitwidth of the minor 1-D vectors must be equal");
4883 }
4884
4885 return success();
4886}
4887
4888OpFoldResult BitCastOp::fold(ArrayRef<Attribute> operands) {
4889 // Nop cast.
4890 if (getSource().getType() == getResult().getType())
4891 return getSource();
4892
4893 // Canceling bitcasts.
4894 if (auto otherOp = getSource().getDefiningOp<BitCastOp>()) {
4895 if (getResult().getType() == otherOp.getSource().getType())
4896 return otherOp.getSource();
4897
4898 setOperand(otherOp.getSource());
4899 return getResult();
4900 }
4901
4902 Attribute sourceConstant = operands.front();
4903 if (!sourceConstant)
4904 return {};
4905
4906 Type srcElemType = getSourceVectorType().getElementType();
4907 Type dstElemType = getResultVectorType().getElementType();
4908
4909 if (auto floatPack = sourceConstant.dyn_cast<DenseFPElementsAttr>()) {
4910 if (floatPack.isSplat()) {
4911 auto splat = floatPack.getSplatValue<FloatAttr>();
4912
4913 // Casting fp16 into fp32.
4914 if (srcElemType.isF16() && dstElemType.isF32()) {
4915 uint32_t bits = static_cast<uint32_t>(
4916 splat.getValue().bitcastToAPInt().getZExtValue());
4917 // Duplicate the 16-bit pattern.
4918 bits = (bits << 16) | (bits & 0xffff);
4919 APInt intBits(32, bits);
4920 APFloat floatBits(llvm::APFloat::IEEEsingle(), intBits);
4921 return DenseElementsAttr::get(getResultVectorType(), floatBits);
4922 }
4923 }
4924 }
4925
4926 return {};
4927}
4928
4929//===----------------------------------------------------------------------===//
4930// TypeCastOp
4931//===----------------------------------------------------------------------===//
4932
4933static SmallVector<int64_t, 8> extractShape(MemRefType memRefType) {
4934 auto vectorType = memRefType.getElementType().dyn_cast<VectorType>();
4935 SmallVector<int64_t, 8> res(memRefType.getShape().begin(),
4936 memRefType.getShape().end());
4937 if (vectorType)
4938 res.append(vectorType.getShape().begin(), vectorType.getShape().end());
4939 return res;
4940}
4941
4942/// Build the canonical memRefType with a single vector.
4943/// E.g. memref<4 x 5 x vector<6 x f32>> -> memref<vector<4 x 5 x 6 x f32>>.
4944void TypeCastOp::build(OpBuilder &builder, OperationState &result,
4945 Value source) {
4946 result.addOperands(source);
4947 MemRefType memRefType = source.getType().cast<MemRefType>();
4948 VectorType vectorType =
4949 VectorType::get(extractShape(memRefType),
4950 getElementTypeOrSelf(getElementTypeOrSelf(memRefType)));
4951 result.addTypes(MemRefType::get({}, vectorType, MemRefLayoutAttrInterface(),
4952 memRefType.getMemorySpace()));
4953}
4954
4955LogicalResult TypeCastOp::verify() {
4956 MemRefType canonicalType = canonicalizeStridedLayout(getMemRefType());
4957 if (!canonicalType.getLayout().isIdentity())
4958 return emitOpError("expects operand to be a memref with identity layout");
4959 if (!getResultMemRefType().getLayout().isIdentity())
4960 return emitOpError("expects result to be a memref with identity layout");
4961 if (getResultMemRefType().getMemorySpace() !=
4962 getMemRefType().getMemorySpace())
4963 return emitOpError("expects result in same memory space");
4964
4965 auto sourceType = getMemRefType();
4966 auto resultType = getResultMemRefType();
4967 if (getElementTypeOrSelf(getElementTypeOrSelf(sourceType)) !=
4968 getElementTypeOrSelf(getElementTypeOrSelf(resultType)))
4969 return emitOpError(
4970 "expects result and operand with same underlying scalar type: ")
4971 << resultType;
4972 if (extractShape(sourceType) != extractShape(resultType))
4973 return emitOpError(
4974 "expects concatenated result and operand shapes to be equal: ")
4975 << resultType;
4976 return success();
4977}
4978
4979//===----------------------------------------------------------------------===//
4980// TransposeOp
4981//===----------------------------------------------------------------------===//
4982
4983void vector::TransposeOp::build(OpBuilder &builder, OperationState &result,
4984 Value vector, ArrayRef<int64_t> transp) {
4985 VectorType vt = vector.getType().cast<VectorType>();
4986 SmallVector<int64_t, 4> transposedShape(vt.getRank());
4987 for (unsigned i = 0; i < transp.size(); ++i)
4988 transposedShape[i] = vt.getShape()[transp[i]];
4989
4990 result.addOperands(vector);
4991 result.addTypes(VectorType::get(transposedShape, vt.getElementType()));
4992 result.addAttribute(getTranspAttrStrName(), builder.getI64ArrayAttr(transp));
4993}
4994
4995OpFoldResult vector::TransposeOp::fold(ArrayRef<Attribute> operands) {
4996 // Eliminate splat constant transpose ops.
4997 if (auto attr = operands.front().dyn_cast_or_null<DenseElementsAttr>())
4998 if (attr.isSplat())
4999 return attr.reshape(getResultType());
5000
5001 // Eliminate identity transpose ops. This happens when the dimensions of the
5002 // input vector remain in their original order after the transpose operation.
5003 SmallVector<int64_t, 4> transp;
5004 getTransp(transp);
5005
5006 // Check if the permutation of the dimensions contains sequential values:
5007 // {0, 1, 2, ...}.
5008 for (int64_t i = 0, e = transp.size(); i < e; i++) {
5009 if (transp[i] != i)
5010 return {};
5011 }
5012
5013 return getVector();
5014}
5015
5016LogicalResult vector::TransposeOp::verify() {
5017 VectorType vectorType = getVectorType();
5018 VectorType resultType = getResultType();
5019 int64_t rank = resultType.getRank();
5020 if (vectorType.getRank() != rank)
5021 return emitOpError("vector result rank mismatch: ") << rank;
5022 // Verify transposition array.
5023 auto transpAttr = getTransp().getValue();
5024 int64_t size = transpAttr.size();
5025 if (rank != size)
5026 return emitOpError("transposition length mismatch: ") << size;
5027 SmallVector<bool, 8> seen(rank, false);
5028 for (const auto &ta : llvm::enumerate(transpAttr)) {
5029 int64_t i = ta.value().cast<IntegerAttr>().getInt();
5030 if (i < 0 || i >= rank)
5031 return emitOpError("transposition index out of range: ") << i;
5032 if (seen[i])
5033 return emitOpError("duplicate position index: ") << i;
5034 seen[i] = true;
5035 if (resultType.getDimSize(ta.index()) != vectorType.getDimSize(i))
5036 return emitOpError("dimension size mismatch at: ") << i;
5037 }
5038 return success();
5039}
5040
5041std::optional<SmallVector<int64_t, 4>> TransposeOp::getShapeForUnroll() {
5042 return llvm::to_vector<4>(getResultType().getShape());
5043}
5044
5045namespace {
5046
5047// Rewrites two back-to-back TransposeOp operations into a single TransposeOp.
5048class TransposeFolder final : public OpRewritePattern<vector::TransposeOp> {
5049public:
5050 using OpRewritePattern::OpRewritePattern;
5051
5052 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5053 PatternRewriter &rewriter) const override {
5054 // Wrapper around vector::TransposeOp::getTransp() for cleaner code.
5055 auto getPermutation = [](vector::TransposeOp transpose) {
5056 SmallVector<int64_t, 4> permutation;
5057 transpose.getTransp(permutation);
5058 return permutation;
5059 };
5060
5061 // Composes two permutations: result[i] = permutation1[permutation2[i]].
5062 auto composePermutations = [](ArrayRef<int64_t> permutation1,
5063 ArrayRef<int64_t> permutation2) {
5064 SmallVector<int64_t, 4> result;
5065 for (auto index : permutation2)
5066 result.push_back(permutation1[index]);
5067 return result;
5068 };
5069
5070 // Return if the input of 'transposeOp' is not defined by another transpose.
5071 vector::TransposeOp parentTransposeOp =
5072 transposeOp.getVector().getDefiningOp<vector::TransposeOp>();
5073 if (!parentTransposeOp)
5074 return failure();
5075
5076 SmallVector<int64_t, 4> permutation = composePermutations(
5077 getPermutation(parentTransposeOp), getPermutation(transposeOp));
5078 // Replace 'transposeOp' with a new transpose operation.
5079 rewriter.replaceOpWithNewOp<vector::TransposeOp>(
5080 transposeOp, transposeOp.getResult().getType(),
5081 parentTransposeOp.getVector(),
5082 vector::getVectorSubscriptAttr(rewriter, permutation));
5083 return success();
5084 }
5085};
5086
5087// Folds transpose(broadcast(<scalar>)) into brodcast(<scalar>).
5088struct FoldTransposedScalarBroadcast final
5089 : public OpRewritePattern<vector::TransposeOp> {
5090 using OpRewritePattern::OpRewritePattern;
5091
5092 LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
5093 PatternRewriter &rewriter) const override {
5094 auto bcastOp = transposeOp.getVector().getDefiningOp<vector::BroadcastOp>();
5095 if (!bcastOp)
5096 return failure();
5097
5098 auto srcVectorType = bcastOp.getSourceType().dyn_cast<VectorType>();
5099 if (!srcVectorType || srcVectorType.getNumElements() == 1) {
5100 rewriter.replaceOpWithNewOp<vector::BroadcastOp>(
5101 transposeOp, transposeOp.getResultType(), bcastOp.getSource());
5102 return success();
5103 }
5104
5105 return failure();
5106 }
5107};
5108
5109// Folds transpose(splat x : src_type) : res_type into splat x : res_type.
5110class FoldTransposeSplat final : public OpRewritePattern<TransposeOp> {
5111public:
5112 using OpRewritePattern::OpRewritePattern;
5113
5114 LogicalResult matchAndRewrite(TransposeOp transposeOp,
5115 PatternRewriter &rewriter) const override {
5116 auto splatOp = transposeOp.getVector().getDefiningOp<vector::SplatOp>();
5117 if (!splatOp)
5118 return failure();
5119
5120 rewriter.replaceOpWithNewOp<vector::SplatOp>(
5121 transposeOp, transposeOp.getResultType(), splatOp.getInput());
5122 return success();
5123 }
5124};
5125
5126} // namespace
5127
5128void vector::TransposeOp::getCanonicalizationPatterns(
5129 RewritePatternSet &results, MLIRContext *context) {
5130 results
5131 .add<FoldTransposedScalarBroadcast, TransposeFolder, FoldTransposeSplat>(
5132 context);
5133}
5134
5135void vector::TransposeOp::getTransp(SmallVectorImpl<int64_t> &results) {
5136 populateFromInt64AttrArray(getTransp(), results);
5137}
5138
5139//===----------------------------------------------------------------------===//
5140// ConstantMaskOp
5141//===----------------------------------------------------------------------===//
5142
5143LogicalResult ConstantMaskOp::verify() {
5144 auto resultType = getResult().getType().cast<VectorType>();
5145 // Check the corner case of 0-D vectors first.
5146 if (resultType.getRank() == 0) {
5147 if (getMaskDimSizes().size() != 1)
5148 return emitError("array attr must have length 1 for 0-D vectors");
5149 auto dim = getMaskDimSizes()[0].cast<IntegerAttr>().getInt();
5150 if (dim != 0 && dim != 1)
5151 return emitError("mask dim size must be either 0 or 1 for 0-D vectors");
5152 return success();
5153 }
5154
5155 // Verify that array attr size matches the rank of the vector result.
5156 if (static_cast<int64_t>(getMaskDimSizes().size()) != resultType.getRank())
5157 return emitOpError(
5158 "must specify array attr of size equal vector result rank");
5159 // Verify that each array attr element is in bounds of corresponding vector
5160 // result dimension size.
5161 auto resultShape = resultType.getShape();
5162 SmallVector<int64_t, 4> maskDimSizes;
5163 for (const auto &it : llvm::enumerate(getMaskDimSizes())) {
5164 int64_t attrValue = it.value().cast<IntegerAttr>().getInt();
5165 if (attrValue < 0 || attrValue > resultShape[it.index()])
5166 return emitOpError(
5167 "array attr of size out of bounds of vector result dimension size");
5168 maskDimSizes.push_back(attrValue);
5169 }
5170 // Verify that if one mask dim size is zero, they all should be zero (because
5171 // the mask region is a conjunction of each mask dimension interval).
5172 bool anyZeros = llvm::is_contained(maskDimSizes, 0);
5173 bool allZeros = llvm::all_of(maskDimSizes, [](int64_t s) { return s == 0; });
5174 if (anyZeros && !allZeros)
5175 return emitOpError("expected all mask dim sizes to be zeros, "
5176 "as a result of conjunction with zero mask dim");
5177 // Verify that if the mask type is scalable, dimensions should be zero because
5178 // constant scalable masks can only be defined for the "none set" or "all set"
5179 // cases, and there is no VLA way to define an "all set" case for
5180 // `vector.constant_mask`. In the future, a convention could be established
5181 // to decide if a specific dimension value could be considered as "all set".
5182 if (resultType.isScalable() &&
5183 getMaskDimSizes()[0].cast<IntegerAttr>().getInt() != 0)
5184 return emitOpError("expected mask dim sizes for scalable masks to be 0");
5185 return success();
5186}
5187
5188//===----------------------------------------------------------------------===//
5189// CreateMaskOp
5190//===----------------------------------------------------------------------===//
5191
5192LogicalResult CreateMaskOp::verify() {
5193 auto vectorType = getResult().getType().cast<VectorType>();
5194 // Verify that an operand was specified for each result vector each dimension.
5195 if (vectorType.getRank() == 0) {
5196 if (getNumOperands() != 1)
5197 return emitOpError(
5198 "must specify exactly one operand for 0-D create_mask");
5199 } else if (getNumOperands() !=
5200 getResult().getType().cast<VectorType>().getRank()) {
5201 return emitOpError(
5202 "must specify an operand for each result vector dimension");
5203 }
5204 return success();
5205}
5206
5207namespace {
5208
5209// Pattern to rewrite a CreateMaskOp with a ConstantMaskOp.
5210class CreateMaskFolder final : public OpRewritePattern<CreateMaskOp> {
5211public:
5212 using OpRewritePattern::OpRewritePattern;
5213
5214 LogicalResult matchAndRewrite(CreateMaskOp createMaskOp,
5215 PatternRewriter &rewriter) const override {
5216 // Return if any of 'createMaskOp' operands are not defined by a constant.
5217 auto isNotDefByConstant = [](Value operand) {
5218 return !isa_and_nonnull<arith::ConstantIndexOp>(operand.getDefiningOp());
5219 };
5220 if (llvm::any_of(createMaskOp.getOperands(), isNotDefByConstant))
5221 return failure();
5222
5223 // CreateMaskOp for scalable vectors can be folded only if all dimensions
5224 // are negative or zero.
5225 if (auto vType = createMaskOp.getType().dyn_cast<VectorType>()) {
5226 if (vType.isScalable())
5227 for (auto opDim : createMaskOp.getOperands()) {
5228 APInt intVal;
5229 if (matchPattern(opDim, m_ConstantInt(&intVal)) &&
5230 intVal.isStrictlyPositive())
5231 return failure();
5232 }
5233 }
5234
5235 // Gather constant mask dimension sizes.
5236 SmallVector<int64_t, 4> maskDimSizes;
5237 maskDimSizes.reserve(createMaskOp->getNumOperands());
5238 for (auto [operand, maxDimSize] : llvm::zip_equal(
5239 createMaskOp.getOperands(), createMaskOp.getType().getShape())) {
5240 Operation *defOp = operand.getDefiningOp();
5241 int64_t dimSize = cast<arith::ConstantIndexOp>(defOp).value();
5242 dimSize = std::min(dimSize, maxDimSize);
5243 // If one of dim sizes is zero, set all dims to zero.
5244 if (dimSize <= 0) {
5245 maskDimSizes.assign(createMaskOp.getType().getRank(), 0);
5246 break;
5247 }
5248 maskDimSizes.push_back(dimSize);
5249 }
5250 // Replace 'createMaskOp' with ConstantMaskOp.
5251 rewriter.replaceOpWithNewOp<ConstantMaskOp>(
5252 createMaskOp, createMaskOp.getResult().getType(),
5253 vector::getVectorSubscriptAttr(rewriter, maskDimSizes));
5254 return success();
5255 }
5256};
5257
5258} // namespace
5259
5260void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
5261 MLIRContext *context) {
5262 results.add<CreateMaskFolder>(context);
5263}
5264
5265//===----------------------------------------------------------------------===//
5266// MaskOp
5267//===----------------------------------------------------------------------===//
5268
5269void MaskOp::build(
5270 OpBuilder &builder, OperationState &result, Value mask,
5271 function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5272 assert(maskRegionBuilder &&(static_cast <bool> (maskRegionBuilder && "builder callback for 'maskRegion' must be present"
) ? void (0) : __assert_fail ("maskRegionBuilder && \"builder callback for 'maskRegion' must be present\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5273, __extension__
__PRETTY_FUNCTION__))
5273 "builder callback for 'maskRegion' must be present")(static_cast <bool> (maskRegionBuilder && "builder callback for 'maskRegion' must be present"
) ? void (0) : __assert_fail ("maskRegionBuilder && \"builder callback for 'maskRegion' must be present\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5273, __extension__
__PRETTY_FUNCTION__))
;
5274
5275 result.addOperands(mask);
5276 OpBuilder::InsertionGuard guard(builder);
5277 Region *maskRegion = result.addRegion();
5278 builder.createBlock(maskRegion);
5279 maskRegionBuilder(builder, result.location);
5280}
5281
5282void MaskOp::build(
5283 OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5284 function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5285 build(builder, result, resultType, mask, /*passthru=*/Value(),
5286 maskRegionBuilder);
5287}
5288
5289void MaskOp::build(
5290 OpBuilder &builder, OperationState &result, Type resultType, Value mask,
5291 Value passthru,
5292 function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
5293 build(builder, result, mask, maskRegionBuilder);
5294 if (passthru)
5295 result.addOperands(passthru);
5296 result.addTypes(resultType);
5297}
5298
5299ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {
5300 // Create the op region.
5301 result.regions.reserve(1);
5302 Region &maskRegion = *result.addRegion();
5303
5304 auto &builder = parser.getBuilder();
5305
5306 // Parse all the operands.
5307 OpAsmParser::UnresolvedOperand mask;
5308 if (parser.parseOperand(mask))
5309 return failure();
5310
5311 // Optional passthru operand.
5312 OpAsmParser::UnresolvedOperand passthru;
5313 ParseResult parsePassthru = parser.parseOptionalComma();
5314 if (parsePassthru.succeeded() && parser.parseOperand(passthru))
5315 return failure();
5316
5317 // Parse op region.
5318 if (parser.parseRegion(maskRegion, /*arguments=*/{}, /*argTypes=*/{}))
5319 return failure();
5320
5321 MaskOp::ensureTerminator(maskRegion, builder, result.location);
5322
5323 // Parse the optional attribute list.
5324 if (parser.parseOptionalAttrDict(result.attributes))
5325 return failure();
5326
5327 // Parse all the types.
5328 Type maskType;
5329 if (parser.parseColonType(maskType))
5330 return failure();
5331
5332 SmallVector<Type> resultTypes;
5333 if (parser.parseOptionalArrowTypeList(resultTypes))
5334 return failure();
5335 result.types.append(resultTypes);
5336
5337 // Resolve operands.
5338 if (parser.resolveOperand(mask, maskType, result.operands))
5339 return failure();
5340
5341 if (parsePassthru.succeeded())
5342 if (parser.resolveOperand(passthru, resultTypes[0], result.operands))
5343 return failure();
5344
5345 return success();
5346}
5347
5348void mlir::vector::MaskOp::print(OpAsmPrinter &p) {
5349 p << " " << getMask();
5350 if (getPassthru())
5351 p << ", " << getPassthru();
5352
5353 // Print single masked operation and skip terminator.
5354 p << " { ";
5355 Block *singleBlock = &getMaskRegion().getBlocks().front();
5356 if (singleBlock && singleBlock->getOperations().size() > 1)
5357 p.printCustomOrGenericOp(&singleBlock->front());
5358 p << " }";
5359
5360 p.printOptionalAttrDict(getOperation()->getAttrs());
5361
5362 p << " : " << getMask().getType();
5363 if (getNumResults() > 0)
5364 p << " -> " << getResultTypes();
5365}
5366
5367void MaskOp::ensureTerminator(Region &region, Builder &builder, Location loc) {
5368 OpTrait::SingleBlockImplicitTerminator<vector::YieldOp>::Impl<
5369 MaskOp>::ensureTerminator(region, builder, loc);
5370 // Keep the default yield terminator if the number of masked operations is not
5371 // the expected. This case will trigger a verification failure.
5372 if (region.front().getOperations().size() != 2)
5373 return;
5374
5375 // Replace default yield terminator with a new one that returns the results
5376 // from the masked operation.
5377 OpBuilder opBuilder(builder.getContext());
5378 Operation *maskedOp = &region.front().front();
5379 Operation *oldYieldOp = &region.front().back();
5380 assert(isa<vector::YieldOp>(oldYieldOp) && "Expected vector::YieldOp")(static_cast <bool> (isa<vector::YieldOp>(oldYieldOp
) && "Expected vector::YieldOp") ? void (0) : __assert_fail
("isa<vector::YieldOp>(oldYieldOp) && \"Expected vector::YieldOp\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5380, __extension__
__PRETTY_FUNCTION__))
;
5381
5382 opBuilder.setInsertionPoint(oldYieldOp);
5383 opBuilder.create<vector::YieldOp>(maskedOp->getLoc(), maskedOp->getResults());
5384 oldYieldOp->dropAllReferences();
5385 oldYieldOp->erase();
5386}
5387
5388LogicalResult MaskOp::verify() {
5389 // Structural checks.
5390 Block &block = getMaskRegion().getBlocks().front();
5391 if (block.getOperations().size() < 2)
5392 return emitOpError("expects an operation to mask");
5393 if (block.getOperations().size() > 2)
5394 return emitOpError("expects only one operation to mask");
5395
5396 auto maskableOp = dyn_cast<MaskableOpInterface>(block.front());
5397 if (!maskableOp)
5398 return emitOpError("expects a maskable operation");
5399
5400 // Result checks.
5401 if (maskableOp->getNumResults() != getNumResults())
5402 return emitOpError("expects number of results to match maskable operation "
5403 "number of results");
5404
5405 if (!llvm::equal(maskableOp->getResultTypes(), getResultTypes()))
5406 return emitOpError(
5407 "expects result type to match maskable operation result type");
5408
5409 // Mask checks.
5410 Type expectedMaskType = maskableOp.getExpectedMaskType();
5411 if (getMask().getType() != expectedMaskType)
5412 return emitOpError("expects a ")
5413 << expectedMaskType << " mask for the maskable operation";
5414
5415 // Passthru checks.
5416 Value passthru = getPassthru();
5417 if (passthru) {
5418 if (!maskableOp.supportsPassthru())
5419 return emitOpError(
5420 "doesn't expect a passthru argument for this maskable operation");
5421
5422 if (maskableOp->getNumResults() != 1)
5423 return emitOpError("expects result when passthru argument is provided");
5424
5425 if (passthru.getType() != maskableOp->getResultTypes()[0])
5426 return emitOpError("expects passthru type to match result type");
5427 }
5428
5429 return success();
5430}
5431
5432// MaskingOpInterface definitions.
5433
5434/// Returns the operation masked by this 'vector.mask'.
5435Operation *MaskOp::getMaskableOp() { return &getMaskRegion().front().front(); }
5436
5437/// Returns true if 'vector.mask' has a passthru value.
5438bool MaskOp::hasPassthru() { return getPassthru() != Value(); }
5439
5440//===----------------------------------------------------------------------===//
5441// ScanOp
5442//===----------------------------------------------------------------------===//
5443
5444LogicalResult ScanOp::verify() {
5445 VectorType srcType = getSourceType();
5446 VectorType initialType = getInitialValueType();
5447 // Check reduction dimension < rank.
5448 int64_t srcRank = srcType.getRank();
5449 int64_t reductionDim = getReductionDim();
5450 if (reductionDim >= srcRank)
5451 return emitOpError("reduction dimension ")
5452 << reductionDim << " has to be less than " << srcRank;
5453
5454 // Check that rank(initial_value) = rank(src) - 1.
5455 int64_t initialValueRank = initialType.getRank();
5456 if (initialValueRank != srcRank - 1)
5457 return emitOpError("initial value rank ")
5458 << initialValueRank << " has to be equal to " << srcRank - 1;
5459
5460 // Check shapes of initial value and src.
5461 ArrayRef<int64_t> srcShape = srcType.getShape();
5462 ArrayRef<int64_t> initialValueShapes = initialType.getShape();
5463 SmallVector<int64_t> expectedShape;
5464 for (int i = 0; i < srcRank; i++) {
5465 if (i != reductionDim)
5466 expectedShape.push_back(srcShape[i]);
5467 }
5468 if (!llvm::equal(initialValueShapes, expectedShape)) {
5469 return emitOpError("incompatible input/initial value shapes");
5470 }
5471
5472 // Verify supported reduction kind.
5473 Type eltType = getDestType().getElementType();
5474 if (!isSupportedCombiningKind(getKind(), eltType))
5475 return emitOpError("unsupported reduction type ")
5476 << eltType << " for kind '" << stringifyCombiningKind(getKind())
5477 << "'";
5478
5479 return success();
5480}
5481
5482void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
5483 RewritePatternSet &patterns, PatternBenefit benefit) {
5484 patterns
5485 .add<CreateMaskFolder, MaskedLoadFolder, MaskedStoreFolder, GatherFolder,
5486 ScatterFolder, ExpandLoadFolder, CompressStoreFolder,
5487 StridedSliceConstantMaskFolder, TransposeFolder>(
5488 patterns.getContext(), benefit);
5489}
5490
5491//===----------------------------------------------------------------------===//
5492// SplatOp
5493//===----------------------------------------------------------------------===//
5494
5495OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
5496 auto constOperand = operands.front();
5497 if (!constOperand.isa_and_nonnull<IntegerAttr, FloatAttr>())
5498 return {};
5499
5500 // SplatElementsAttr::get treats single value for second arg as being a splat.
5501 return SplatElementsAttr::get(getType(), {constOperand});
5502}
5503
5504//===----------------------------------------------------------------------===//
5505// WarpExecuteOnLane0Op
5506//===----------------------------------------------------------------------===//
5507
5508void WarpExecuteOnLane0Op::print(OpAsmPrinter &p) {
5509 p << "(" << getLaneid() << ")";
5510
5511 SmallVector<StringRef> coreAttr = {getWarpSizeAttrName()};
5512 auto warpSizeAttr = getOperation()->getAttr(getWarpSizeAttrName());
5513 p << "[" << warpSizeAttr.cast<IntegerAttr>().getInt() << "]";
5514
5515 if (!getArgs().empty())
5516 p << " args(" << getArgs() << " : " << getArgs().getTypes() << ")";
5517 if (!getResults().empty())
5518 p << " -> (" << getResults().getTypes() << ')';
5519 p << " ";
5520 p.printRegion(getRegion(),
5521 /*printEntryBlockArgs=*/true,
5522 /*printBlockTerminators=*/!getResults().empty());
5523 p.printOptionalAttrDict(getOperation()->getAttrs(), coreAttr);
5524}
5525
5526ParseResult WarpExecuteOnLane0Op::parse(OpAsmParser &parser,
5527 OperationState &result) {
5528 // Create the region.
5529 result.regions.reserve(1);
5530 Region *warpRegion = result.addRegion();
5531
5532 auto &builder = parser.getBuilder();
5533 OpAsmParser::UnresolvedOperand laneId;
5534
5535 // Parse predicate operand.
5536 if (parser.parseLParen() ||
1
Taking false branch
5537 parser.parseOperand(laneId, /*allowResultNumber=*/false) ||
5538 parser.parseRParen())
5539 return failure();
5540
5541 int64_t warpSize;
2
'warpSize' declared without an initial value
5542 if (parser.parseLSquare() || parser.parseInteger(warpSize) ||
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Taking false branch
5543 parser.parseRSquare())
5544 return failure();
5545 result.addAttribute(getWarpSizeAttrName(OperationName(getOperationName(),
5546 builder.getContext())),
5547 builder.getI64IntegerAttr(warpSize));
13
1st function call argument is an uninitialized value
5548
5549 if (parser.resolveOperand(laneId, builder.getIndexType(), result.operands))
5550 return failure();
5551
5552 llvm::SMLoc inputsOperandsLoc;
5553 SmallVector<OpAsmParser::UnresolvedOperand> inputsOperands;
5554 SmallVector<Type> inputTypes;
5555 if (succeeded(parser.parseOptionalKeyword("args"))) {
5556 if (parser.parseLParen())
5557 return failure();
5558
5559 inputsOperandsLoc = parser.getCurrentLocation();
5560 if (parser.parseOperandList(inputsOperands) ||
5561 parser.parseColonTypeList(inputTypes) || parser.parseRParen())
5562 return failure();
5563 }
5564 if (parser.resolveOperands(inputsOperands, inputTypes, inputsOperandsLoc,
5565 result.operands))
5566 return failure();
5567
5568 // Parse optional results type list.
5569 if (parser.parseOptionalArrowTypeList(result.types))
5570 return failure();
5571 // Parse the region.
5572 if (parser.parseRegion(*warpRegion, /*arguments=*/{},
5573 /*argTypes=*/{}))
5574 return failure();
5575 WarpExecuteOnLane0Op::ensureTerminator(*warpRegion, builder, result.location);
5576
5577 // Parse the optional attribute list.
5578 if (parser.parseOptionalAttrDict(result.attributes))
5579 return failure();
5580 return success();
5581}
5582
5583void WarpExecuteOnLane0Op::getSuccessorRegions(
5584 std::optional<unsigned> index, ArrayRef<Attribute> operands,
5585 SmallVectorImpl<RegionSuccessor> &regions) {
5586 if (index) {
5587 regions.push_back(RegionSuccessor(getResults()));
5588 return;
5589 }
5590
5591 // The warp region is always executed
5592 regions.push_back(RegionSuccessor(&getWarpRegion()));
5593}
5594
5595void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
5596 TypeRange resultTypes, Value laneId,
5597 int64_t warpSize) {
5598 build(builder, result, resultTypes, laneId, warpSize,
5599 /*operands=*/std::nullopt, /*argTypes=*/std::nullopt);
5600}
5601
5602void WarpExecuteOnLane0Op::build(OpBuilder &builder, OperationState &result,
5603 TypeRange resultTypes, Value laneId,
5604 int64_t warpSize, ValueRange args,
5605 TypeRange blockArgTypes) {
5606 result.addOperands(laneId);
5607 result.addAttribute(getAttributeNames()[0],
5608 builder.getI64IntegerAttr(warpSize));
5609 result.addTypes(resultTypes);
5610 result.addOperands(args);
5611 assert(args.size() == blockArgTypes.size())(static_cast <bool> (args.size() == blockArgTypes.size(
)) ? void (0) : __assert_fail ("args.size() == blockArgTypes.size()"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5611, __extension__
__PRETTY_FUNCTION__))
;
5612 OpBuilder::InsertionGuard guard(builder);
5613 Region *warpRegion = result.addRegion();
5614 Block *block = builder.createBlock(warpRegion);
5615 for (auto [type, arg] : llvm::zip_equal(blockArgTypes, args))
5616 block->addArgument(type, arg.getLoc());
5617}
5618
5619/// Helper check if the distributed vector type is consistent with the expanded
5620/// type and distributed size.
5621static LogicalResult verifyDistributedType(Type expanded, Type distributed,
5622 int64_t warpSize, Operation *op) {
5623 // If the types matches there is no distribution.
5624 if (expanded == distributed)
5625 return success();
5626 auto expandedVecType = expanded.dyn_cast<VectorType>();
5627 auto distributedVecType = distributed.dyn_cast<VectorType>();
5628 if (!expandedVecType || !distributedVecType)
5629 return op->emitOpError("expected vector type for distributed operands.");
5630 if (expandedVecType.getRank() != distributedVecType.getRank() ||
5631 expandedVecType.getElementType() != distributedVecType.getElementType())
5632 return op->emitOpError(
5633 "expected distributed vectors to have same rank and element type.");
5634 bool foundDistributedDim = false;
5635 for (int64_t i = 0, e = expandedVecType.getRank(); i < e; i++) {
5636 if (expandedVecType.getDimSize(i) == distributedVecType.getDimSize(i))
5637 continue;
5638 if (expandedVecType.getDimSize(i) ==
5639 distributedVecType.getDimSize(i) * warpSize) {
5640 if (foundDistributedDim)
5641 return op->emitOpError()
5642 << "expected only one dimension to be distributed from "
5643 << expandedVecType << " to " << distributedVecType;
5644 foundDistributedDim = true;
5645 continue;
5646 }
5647 return op->emitOpError() << "incompatible distribution dimensions from "
5648 << expandedVecType << " to " << distributedVecType;
5649 }
5650 return success();
5651}
5652
5653LogicalResult WarpExecuteOnLane0Op::verify() {
5654 if (getArgs().size() != getWarpRegion().getNumArguments())
5655 return emitOpError(
5656 "expected same number op arguments and block arguments.");
5657 auto yield =
5658 cast<YieldOp>(getWarpRegion().getBlocks().begin()->getTerminator());
5659 if (yield.getNumOperands() != getNumResults())
5660 return emitOpError(
5661 "expected same number of yield operands and return values.");
5662 int64_t warpSize = getWarpSize();
5663 for (auto [regionArg, arg] :
5664 llvm::zip_equal(getWarpRegion().getArguments(), getArgs())) {
5665 if (failed(verifyDistributedType(regionArg.getType(), arg.getType(),
5666 warpSize, getOperation())))
5667 return failure();
5668 }
5669 for (auto [yieldOperand, result] :
5670 llvm::zip_equal(yield.getOperands(), getResults())) {
5671 if (failed(verifyDistributedType(yieldOperand.getType(), result.getType(),
5672 warpSize, getOperation())))
5673 return failure();
5674 }
5675 return success();
5676}
5677
5678bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
5679 return succeeded(
5680 verifyDistributedType(lhs, rhs, getWarpSize(), getOperation()));
5681}
5682
5683Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
5684 CombiningKind kind, Value v1, Value v2) {
5685 Type t1 = getElementTypeOrSelf(v1.getType());
5686 Type t2 = getElementTypeOrSelf(v2.getType());
5687 switch (kind) {
5688 case CombiningKind::ADD:
5689 if (t1.isIntOrIndex() && t2.isIntOrIndex())
5690 return b.createOrFold<arith::AddIOp>(loc, v1, v2);
5691 else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5692 return b.createOrFold<arith::AddFOp>(loc, v1, v2);
5693 llvm_unreachable("invalid value types for ADD reduction")::llvm::llvm_unreachable_internal("invalid value types for ADD reduction"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5693)
;
5694 case CombiningKind::AND:
5695 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5695, __extension__
__PRETTY_FUNCTION__))
;
5696 return b.createOrFold<arith::AndIOp>(loc, v1, v2);
5697 case CombiningKind::MAXF:
5698 assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5699, __extension__
__PRETTY_FUNCTION__))
5699 "expected float values")(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5699, __extension__
__PRETTY_FUNCTION__))
;
5700 return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
5701 case CombiningKind::MINF:
5702 assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5703, __extension__
__PRETTY_FUNCTION__))
5703 "expected float values")(static_cast <bool> (t1.isa<FloatType>() &&
t2.isa<FloatType>() && "expected float values"
) ? void (0) : __assert_fail ("t1.isa<FloatType>() && t2.isa<FloatType>() && \"expected float values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5703, __extension__
__PRETTY_FUNCTION__))
;
5704 return b.createOrFold<arith::MinFOp>(loc, v1, v2);
5705 case CombiningKind::MAXSI:
5706 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5706, __extension__
__PRETTY_FUNCTION__))
;
5707 return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
5708 case CombiningKind::MINSI:
5709 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5709, __extension__
__PRETTY_FUNCTION__))
;
5710 return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
5711 case CombiningKind::MAXUI:
5712 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5712, __extension__
__PRETTY_FUNCTION__))
;
5713 return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
5714 case CombiningKind::MINUI:
5715 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5715, __extension__
__PRETTY_FUNCTION__))
;
5716 return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
5717 case CombiningKind::MUL:
5718 if (t1.isIntOrIndex() && t2.isIntOrIndex())
5719 return b.createOrFold<arith::MulIOp>(loc, v1, v2);
5720 else if (t1.isa<FloatType>() && t2.isa<FloatType>())
5721 return b.createOrFold<arith::MulFOp>(loc, v1, v2);
5722 llvm_unreachable("invalid value types for MUL reduction")::llvm::llvm_unreachable_internal("invalid value types for MUL reduction"
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5722)
;
5723 case CombiningKind::OR:
5724 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5724, __extension__
__PRETTY_FUNCTION__))
;
5725 return b.createOrFold<arith::OrIOp>(loc, v1, v2);
5726 case CombiningKind::XOR:
5727 assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values")(static_cast <bool> (t1.isIntOrIndex() && t2.isIntOrIndex
() && "expected int values") ? void (0) : __assert_fail
("t1.isIntOrIndex() && t2.isIntOrIndex() && \"expected int values\""
, "mlir/lib/Dialect/Vector/IR/VectorOps.cpp", 5727, __extension__
__PRETTY_FUNCTION__))
;
5728 return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
5729 };
5730 llvm_unreachable("unknown CombiningKind")::llvm::llvm_unreachable_internal("unknown CombiningKind", "mlir/lib/Dialect/Vector/IR/VectorOps.cpp"
, 5730)
;
5731}
5732
5733//===----------------------------------------------------------------------===//
5734// TableGen'd op method definitions
5735//===----------------------------------------------------------------------===//
5736
5737#define GET_ATTRDEF_CLASSES
5738#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.cpp.inc"
5739
5740#define GET_OP_CLASSES
5741#include "mlir/Dialect/Vector/IR/VectorOps.cpp.inc"

/build/source/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() = default;
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a loc(...) specifier if printing debug info is enabled.
331 virtual void printOptionalLocationSpecifier(Location loc) = 0;
332
333 /// Print a newline and indent the printer to the start of the current
334 /// operation.
335 virtual void printNewline() = 0;
336
337 /// Increase indentation.
338 virtual void increaseIndent() = 0;
339
340 /// Decrease indentation.
341 virtual void decreaseIndent() = 0;
342
343 /// Print a block argument in the usual format of:
344 /// %ssaName : type {attr1=42} loc("here")
345 /// where location printing is controlled by the standard internal option.
346 /// You may pass omitType=true to not print a type, and pass an empty
347 /// attribute list if you don't care for attributes.
348 virtual void printRegionArgument(BlockArgument arg,
349 ArrayRef<NamedAttribute> argAttrs = {},
350 bool omitType = false) = 0;
351
352 /// Print implementations for various things an operation contains.
353 virtual void printOperand(Value value) = 0;
354 virtual void printOperand(Value value, raw_ostream &os) = 0;
355
356 /// Print a comma separated list of operands.
357 template <typename ContainerType>
358 void printOperands(const ContainerType &container) {
359 printOperands(container.begin(), container.end());
360 }
361
362 /// Print a comma separated list of operands.
363 template <typename IteratorType>
364 void printOperands(IteratorType it, IteratorType end) {
365 llvm::interleaveComma(llvm::make_range(it, end), getStream(),
366 [this](Value value) { printOperand(value); });
367 }
368
369 /// Print the given successor.
370 virtual void printSuccessor(Block *successor) = 0;
371
372 /// Print the successor and its operands.
373 virtual void printSuccessorAndUseList(Block *successor,
374 ValueRange succOperands) = 0;
375
376 /// If the specified operation has attributes, print out an attribute
377 /// dictionary with their values. elidedAttrs allows the client to ignore
378 /// specific well known attributes, commonly used if the attribute value is
379 /// printed some other way (like as a fixed operand).
380 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
381 ArrayRef<StringRef> elidedAttrs = {}) = 0;
382
383 /// If the specified operation has attributes, print out an attribute
384 /// dictionary prefixed with 'attributes'.
385 virtual void
386 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
387 ArrayRef<StringRef> elidedAttrs = {}) = 0;
388
389 /// Prints the entire operation with the custom assembly form, if available,
390 /// or the generic assembly form, otherwise.
391 virtual void printCustomOrGenericOp(Operation *op) = 0;
392
393 /// Print the entire operation with the default generic assembly form.
394 /// If `printOpName` is true, then the operation name is printed (the default)
395 /// otherwise it is omitted and the print will start with the operand list.
396 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
397
398 /// Prints a region.
399 /// If 'printEntryBlockArgs' is false, the arguments of the
400 /// block are not printed. If 'printBlockTerminator' is false, the terminator
401 /// operation of the block is not printed. If printEmptyBlock is true, then
402 /// the block header is printed even if the block is empty.
403 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
404 bool printBlockTerminators = true,
405 bool printEmptyBlock = false) = 0;
406
407 /// Renumber the arguments for the specified region to the same names as the
408 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
409 /// operations. If any entry in namesToUse is null, the corresponding
410 /// argument name is left alone.
411 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
412
413 /// Prints an affine map of SSA ids, where SSA id names are used in place
414 /// of dims/symbols.
415 /// Operand values must come from single-result sources, and be valid
416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
418 ValueRange operands) = 0;
419
420 /// Prints an affine expression of SSA ids with SSA id names used instead of
421 /// dims and symbols.
422 /// Operand values must come from single-result sources, and be valid
423 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
424 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
425 ValueRange symOperands) = 0;
426
427 /// Print the complete type of an operation in functional form.
428 void printFunctionalType(Operation *op);
429 using AsmPrinter::printFunctionalType;
430};
431
432// Make the implementations convenient to use.
433inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
434 p.printOperand(value);
435 return p;
436}
437
438template <typename T,
439 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
440 !std::is_convertible<T &, Value &>::value,
441 T> * = nullptr>
442inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
443 p.printOperands(values);
444 return p;
445}
446
447inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
448 p.printSuccessor(value);
449 return p;
450}
451
452//===----------------------------------------------------------------------===//
453// AsmParser
454//===----------------------------------------------------------------------===//
455
456/// This base class exposes generic asm parser hooks, usable across the various
457/// derived parsers.
458class AsmParser {
459public:
460 AsmParser() = default;
461 virtual ~AsmParser();
462
463 MLIRContext *getContext() const;
464
465 /// Return the location of the original name token.
466 virtual SMLoc getNameLoc() const = 0;
467
468 //===--------------------------------------------------------------------===//
469 // Utilities
470 //===--------------------------------------------------------------------===//
471
472 /// Emit a diagnostic at the specified location and return failure.
473 virtual InFlightDiagnostic emitError(SMLoc loc,
474 const Twine &message = {}) = 0;
475
476 /// Return a builder which provides useful access to MLIRContext, global
477 /// objects like types and attributes.
478 virtual Builder &getBuilder() const = 0;
479
480 /// Get the location of the next token and store it into the argument. This
481 /// always succeeds.
482 virtual SMLoc getCurrentLocation() = 0;
483 ParseResult getCurrentLocation(SMLoc *loc) {
484 *loc = getCurrentLocation();
485 return success();
486 }
487
488 /// Re-encode the given source location as an MLIR location and return it.
489 /// Note: This method should only be used when a `Location` is necessary, as
490 /// the encoding process is not efficient.
491 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
492
493 //===--------------------------------------------------------------------===//
494 // Token Parsing
495 //===--------------------------------------------------------------------===//
496
497 /// Parse a '->' token.
498 virtual ParseResult parseArrow() = 0;
499
500 /// Parse a '->' token if present
501 virtual ParseResult parseOptionalArrow() = 0;
502
503 /// Parse a `{` token.
504 virtual ParseResult parseLBrace() = 0;
505
506 /// Parse a `{` token if present.
507 virtual ParseResult parseOptionalLBrace() = 0;
508
509 /// Parse a `}` token.
510 virtual ParseResult parseRBrace() = 0;
511
512 /// Parse a `}` token if present.
513 virtual ParseResult parseOptionalRBrace() = 0;
514
515 /// Parse a `:` token.
516 virtual ParseResult parseColon() = 0;
517
518 /// Parse a `:` token if present.
519 virtual ParseResult parseOptionalColon() = 0;
520
521 /// Parse a `,` token.
522 virtual ParseResult parseComma() = 0;
523
524 /// Parse a `,` token if present.
525 virtual ParseResult parseOptionalComma() = 0;
526
527 /// Parse a `=` token.
528 virtual ParseResult parseEqual() = 0;
529
530 /// Parse a `=` token if present.
531 virtual ParseResult parseOptionalEqual() = 0;
532
533 /// Parse a '<' token.
534 virtual ParseResult parseLess() = 0;
535
536 /// Parse a '<' token if present.
537 virtual ParseResult parseOptionalLess() = 0;
538
539 /// Parse a '>' token.
540 virtual ParseResult parseGreater() = 0;
541
542 /// Parse a '>' token if present.
543 virtual ParseResult parseOptionalGreater() = 0;
544
545 /// Parse a '?' token.
546 virtual ParseResult parseQuestion() = 0;
547
548 /// Parse a '?' token if present.
549 virtual ParseResult parseOptionalQuestion() = 0;
550
551 /// Parse a '+' token.
552 virtual ParseResult parsePlus() = 0;
553
554 /// Parse a '+' token if present.
555 virtual ParseResult parseOptionalPlus() = 0;
556
557 /// Parse a '*' token.
558 virtual ParseResult parseStar() = 0;
559
560 /// Parse a '*' token if present.
561 virtual ParseResult parseOptionalStar() = 0;
562
563 /// Parse a '|' token.
564 virtual ParseResult parseVerticalBar() = 0;
565
566 /// Parse a '|' token if present.
567 virtual ParseResult parseOptionalVerticalBar() = 0;
568
569 /// Parse a quoted string token.
570 ParseResult parseString(std::string *string) {
571 auto loc = getCurrentLocation();
572 if (parseOptionalString(string))
573 return emitError(loc, "expected string");
574 return success();
575 }
576
577 /// Parse a quoted string token if present.
578 virtual ParseResult parseOptionalString(std::string *string) = 0;
579
580 /// Parses a Base64 encoded string of bytes.
581 virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0;
582
583 /// Parse a `(` token.
584 virtual ParseResult parseLParen() = 0;
585
586 /// Parse a `(` token if present.
587 virtual ParseResult parseOptionalLParen() = 0;
588
589 /// Parse a `)` token.
590 virtual ParseResult parseRParen() = 0;
591
592 /// Parse a `)` token if present.
593 virtual ParseResult parseOptionalRParen() = 0;
594
595 /// Parse a `[` token.
596 virtual ParseResult parseLSquare() = 0;
597
598 /// Parse a `[` token if present.
599 virtual ParseResult parseOptionalLSquare() = 0;
600
601 /// Parse a `]` token.
602 virtual ParseResult parseRSquare() = 0;
603
604 /// Parse a `]` token if present.
605 virtual ParseResult parseOptionalRSquare() = 0;
606
607 /// Parse a `...` token.
608 virtual ParseResult parseEllipsis() = 0;
609
610 /// Parse a `...` token if present;
611 virtual ParseResult parseOptionalEllipsis() = 0;
612
613 /// Parse a floating point value from the stream.
614 virtual ParseResult parseFloat(double &result) = 0;
615
616 /// Parse an integer value from the stream.
617 template <typename IntT>
618 ParseResult parseInteger(IntT &result) {
619 auto loc = getCurrentLocation();
620 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
621 if (!parseResult.has_value())
9
Taking true branch
622 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
623 return *parseResult;
624 }
625
626 /// Parse an optional integer value from the stream.
627 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
628
629 template <typename IntT>
630 OptionalParseResult parseOptionalInteger(IntT &result) {
631 auto loc = getCurrentLocation();
632
633 // Parse the unsigned variant.
634 APInt uintResult;
635 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
636 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
637 return parseResult;
7
Returning without writing to 'result'
638
639 // Try to convert to the provided integer type. sextOrTrunc is correct even
640 // for unsigned types because parseOptionalInteger ensures the sign bit is
641 // zero for non-negated integers.
642 result =
643 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
644 if (APInt(uintResult.getBitWidth(), result) != uintResult)
645 return emitError(loc, "integer value too large");
646 return success();
647 }
648
649 /// These are the supported delimiters around operand lists and region
650 /// argument lists, used by parseOperandList.
651 enum class Delimiter {
652 /// Zero or more operands with no delimiters.
653 None,
654 /// Parens surrounding zero or more operands.
655 Paren,
656 /// Square brackets surrounding zero or more operands.
657 Square,
658 /// <> brackets surrounding zero or more operands.
659 LessGreater,
660 /// {} brackets surrounding zero or more operands.
661 Braces,
662 /// Parens supporting zero or more operands, or nothing.
663 OptionalParen,
664 /// Square brackets supporting zero or more ops, or nothing.
665 OptionalSquare,
666 /// <> brackets supporting zero or more ops, or nothing.
667 OptionalLessGreater,
668 /// {} brackets surrounding zero or more operands, or nothing.
669 OptionalBraces,
670 };
671
672 /// Parse a list of comma-separated items with an optional delimiter. If a
673 /// delimiter is provided, then an empty list is allowed. If not, then at
674 /// least one element will be parsed.
675 ///
676 /// contextMessage is an optional message appended to "expected '('" sorts of
677 /// diagnostics when parsing the delimeters.
678 virtual ParseResult
679 parseCommaSeparatedList(Delimiter delimiter,
680 function_ref<ParseResult()> parseElementFn,
681 StringRef contextMessage = StringRef()) = 0;
682
683 /// Parse a comma separated list of elements that must have at least one entry
684 /// in it.
685 ParseResult
686 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
687 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
688 }
689
690 //===--------------------------------------------------------------------===//
691 // Keyword Parsing
692 //===--------------------------------------------------------------------===//
693
694 /// This class represents a StringSwitch like class that is useful for parsing
695 /// expected keywords. On construction, it invokes `parseKeyword` and
696 /// processes each of the provided cases statements until a match is hit. The
697 /// provided `ResultT` must be assignable from `failure()`.
698 template <typename ResultT = ParseResult>
699 class KeywordSwitch {
700 public:
701 KeywordSwitch(AsmParser &parser)
702 : parser(parser), loc(parser.getCurrentLocation()) {
703 if (failed(parser.parseKeywordOrCompletion(&keyword)))
704 result = failure();
705 }
706
707 /// Case that uses the provided value when true.
708 KeywordSwitch &Case(StringLiteral str, ResultT value) {
709 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
710 }
711 KeywordSwitch &Default(ResultT value) {
712 return Default([&](StringRef, SMLoc) { return std::move(value); });
713 }
714 /// Case that invokes the provided functor when true. The parameters passed
715 /// to the functor are the keyword, and the location of the keyword (in case
716 /// any errors need to be emitted).
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Case(StringLiteral str, FnT &&fn) {
720 if (result)
721 return *this;
722
723 // If the word was empty, record this as a completion.
724 if (keyword.empty())
725 parser.codeCompleteExpectedTokens(str);
726 else if (keyword == str)
727 result.emplace(std::move(fn(keyword, loc)));
728 return *this;
729 }
730 template <typename FnT>
731 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
732 Default(FnT &&fn) {
733 if (!result)
734 result.emplace(fn(keyword, loc));
735 return *this;
736 }
737
738 /// Returns true if this switch has a value yet.
739 bool hasValue() const { return result.has_value(); }
740
741 /// Return the result of the switch.
742 [[nodiscard]] operator ResultT() {
743 if (!result)
744 return parser.emitError(loc, "unexpected keyword: ") << keyword;
745 return std::move(*result);
746 }
747
748 private:
749 /// The parser used to construct this switch.
750 AsmParser &parser;
751
752 /// The location of the keyword, used to emit errors as necessary.
753 SMLoc loc;
754
755 /// The parsed keyword itself.
756 StringRef keyword;
757
758 /// The result of the switch statement or none if currently unknown.
759 Optional<ResultT> result;
760 };
761
762 /// Parse a given keyword.
763 ParseResult parseKeyword(StringRef keyword) {
764 return parseKeyword(keyword, "");
765 }
766 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
767
768 /// Parse a keyword into 'keyword'.
769 ParseResult parseKeyword(StringRef *keyword) {
770 auto loc = getCurrentLocation();
771 if (parseOptionalKeyword(keyword))
772 return emitError(loc, "expected valid keyword");
773 return success();
774 }
775
776 /// Parse the given keyword if present.
777 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
778
779 /// Parse a keyword, if present, into 'keyword'.
780 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
781
782 /// Parse a keyword, if present, and if one of the 'allowedValues',
783 /// into 'keyword'
784 virtual ParseResult
785 parseOptionalKeyword(StringRef *keyword,
786 ArrayRef<StringRef> allowedValues) = 0;
787
788 /// Parse a keyword or a quoted string.
789 ParseResult parseKeywordOrString(std::string *result) {
790 if (failed(parseOptionalKeywordOrString(result)))
791 return emitError(getCurrentLocation())
792 << "expected valid keyword or string";
793 return success();
794 }
795
796 /// Parse an optional keyword or string.
797 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
798
799 //===--------------------------------------------------------------------===//
800 // Attribute/Type Parsing
801 //===--------------------------------------------------------------------===//
802
803 /// Invoke the `getChecked` method of the given Attribute or Type class, using
804 /// the provided location to emit errors in the case of failure. Note that
805 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
806 /// context parameter.
807 template <typename T, typename... ParamsT>
808 auto getChecked(SMLoc loc, ParamsT &&...params) {
809 return T::getChecked([&] { return emitError(loc); },
810 std::forward<ParamsT>(params)...);
811 }
812 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
813 /// errors.
814 template <typename T, typename... ParamsT>
815 auto getChecked(ParamsT &&...params) {
816 return T::getChecked([&] { return emitError(getNameLoc()); },
817 std::forward<ParamsT>(params)...);
818 }
819
820 //===--------------------------------------------------------------------===//
821 // Attribute Parsing
822 //===--------------------------------------------------------------------===//
823
824 /// Parse an arbitrary attribute of a given type and return it in result.
825 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
826
827 /// Parse a custom attribute with the provided callback, unless the next
828 /// token is `#`, in which case the generic parser is invoked.
829 virtual ParseResult parseCustomAttributeWithFallback(
830 Attribute &result, Type type,
831 function_ref<ParseResult(Attribute &result, Type type)>
832 parseAttribute) = 0;
833
834 /// Parse an attribute of a specific kind and type.
835 template <typename AttrType>
836 ParseResult parseAttribute(AttrType &result, Type type = {}) {
837 SMLoc loc = getCurrentLocation();
838
839 // Parse any kind of attribute.
840 Attribute attr;
841 if (parseAttribute(attr, type))
842 return failure();
843
844 // Check for the right kind of attribute.
845 if (!(result = attr.dyn_cast<AttrType>()))
846 return emitError(loc, "invalid kind of attribute specified");
847
848 return success();
849 }
850
851 /// Parse an arbitrary attribute and return it in result. This also adds the
852 /// attribute to the specified attribute list with the specified name.
853 ParseResult parseAttribute(Attribute &result, StringRef attrName,
854 NamedAttrList &attrs) {
855 return parseAttribute(result, Type(), attrName, attrs);
856 }
857
858 /// Parse an attribute of a specific kind and type.
859 template <typename AttrType>
860 ParseResult parseAttribute(AttrType &result, StringRef attrName,
861 NamedAttrList &attrs) {
862 return parseAttribute(result, Type(), attrName, attrs);
863 }
864
865 /// Parse an arbitrary attribute of a given type and populate it in `result`.
866 /// This also adds the attribute to the specified attribute list with the
867 /// specified name.
868 template <typename AttrType>
869 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
870 NamedAttrList &attrs) {
871 SMLoc loc = getCurrentLocation();
872
873 // Parse any kind of attribute.
874 Attribute attr;
875 if (parseAttribute(attr, type))
876 return failure();
877
878 // Check for the right kind of attribute.
879 result = attr.dyn_cast<AttrType>();
880 if (!result)
881 return emitError(loc, "invalid kind of attribute specified");
882
883 attrs.append(attrName, result);
884 return success();
885 }
886
887 /// Trait to check if `AttrType` provides a `parse` method.
888 template <typename AttrType>
889 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
890 std::declval<Type>()));
891 template <typename AttrType>
892 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
893
894 /// Parse a custom attribute of a given type unless the next token is `#`, in
895 /// which case the generic parser is invoked. The parsed attribute is
896 /// populated in `result` and also added to the specified attribute list with
897 /// the specified name.
898 template <typename AttrType>
899 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
900 parseCustomAttributeWithFallback(AttrType &result, Type type,
901 StringRef attrName, NamedAttrList &attrs) {
902 SMLoc loc = getCurrentLocation();
903
904 // Parse any kind of attribute.
905 Attribute attr;
906 if (parseCustomAttributeWithFallback(
907 attr, type, [&](Attribute &result, Type type) -> ParseResult {
908 result = AttrType::parse(*this, type);
909 if (!result)
910 return failure();
911 return success();
912 }))
913 return failure();
914
915 // Check for the right kind of attribute.
916 result = attr.dyn_cast<AttrType>();
917 if (!result)
918 return emitError(loc, "invalid kind of attribute specified");
919
920 attrs.append(attrName, result);
921 return success();
922 }
923
924 /// SFINAE parsing method for Attribute that don't implement a parse method.
925 template <typename AttrType>
926 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
927 parseCustomAttributeWithFallback(AttrType &result, Type type,
928 StringRef attrName, NamedAttrList &attrs) {
929 return parseAttribute(result, type, attrName, attrs);
930 }
931
932 /// Parse a custom attribute of a given type unless the next token is `#`, in
933 /// which case the generic parser is invoked. The parsed attribute is
934 /// populated in `result`.
935 template <typename AttrType>
936 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
937 parseCustomAttributeWithFallback(AttrType &result) {
938 SMLoc loc = getCurrentLocation();
939
940 // Parse any kind of attribute.
941 Attribute attr;
942 if (parseCustomAttributeWithFallback(
943 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
944 result = AttrType::parse(*this, type);
945 return success(!!result);
946 }))
947 return failure();
948
949 // Check for the right kind of attribute.
950 result = attr.dyn_cast<AttrType>();
951 if (!result)
952 return emitError(loc, "invalid kind of attribute specified");
953 return success();
954 }
955
956 /// SFINAE parsing method for Attribute that don't implement a parse method.
957 template <typename AttrType>
958 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
959 parseCustomAttributeWithFallback(AttrType &result) {
960 return parseAttribute(result);
961 }
962
963 /// Parse an arbitrary optional attribute of a given type and return it in
964 /// result.
965 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
966 Type type = {}) = 0;
967
968 /// Parse an optional array attribute and return it in result.
969 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
970 Type type = {}) = 0;
971
972 /// Parse an optional string attribute and return it in result.
973 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
974 Type type = {}) = 0;
975
976 /// Parse an optional symbol ref attribute and return it in result.
977 virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result,
978 Type type = {}) = 0;
979
980 /// Parse an optional attribute of a specific type and add it to the list with
981 /// the specified name.
982 template <typename AttrType>
983 OptionalParseResult parseOptionalAttribute(AttrType &result,
984 StringRef attrName,
985 NamedAttrList &attrs) {
986 return parseOptionalAttribute(result, Type(), attrName, attrs);
987 }
988
989 /// Parse an optional attribute of a specific type and add it to the list with
990 /// the specified name.
991 template <typename AttrType>
992 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
993 StringRef attrName,
994 NamedAttrList &attrs) {
995 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
996 if (parseResult.has_value() && succeeded(*parseResult))
997 attrs.append(attrName, result);
998 return parseResult;
999 }
1000
1001 /// Parse a named dictionary into 'result' if it is present.
1002 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
1003
1004 /// Parse a named dictionary into 'result' if the `attributes` keyword is
1005 /// present.
1006 virtual ParseResult
1007 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
1008
1009 /// Parse an affine map instance into 'map'.
1010 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
1011
1012 /// Parse an integer set instance into 'set'.
1013 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
1014
1015 //===--------------------------------------------------------------------===//
1016 // Identifier Parsing
1017 //===--------------------------------------------------------------------===//
1018
1019 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1020 /// attribute.
1021 ParseResult parseSymbolName(StringAttr &result) {
1022 if (failed(parseOptionalSymbolName(result)))
1023 return emitError(getCurrentLocation())
1024 << "expected valid '@'-identifier for symbol name";
1025 return success();
1026 }
1027
1028 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1029 /// attribute named 'attrName'.
1030 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1031 NamedAttrList &attrs) {
1032 if (parseSymbolName(result))
1033 return failure();
1034 attrs.append(attrName, result);
1035 return success();
1036 }
1037
1038 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1039 /// string attribute.
1040 virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0;
1041
1042 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1043 /// string attribute named 'attrName'.
1044 ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
1045 NamedAttrList &attrs) {
1046 if (succeeded(parseOptionalSymbolName(result))) {
1047 attrs.append(attrName, result);
1048 return success();
1049 }
1050 return failure();
1051 }
1052
1053 //===--------------------------------------------------------------------===//
1054 // Resource Parsing
1055 //===--------------------------------------------------------------------===//
1056
1057 /// Parse a handle to a resource within the assembly format.
1058 template <typename ResourceT>
1059 FailureOr<ResourceT> parseResourceHandle() {
1060 SMLoc handleLoc = getCurrentLocation();
1061
1062 // Try to load the dialect that owns the handle.
1063 auto *dialect =
1064 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1065 if (!dialect) {
1066 return emitError(handleLoc)
1067 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1068 << "' is unknown";
1069 }
1070
1071 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1072 if (failed(handle))
1073 return failure();
1074 if (auto *result = dyn_cast<ResourceT>(&*handle))
1075 return std::move(*result);
1076 return emitError(handleLoc) << "provided resource handle differs from the "
1077 "expected resource type";
1078 }
1079
1080 //===--------------------------------------------------------------------===//
1081 // Type Parsing
1082 //===--------------------------------------------------------------------===//
1083
1084 /// Parse a type.
1085 virtual ParseResult parseType(Type &result) = 0;
1086
1087 /// Parse a custom type with the provided callback, unless the next
1088 /// token is `#`, in which case the generic parser is invoked.
1089 virtual ParseResult parseCustomTypeWithFallback(
1090 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1091
1092 /// Parse an optional type.
1093 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1094
1095 /// Parse a type of a specific type.
1096 template <typename TypeT>
1097 ParseResult parseType(TypeT &result) {
1098 SMLoc loc = getCurrentLocation();
1099
1100 // Parse any kind of type.
1101 Type type;
1102 if (parseType(type))
1103 return failure();
1104
1105 // Check for the right kind of type.
1106 result = type.dyn_cast<TypeT>();
1107 if (!result)
1108 return emitError(loc, "invalid kind of type specified");
1109
1110 return success();
1111 }
1112
1113 /// Trait to check if `TypeT` provides a `parse` method.
1114 template <typename TypeT>
1115 using type_has_parse_method =
1116 decltype(TypeT::parse(std::declval<AsmParser &>()));
1117 template <typename TypeT>
1118 using detect_type_has_parse_method =
1119 llvm::is_detected<type_has_parse_method, TypeT>;
1120
1121 /// Parse a custom Type of a given type unless the next token is `#`, in
1122 /// which case the generic parser is invoked. The parsed Type is
1123 /// populated in `result`.
1124 template <typename TypeT>
1125 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1126 parseCustomTypeWithFallback(TypeT &result) {
1127 SMLoc loc = getCurrentLocation();
1128
1129 // Parse any kind of Type.
1130 Type type;
1131 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1132 result = TypeT::parse(*this);
1133 return success(!!result);
1134 }))
1135 return failure();
1136
1137 // Check for the right kind of Type.
1138 result = type.dyn_cast<TypeT>();
1139 if (!result)
1140 return emitError(loc, "invalid kind of Type specified");
1141 return success();
1142 }
1143
1144 /// SFINAE parsing method for Type that don't implement a parse method.
1145 template <typename TypeT>
1146 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1147 parseCustomTypeWithFallback(TypeT &result) {
1148 return parseType(result);
1149 }
1150
1151 /// Parse a type list.
1152 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1153 return parseCommaSeparatedList(
1154 [&]() { return parseType(result.emplace_back()); });
1155 }
1156
1157 /// Parse an arrow followed by a type list.
1158 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1159
1160 /// Parse an optional arrow followed by a type list.
1161 virtual ParseResult
1162 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1163
1164 /// Parse a colon followed by a type.
1165 virtual ParseResult parseColonType(Type &result) = 0;
1166
1167 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1168 template <typename TypeType>
1169 ParseResult parseColonType(TypeType &result) {
1170 SMLoc loc = getCurrentLocation();
1171
1172 // Parse any kind of type.
1173 Type type;
1174 if (parseColonType(type))
1175 return failure();
1176
1177 // Check for the right kind of type.
1178 result = type.dyn_cast<TypeType>();
1179 if (!result)
1180 return emitError(loc, "invalid kind of type specified");
1181
1182 return success();
1183 }
1184
1185 /// Parse a colon followed by a type list, which must have at least one type.
1186 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1187
1188 /// Parse an optional colon followed by a type list, which if present must
1189 /// have at least one type.
1190 virtual ParseResult
1191 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1192
1193 /// Parse a keyword followed by a type.
1194 ParseResult parseKeywordType(const char *keyword, Type &result) {
1195 return failure(parseKeyword(keyword) || parseType(result));
1196 }
1197
1198 /// Add the specified type to the end of the specified type list and return
1199 /// success. This is a helper designed to allow parse methods to be simple
1200 /// and chain through || operators.
1201 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1202 result.push_back(type);
1203 return success();
1204 }
1205
1206 /// Add the specified types to the end of the specified type list and return
1207 /// success. This is a helper designed to allow parse methods to be simple
1208 /// and chain through || operators.
1209 ParseResult addTypesToList(ArrayRef<Type> types,
1210 SmallVectorImpl<Type> &result) {
1211 result.append(types.begin(), types.end());
1212 return success();
1213 }
1214
1215 /// Parse a dimension list of a tensor or memref type. This populates the
1216 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1217 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1218 ///
1219 /// dimension-list ::= eps | dimension (`x` dimension)*
1220 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1221 /// dimension ::= `?` | decimal-literal
1222 ///
1223 /// When `allowDynamic` is not set, this is used to parse:
1224 ///
1225 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1226 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1227 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1228 bool allowDynamic = true,
1229 bool withTrailingX = true) = 0;
1230
1231 /// Parse an 'x' token in a dimension list, handling the case where the x is
1232 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1233 /// next token.
1234 virtual ParseResult parseXInDimensionList() = 0;
1235
1236protected:
1237 /// Parse a handle to a resource within the assembly format for the given
1238 /// dialect.
1239 virtual FailureOr<AsmDialectResourceHandle>
1240 parseResourceHandle(Dialect *dialect) = 0;
1241
1242 //===--------------------------------------------------------------------===//
1243 // Code Completion
1244 //===--------------------------------------------------------------------===//
1245
1246 /// Parse a keyword, or an empty string if the current location signals a code
1247 /// completion.
1248 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1249
1250 /// Signal the code completion of a set of expected tokens.
1251 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1252
1253private:
1254 AsmParser(const AsmParser &) = delete;
1255 void operator=(const AsmParser &) = delete;
1256};
1257
1258//===----------------------------------------------------------------------===//
1259// OpAsmParser
1260//===----------------------------------------------------------------------===//
1261
1262/// The OpAsmParser has methods for interacting with the asm parser: parsing
1263/// things from it, emitting errors etc. It has an intentionally high-level API
1264/// that is designed to reduce/constrain syntax innovation in individual
1265/// operations.
1266///
1267/// For example, consider an op like this:
1268///
1269/// %x = load %p[%1, %2] : memref<...>
1270///
1271/// The "%x = load" tokens are already parsed and therefore invisible to the
1272/// custom op parser. This can be supported by calling `parseOperandList` to
1273/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1274/// parse the indices, then calling `parseColonTypeList` to parse the result
1275/// type.
1276///
1277class OpAsmParser : public AsmParser {
1278public:
1279 using AsmParser::AsmParser;
1280 ~OpAsmParser() override;
1281
1282 /// Parse a loc(...) specifier if present, filling in result if so.
1283 /// Location for BlockArgument and Operation may be deferred with an alias, in
1284 /// which case an OpaqueLoc is set and will be resolved when parsing
1285 /// completes.
1286 virtual ParseResult
1287 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1288
1289 /// Return the name of the specified result in the specified syntax, as well
1290 /// as the sub-element in the name. It returns an empty string and ~0U for
1291 /// invalid result numbers. For example, in this operation:
1292 ///
1293 /// %x, %y:2, %z = foo.op
1294 ///
1295 /// getResultName(0) == {"x", 0 }
1296 /// getResultName(1) == {"y", 0 }
1297 /// getResultName(2) == {"y", 1 }
1298 /// getResultName(3) == {"z", 0 }
1299 /// getResultName(4) == {"", ~0U }
1300 virtual std::pair<StringRef, unsigned>
1301 getResultName(unsigned resultNo) const = 0;
1302
1303 /// Return the number of declared SSA results. This returns 4 for the foo.op
1304 /// example in the comment for `getResultName`.
1305 virtual size_t getNumResults() const = 0;
1306
1307 // These methods emit an error and return failure or success. This allows
1308 // these to be chained together into a linear sequence of || expressions in
1309 // many cases.
1310
1311 /// Parse an operation in its generic form.
1312 /// The parsed operation is parsed in the current context and inserted in the
1313 /// provided block and insertion point. The results produced by this operation
1314 /// aren't mapped to any named value in the parser. Returns nullptr on
1315 /// failure.
1316 virtual Operation *parseGenericOperation(Block *insertBlock,
1317 Block::iterator insertPt) = 0;
1318
1319 /// Parse the name of an operation, in the custom form. On success, return a
1320 /// an object of type 'OperationName'. Otherwise, failure is returned.
1321 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1322
1323 //===--------------------------------------------------------------------===//
1324 // Operand Parsing
1325 //===--------------------------------------------------------------------===//
1326
1327 /// This is the representation of an operand reference.
1328 struct UnresolvedOperand {
1329 SMLoc location; // Location of the token.
1330 StringRef name; // Value name, e.g. %42 or %abc
1331 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1332 };
1333
1334 /// Parse different components, viz., use-info of operand(s), successor(s),
1335 /// region(s), attribute(s) and function-type, of the generic form of an
1336 /// operation instance and populate the input operation-state 'result' with
1337 /// those components. If any of the components is explicitly provided, then
1338 /// skip parsing that component.
1339 virtual ParseResult parseGenericOperationAfterOpName(
1340 OperationState &result,
1341 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt,
1342 Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt,
1343 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1344 std::nullopt,
1345 Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
1346 Optional<FunctionType> parsedFnType = std::nullopt) = 0;
1347
1348 /// Parse a single SSA value operand name along with a result number if
1349 /// `allowResultNumber` is true.
1350 virtual ParseResult parseOperand(UnresolvedOperand &result,
1351 bool allowResultNumber = true) = 0;
1352
1353 /// Parse a single operand if present.
1354 virtual OptionalParseResult
1355 parseOptionalOperand(UnresolvedOperand &result,
1356 bool allowResultNumber = true) = 0;
1357
1358 /// Parse zero or more SSA comma-separated operand references with a specified
1359 /// surrounding delimiter, and an optional required operand count.
1360 virtual ParseResult
1361 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1362 Delimiter delimiter = Delimiter::None,
1363 bool allowResultNumber = true,
1364 int requiredOperandCount = -1) = 0;
1365
1366 /// Parse a specified number of comma separated operands.
1367 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1368 int requiredOperandCount,
1369 Delimiter delimiter = Delimiter::None) {
1370 return parseOperandList(result, delimiter,
1371 /*allowResultNumber=*/true, requiredOperandCount);
1372 }
1373
1374 /// Parse zero or more trailing SSA comma-separated trailing operand
1375 /// references with a specified surrounding delimiter, and an optional
1376 /// required operand count. A leading comma is expected before the
1377 /// operands.
1378 ParseResult
1379 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1380 Delimiter delimiter = Delimiter::None) {
1381 if (failed(parseOptionalComma()))
1382 return success(); // The comma is optional.
1383 return parseOperandList(result, delimiter);
1384 }
1385
1386 /// Resolve an operand to an SSA value, emitting an error on failure.
1387 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1388 Type type,
1389 SmallVectorImpl<Value> &result) = 0;
1390
1391 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1392 /// appending the results to the list on success. This method should be used
1393 /// when all operands have the same type.
1394 template <typename Operands = ArrayRef<UnresolvedOperand>>
1395 ParseResult resolveOperands(Operands &&operands, Type type,
1396 SmallVectorImpl<Value> &result) {
1397 for (const UnresolvedOperand &operand : operands)
1398 if (resolveOperand(operand, type, result))
1399 return failure();
1400 return success();
1401 }
1402 template <typename Operands = ArrayRef<UnresolvedOperand>>
1403 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1404 SmallVectorImpl<Value> &result) {
1405 return resolveOperands(std::forward<Operands>(operands), type, result);
1406 }
1407
1408 /// Resolve a list of operands and a list of operand types to SSA values,
1409 /// emitting an error and returning failure, or appending the results
1410 /// to the list on success.
1411 template <typename Operands = ArrayRef<UnresolvedOperand>,
1412 typename Types = ArrayRef<Type>>
1413 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1414 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1415 SmallVectorImpl<Value> &result) {
1416 size_t operandSize = std::distance(operands.begin(), operands.end());
1417 size_t typeSize = std::distance(types.begin(), types.end());
1418 if (operandSize != typeSize)
1419 return emitError(loc)
1420 << operandSize << " operands present, but expected " << typeSize;
1421
1422 for (auto [operand, type] : llvm::zip(operands, types))
1423 if (resolveOperand(operand, type, result))
1424 return failure();
1425 return success();
1426 }
1427
1428 /// Parses an affine map attribute where dims and symbols are SSA operands.
1429 /// Operand values must come from single-result sources, and be valid
1430 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1431 virtual ParseResult
1432 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1433 Attribute &map, StringRef attrName,
1434 NamedAttrList &attrs,
1435 Delimiter delimiter = Delimiter::Square) = 0;
1436
1437 /// Parses an affine expression where dims and symbols are SSA operands.
1438 /// Operand values must come from single-result sources, and be valid
1439 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1440 virtual ParseResult
1441 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1442 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1443 AffineExpr &expr) = 0;
1444
1445 //===--------------------------------------------------------------------===//
1446 // Argument Parsing
1447 //===--------------------------------------------------------------------===//
1448
1449 struct Argument {
1450 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1451 Type type; // Type.
1452 DictionaryAttr attrs; // Attributes if present.
1453 Optional<Location> sourceLoc; // Source location specifier if present.
1454 };
1455
1456 /// Parse a single argument with the following syntax:
1457 ///
1458 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1459 ///
1460 /// If `allowType` is false or `allowAttrs` are false then the respective
1461 /// parts of the grammar are not parsed.
1462 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1463 bool allowAttrs = false) = 0;
1464
1465 /// Parse a single argument if present.
1466 virtual OptionalParseResult
1467 parseOptionalArgument(Argument &result, bool allowType = false,
1468 bool allowAttrs = false) = 0;
1469
1470 /// Parse zero or more arguments with a specified surrounding delimiter.
1471 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1472 Delimiter delimiter = Delimiter::None,
1473 bool allowType = false,
1474 bool allowAttrs = false) = 0;
1475
1476 //===--------------------------------------------------------------------===//
1477 // Region Parsing
1478 //===--------------------------------------------------------------------===//
1479
1480 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1481 /// moved to the op regions after the op is created. The first block of the
1482 /// region takes 'arguments'.
1483 ///
1484 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1485 /// shadow the names of other existing SSA values defined above the region
1486 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1487 /// to operations that are 'IsolatedFromAbove'.
1488 virtual ParseResult parseRegion(Region &region,
1489 ArrayRef<Argument> arguments = {},
1490 bool enableNameShadowing = false) = 0;
1491
1492 /// Parses a region if present.
1493 virtual OptionalParseResult
1494 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1495 bool enableNameShadowing = false) = 0;
1496
1497 /// Parses a region if present. If the region is present, a new region is
1498 /// allocated and placed in `region`. If no region is present or on failure,
1499 /// `region` remains untouched.
1500 virtual OptionalParseResult
1501 parseOptionalRegion(std::unique_ptr<Region> &region,
1502 ArrayRef<Argument> arguments = {},
1503 bool enableNameShadowing = false) = 0;
1504
1505 //===--------------------------------------------------------------------===//
1506 // Successor Parsing
1507 //===--------------------------------------------------------------------===//
1508
1509 /// Parse a single operation successor.
1510 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1511
1512 /// Parse an optional operation successor.
1513 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1514
1515 /// Parse a single operation successor and its operand list.
1516 virtual ParseResult
1517 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1518
1519 //===--------------------------------------------------------------------===//
1520 // Type Parsing
1521 //===--------------------------------------------------------------------===//
1522
1523 /// Parse a list of assignments of the form
1524 /// (%x1 = %y1, %x2 = %y2, ...)
1525 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1526 SmallVectorImpl<UnresolvedOperand> &rhs) {
1527 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1528 if (!result.has_value())
1529 return emitError(getCurrentLocation(), "expected '('");
1530 return result.value();
1531 }
1532
1533 virtual OptionalParseResult
1534 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1535 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1536};
1537
1538//===--------------------------------------------------------------------===//
1539// Dialect OpAsm interface.
1540//===--------------------------------------------------------------------===//
1541
1542/// A functor used to set the name of the start of a result group of an
1543/// operation. See 'getAsmResultNames' below for more details.
1544using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1545
1546/// A functor used to set the name of blocks in regions directly nested under
1547/// an operation.
1548using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1549
1550class OpAsmDialectInterface
1551 : public DialectInterface::Base<OpAsmDialectInterface> {
1552public:
1553 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1554
1555 //===------------------------------------------------------------------===//
1556 // Aliases
1557 //===------------------------------------------------------------------===//
1558
1559 /// Holds the result of `getAlias` hook call.
1560 enum class AliasResult {
1561 /// The object (type or attribute) is not supported by the hook
1562 /// and an alias was not provided.
1563 NoAlias,
1564 /// An alias was provided, but it might be overriden by other hook.
1565 OverridableAlias,
1566 /// An alias was provided and it should be used
1567 /// (no other hooks will be checked).
1568 FinalAlias
1569 };
1570
1571 /// Hooks for getting an alias identifier alias for a given symbol, that is
1572 /// not necessarily a part of this dialect. The identifier is used in place of
1573 /// the symbol when printing textual IR. These aliases must not contain `.` or
1574 /// end with a numeric digit([0-9]+).
1575 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1576 return AliasResult::NoAlias;
1577 }
1578 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1579 return AliasResult::NoAlias;
1580 }
1581
1582 //===--------------------------------------------------------------------===//
1583 // Resources
1584 //===--------------------------------------------------------------------===//
1585
1586 /// Declare a resource with the given key, returning a handle to use for any
1587 /// references of this resource key within the IR during parsing. The result
1588 /// of `getResourceKey` on the returned handle is permitted to be different
1589 /// than `key`.
1590 virtual FailureOr<AsmDialectResourceHandle>
1591 declareResource(StringRef key) const {
1592 return failure();
1593 }
1594
1595 /// Return a key to use for the given resource. This key should uniquely
1596 /// identify this resource within the dialect.
1597 virtual std::string
1598 getResourceKey(const AsmDialectResourceHandle &handle) const {
1599 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
1600 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
;
1601 }
1602
1603 /// Hook for parsing resource entries. Returns failure if the entry was not
1604 /// valid, or could otherwise not be processed correctly. Any necessary errors
1605 /// can be emitted via the provided entry.
1606 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1607
1608 /// Hook for building resources to use during printing. The given `op` may be
1609 /// inspected to help determine what information to include.
1610 /// `referencedResources` contains all of the resources detected when printing
1611 /// 'op'.
1612 virtual void
1613 buildResources(Operation *op,
1614 const SetVector<AsmDialectResourceHandle> &referencedResources,
1615 AsmResourceBuilder &builder) const {}
1616};
1617} // namespace mlir
1618
1619//===--------------------------------------------------------------------===//
1620// Operation OpAsm interface.
1621//===--------------------------------------------------------------------===//
1622
1623/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1624#include "mlir/IR/OpAsmInterface.h.inc"
1625
1626namespace llvm {
1627template <>
1628struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1629 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1630 return {DenseMapInfo<void *>::getEmptyKey(),
1631 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1632 }
1633 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1634 return {DenseMapInfo<void *>::getTombstoneKey(),
1635 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1636 }
1637 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1638 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1639 }
1640 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1641 const mlir::AsmDialectResourceHandle &rhs) {
1642 return lhs.getResource() == rhs.getResource();
1643 }
1644};
1645} // namespace llvm
1646
1647#endif