Bug Summary

File:build/source/mlir/include/mlir/IR/TypeSupport.h
Warning:line 46, column 5
Address of stack memory associated with temporary object of type '(lambda at /build/source/mlir/include/mlir/IR/StorageUniquerSupport.h:133:12)' is still referred to by a temporary object on the stack upon returning to the caller. This will be a dangling reference

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name BuiltinTypes.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-17/lib/clang/17 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GLIBCXX_ASSERTIONS -D _GNU_SOURCE -D _LIBCPP_ENABLE_ASSERTIONS -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/IR -I /build/source/mlir/lib/IR -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-17/lib/clang/17/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1683717183 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2023-05-10-133810-16478-1 -x c++ /build/source/mlir/lib/IR/BuiltinTypes.cpp

/build/source/mlir/lib/IR/BuiltinTypes.cpp

1//===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinTypes.h"
10#include "TypeDetail.h"
11#include "mlir/IR/AffineExpr.h"
12#include "mlir/IR/AffineMap.h"
13#include "mlir/IR/BuiltinAttributes.h"
14#include "mlir/IR/BuiltinDialect.h"
15#include "mlir/IR/Diagnostics.h"
16#include "mlir/IR/Dialect.h"
17#include "mlir/IR/FunctionInterfaces.h"
18#include "mlir/IR/OpImplementation.h"
19#include "mlir/IR/TensorEncoding.h"
20#include "llvm/ADT/APFloat.h"
21#include "llvm/ADT/BitVector.h"
22#include "llvm/ADT/Sequence.h"
23#include "llvm/ADT/Twine.h"
24#include "llvm/ADT/TypeSwitch.h"
25
26using namespace mlir;
27using namespace mlir::detail;
28
29//===----------------------------------------------------------------------===//
30/// Tablegen Type Definitions
31//===----------------------------------------------------------------------===//
32
33#define GET_TYPEDEF_CLASSES
34#include "mlir/IR/BuiltinTypes.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// BuiltinDialect
38//===----------------------------------------------------------------------===//
39
40void BuiltinDialect::registerTypes() {
41 addTypes<
1
Calling 'Dialect::addTypes'
42#define GET_TYPEDEF_LIST
43#include "mlir/IR/BuiltinTypes.cpp.inc"
44 >();
45}
46
47//===----------------------------------------------------------------------===//
48/// ComplexType
49//===----------------------------------------------------------------------===//
50
51/// Verify the construction of an integer type.
52LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError,
53 Type elementType) {
54 if (!elementType.isIntOrFloat())
55 return emitError() << "invalid element type for complex";
56 return success();
57}
58
59//===----------------------------------------------------------------------===//
60// Integer Type
61//===----------------------------------------------------------------------===//
62
63/// Verify the construction of an integer type.
64LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError,
65 unsigned width,
66 SignednessSemantics signedness) {
67 if (width > IntegerType::kMaxWidth) {
68 return emitError() << "integer bitwidth is limited to "
69 << IntegerType::kMaxWidth << " bits";
70 }
71 return success();
72}
73
74unsigned IntegerType::getWidth() const { return getImpl()->width; }
75
76IntegerType::SignednessSemantics IntegerType::getSignedness() const {
77 return getImpl()->signedness;
78}
79
80IntegerType IntegerType::scaleElementBitwidth(unsigned scale) {
81 if (!scale)
82 return IntegerType();
83 return IntegerType::get(getContext(), scale * getWidth(), getSignedness());
84}
85
86//===----------------------------------------------------------------------===//
87// Float Type
88//===----------------------------------------------------------------------===//
89
90unsigned FloatType::getWidth() {
91 if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
92 Float8E4M3FNUZType, Float8E4M3B11FNUZType>())
93 return 8;
94 if (isa<Float16Type, BFloat16Type>())
95 return 16;
96 if (isa<Float32Type>())
97 return 32;
98 if (isa<Float64Type>())
99 return 64;
100 if (isa<Float80Type>())
101 return 80;
102 if (isa<Float128Type>())
103 return 128;
104 llvm_unreachable("unexpected float type")::llvm::llvm_unreachable_internal("unexpected float type", "mlir/lib/IR/BuiltinTypes.cpp"
, 104)
;
105}
106
107/// Returns the floating semantics for the given type.
108const llvm::fltSemantics &FloatType::getFloatSemantics() {
109 if (isa<Float8E5M2Type>())
110 return APFloat::Float8E5M2();
111 if (isa<Float8E4M3FNType>())
112 return APFloat::Float8E4M3FN();
113 if (isa<Float8E5M2FNUZType>())
114 return APFloat::Float8E5M2FNUZ();
115 if (isa<Float8E4M3FNUZType>())
116 return APFloat::Float8E4M3FNUZ();
117 if (isa<Float8E4M3B11FNUZType>())
118 return APFloat::Float8E4M3B11FNUZ();
119 if (isa<BFloat16Type>())
120 return APFloat::BFloat();
121 if (isa<Float16Type>())
122 return APFloat::IEEEhalf();
123 if (isa<Float32Type>())
124 return APFloat::IEEEsingle();
125 if (isa<Float64Type>())
126 return APFloat::IEEEdouble();
127 if (isa<Float80Type>())
128 return APFloat::x87DoubleExtended();
129 if (isa<Float128Type>())
130 return APFloat::IEEEquad();
131 llvm_unreachable("non-floating point type used")::llvm::llvm_unreachable_internal("non-floating point type used"
, "mlir/lib/IR/BuiltinTypes.cpp", 131)
;
132}
133
134FloatType FloatType::scaleElementBitwidth(unsigned scale) {
135 if (!scale)
136 return FloatType();
137 MLIRContext *ctx = getContext();
138 if (isF16() || isBF16()) {
139 if (scale == 2)
140 return FloatType::getF32(ctx);
141 if (scale == 4)
142 return FloatType::getF64(ctx);
143 }
144 if (isF32())
145 if (scale == 2)
146 return FloatType::getF64(ctx);
147 return FloatType();
148}
149
150unsigned FloatType::getFPMantissaWidth() {
151 return APFloat::semanticsPrecision(getFloatSemantics());
152}
153
154//===----------------------------------------------------------------------===//
155// FunctionType
156//===----------------------------------------------------------------------===//
157
158unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; }
159
160ArrayRef<Type> FunctionType::getInputs() const {
161 return getImpl()->getInputs();
162}
163
164unsigned FunctionType::getNumResults() const { return getImpl()->numResults; }
165
166ArrayRef<Type> FunctionType::getResults() const {
167 return getImpl()->getResults();
168}
169
170FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const {
171 return get(getContext(), inputs, results);
172}
173
174/// Returns a new function type with the specified arguments and results
175/// inserted.
176FunctionType FunctionType::getWithArgsAndResults(
177 ArrayRef<unsigned> argIndices, TypeRange argTypes,
178 ArrayRef<unsigned> resultIndices, TypeRange resultTypes) {
179 SmallVector<Type> argStorage, resultStorage;
180 TypeRange newArgTypes = function_interface_impl::insertTypesInto(
181 getInputs(), argIndices, argTypes, argStorage);
182 TypeRange newResultTypes = function_interface_impl::insertTypesInto(
183 getResults(), resultIndices, resultTypes, resultStorage);
184 return clone(newArgTypes, newResultTypes);
185}
186
187/// Returns a new function type without the specified arguments and results.
188FunctionType
189FunctionType::getWithoutArgsAndResults(const BitVector &argIndices,
190 const BitVector &resultIndices) {
191 SmallVector<Type> argStorage, resultStorage;
192 TypeRange newArgTypes = function_interface_impl::filterTypesOut(
193 getInputs(), argIndices, argStorage);
194 TypeRange newResultTypes = function_interface_impl::filterTypesOut(
195 getResults(), resultIndices, resultStorage);
196 return clone(newArgTypes, newResultTypes);
197}
198
199//===----------------------------------------------------------------------===//
200// OpaqueType
201//===----------------------------------------------------------------------===//
202
203/// Verify the construction of an opaque type.
204LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError,
205 StringAttr dialect, StringRef typeData) {
206 if (!Dialect::isValidNamespace(dialect.strref()))
207 return emitError() << "invalid dialect namespace '" << dialect << "'";
208
209 // Check that the dialect is actually registered.
210 MLIRContext *context = dialect.getContext();
211 if (!context->allowsUnregisteredDialects() &&
212 !context->getLoadedDialect(dialect.strref())) {
213 return emitError()
214 << "`!" << dialect << "<\"" << typeData << "\">"
215 << "` type created with unregistered dialect. If this is "
216 "intended, please call allowUnregisteredDialects() on the "
217 "MLIRContext, or use -allow-unregistered-dialect with "
218 "the MLIR opt tool used";
219 }
220
221 return success();
222}
223
224//===----------------------------------------------------------------------===//
225// VectorType
226//===----------------------------------------------------------------------===//
227
228LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError,
229 ArrayRef<int64_t> shape, Type elementType,
230 unsigned numScalableDims) {
231 if (!isValidElementType(elementType))
232 return emitError()
233 << "vector elements must be int/index/float type but got "
234 << elementType;
235
236 if (any_of(shape, [](int64_t i) { return i <= 0; }))
237 return emitError()
238 << "vector types must have positive constant sizes but got "
239 << shape;
240
241 return success();
242}
243
244VectorType VectorType::scaleElementBitwidth(unsigned scale) {
245 if (!scale)
246 return VectorType();
247 if (auto et = getElementType().dyn_cast<IntegerType>())
248 if (auto scaledEt = et.scaleElementBitwidth(scale))
249 return VectorType::get(getShape(), scaledEt, getNumScalableDims());
250 if (auto et = getElementType().dyn_cast<FloatType>())
251 if (auto scaledEt = et.scaleElementBitwidth(scale))
252 return VectorType::get(getShape(), scaledEt, getNumScalableDims());
253 return VectorType();
254}
255
256VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
257 Type elementType) const {
258 return VectorType::get(shape.value_or(getShape()), elementType,
259 getNumScalableDims());
260}
261
262//===----------------------------------------------------------------------===//
263// TensorType
264//===----------------------------------------------------------------------===//
265
266Type TensorType::getElementType() const {
267 return llvm::TypeSwitch<TensorType, Type>(*this)
268 .Case<RankedTensorType, UnrankedTensorType>(
269 [](auto type) { return type.getElementType(); });
270}
271
272bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); }
273
274ArrayRef<int64_t> TensorType::getShape() const {
275 return cast<RankedTensorType>().getShape();
276}
277
278TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
279 Type elementType) const {
280 if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) {
281 if (shape)
282 return RankedTensorType::get(*shape, elementType);
283 return UnrankedTensorType::get(elementType);
284 }
285
286 auto rankedTy = cast<RankedTensorType>();
287 if (!shape)
288 return RankedTensorType::get(rankedTy.getShape(), elementType,
289 rankedTy.getEncoding());
290 return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType,
291 rankedTy.getEncoding());
292}
293
294// Check if "elementType" can be an element type of a tensor.
295static LogicalResult
296checkTensorElementType(function_ref<InFlightDiagnostic()> emitError,
297 Type elementType) {
298 if (!TensorType::isValidElementType(elementType))
299 return emitError() << "invalid tensor element type: " << elementType;
300 return success();
301}
302
303/// Return true if the specified element type is ok in a tensor.
304bool TensorType::isValidElementType(Type type) {
305 // Note: Non standard/builtin types are allowed to exist within tensor
306 // types. Dialects are expected to verify that tensor types have a valid
307 // element type within that dialect.
308 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
309 IndexType>() ||
310 !llvm::isa<BuiltinDialect>(type.getDialect());
311}
312
313//===----------------------------------------------------------------------===//
314// RankedTensorType
315//===----------------------------------------------------------------------===//
316
317LogicalResult
318RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
319 ArrayRef<int64_t> shape, Type elementType,
320 Attribute encoding) {
321 for (int64_t s : shape)
322 if (s < 0 && !ShapedType::isDynamic(s))
323 return emitError() << "invalid tensor dimension size";
324 if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>())
325 if (failed(v.verifyEncoding(shape, elementType, emitError)))
326 return failure();
327 return checkTensorElementType(emitError, elementType);
328}
329
330//===----------------------------------------------------------------------===//
331// UnrankedTensorType
332//===----------------------------------------------------------------------===//
333
334LogicalResult
335UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
336 Type elementType) {
337 return checkTensorElementType(emitError, elementType);
338}
339
340//===----------------------------------------------------------------------===//
341// BaseMemRefType
342//===----------------------------------------------------------------------===//
343
344Type BaseMemRefType::getElementType() const {
345 return llvm::TypeSwitch<BaseMemRefType, Type>(*this)
346 .Case<MemRefType, UnrankedMemRefType>(
347 [](auto type) { return type.getElementType(); });
348}
349
350bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); }
351
352ArrayRef<int64_t> BaseMemRefType::getShape() const {
353 return cast<MemRefType>().getShape();
354}
355
356BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
357 Type elementType) const {
358 if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) {
359 if (!shape)
360 return UnrankedMemRefType::get(elementType, getMemorySpace());
361 MemRefType::Builder builder(*shape, elementType);
362 builder.setMemorySpace(getMemorySpace());
363 return builder;
364 }
365
366 MemRefType::Builder builder(cast<MemRefType>());
367 if (shape)
368 builder.setShape(*shape);
369 builder.setElementType(elementType);
370 return builder;
371}
372
373Attribute BaseMemRefType::getMemorySpace() const {
374 if (auto rankedMemRefTy = dyn_cast<MemRefType>())
375 return rankedMemRefTy.getMemorySpace();
376 return cast<UnrankedMemRefType>().getMemorySpace();
377}
378
379unsigned BaseMemRefType::getMemorySpaceAsInt() const {
380 if (auto rankedMemRefTy = dyn_cast<MemRefType>())
381 return rankedMemRefTy.getMemorySpaceAsInt();
382 return cast<UnrankedMemRefType>().getMemorySpaceAsInt();
383}
384
385//===----------------------------------------------------------------------===//
386// MemRefType
387//===----------------------------------------------------------------------===//
388
389/// Given an `originalShape` and a `reducedShape` assumed to be a subset of
390/// `originalShape` with some `1` entries erased, return the set of indices
391/// that specifies which of the entries of `originalShape` are dropped to obtain
392/// `reducedShape`. The returned mask can be applied as a projection to
393/// `originalShape` to obtain the `reducedShape`. This mask is useful to track
394/// which dimensions must be kept when e.g. compute MemRef strides under
395/// rank-reducing operations. Return std::nullopt if reducedShape cannot be
396/// obtained by dropping only `1` entries in `originalShape`.
397std::optional<llvm::SmallDenseSet<unsigned>>
398mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape,
399 ArrayRef<int64_t> reducedShape) {
400 size_t originalRank = originalShape.size(), reducedRank = reducedShape.size();
401 llvm::SmallDenseSet<unsigned> unusedDims;
402 unsigned reducedIdx = 0;
403 for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) {
404 // Greedily insert `originalIdx` if match.
405 if (reducedIdx < reducedRank &&
406 originalShape[originalIdx] == reducedShape[reducedIdx]) {
407 reducedIdx++;
408 continue;
409 }
410
411 unusedDims.insert(originalIdx);
412 // If no match on `originalIdx`, the `originalShape` at this dimension
413 // must be 1, otherwise we bail.
414 if (originalShape[originalIdx] != 1)
415 return std::nullopt;
416 }
417 // The whole reducedShape must be scanned, otherwise we bail.
418 if (reducedIdx != reducedRank)
419 return std::nullopt;
420 return unusedDims;
421}
422
423SliceVerificationResult
424mlir::isRankReducedType(ShapedType originalType,
425 ShapedType candidateReducedType) {
426 if (originalType == candidateReducedType)
427 return SliceVerificationResult::Success;
428
429 ShapedType originalShapedType = originalType.cast<ShapedType>();
430 ShapedType candidateReducedShapedType =
431 candidateReducedType.cast<ShapedType>();
432
433 // Rank and size logic is valid for all ShapedTypes.
434 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
435 ArrayRef<int64_t> candidateReducedShape =
436 candidateReducedShapedType.getShape();
437 unsigned originalRank = originalShape.size(),
438 candidateReducedRank = candidateReducedShape.size();
439 if (candidateReducedRank > originalRank)
440 return SliceVerificationResult::RankTooLarge;
441
442 auto optionalUnusedDimsMask =
443 computeRankReductionMask(originalShape, candidateReducedShape);
444
445 // Sizes cannot be matched in case empty vector is returned.
446 if (!optionalUnusedDimsMask)
447 return SliceVerificationResult::SizeMismatch;
448
449 if (originalShapedType.getElementType() !=
450 candidateReducedShapedType.getElementType())
451 return SliceVerificationResult::ElemTypeMismatch;
452
453 return SliceVerificationResult::Success;
454}
455
456bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) {
457 // Empty attribute is allowed as default memory space.
458 if (!memorySpace)
459 return true;
460
461 // Supported built-in attributes.
462 if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>())
463 return true;
464
465 // Allow custom dialect attributes.
466 if (!isa<BuiltinDialect>(memorySpace.getDialect()))
467 return true;
468
469 return false;
470}
471
472Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace,
473 MLIRContext *ctx) {
474 if (memorySpace == 0)
475 return nullptr;
476
477 return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace);
478}
479
480Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) {
481 IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>();
482 if (intMemorySpace && intMemorySpace.getValue() == 0)
483 return nullptr;
484
485 return memorySpace;
486}
487
488unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) {
489 if (!memorySpace)
490 return 0;
491
492 assert(memorySpace.isa<IntegerAttr>() &&(static_cast <bool> (memorySpace.isa<IntegerAttr>
() && "Using `getMemorySpaceInteger` with non-Integer attribute"
) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\""
, "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__
))
493 "Using `getMemorySpaceInteger` with non-Integer attribute")(static_cast <bool> (memorySpace.isa<IntegerAttr>
() && "Using `getMemorySpaceInteger` with non-Integer attribute"
) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\""
, "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__
))
;
494
495 return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt());
496}
497
498unsigned MemRefType::getMemorySpaceAsInt() const {
499 return detail::getMemorySpaceAsInt(getMemorySpace());
500}
501
502MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
503 MemRefLayoutAttrInterface layout,
504 Attribute memorySpace) {
505 // Use default layout for empty attribute.
506 if (!layout)
507 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
508 shape.size(), elementType.getContext()));
509
510 // Drop default memory space value and replace it with empty attribute.
511 memorySpace = skipDefaultMemorySpace(memorySpace);
512
513 return Base::get(elementType.getContext(), shape, elementType, layout,
514 memorySpace);
515}
516
517MemRefType MemRefType::getChecked(
518 function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape,
519 Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) {
520
521 // Use default layout for empty attribute.
522 if (!layout)
523 layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap(
524 shape.size(), elementType.getContext()));
525
526 // Drop default memory space value and replace it with empty attribute.
527 memorySpace = skipDefaultMemorySpace(memorySpace);
528
529 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
530 elementType, layout, memorySpace);
531}
532
533MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
534 AffineMap map, Attribute memorySpace) {
535
536 // Use default layout for empty map.
537 if (!map)
538 map = AffineMap::getMultiDimIdentityMap(shape.size(),
539 elementType.getContext());
540
541 // Wrap AffineMap into Attribute.
542 auto layout = AffineMapAttr::get(map);
543
544 // Drop default memory space value and replace it with empty attribute.
545 memorySpace = skipDefaultMemorySpace(memorySpace);
546
547 return Base::get(elementType.getContext(), shape, elementType, layout,
548 memorySpace);
549}
550
551MemRefType
552MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
553 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
554 Attribute memorySpace) {
555
556 // Use default layout for empty map.
557 if (!map)
558 map = AffineMap::getMultiDimIdentityMap(shape.size(),
559 elementType.getContext());
560
561 // Wrap AffineMap into Attribute.
562 auto layout = AffineMapAttr::get(map);
563
564 // Drop default memory space value and replace it with empty attribute.
565 memorySpace = skipDefaultMemorySpace(memorySpace);
566
567 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
568 elementType, layout, memorySpace);
569}
570
571MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
572 AffineMap map, unsigned memorySpaceInd) {
573
574 // Use default layout for empty map.
575 if (!map)
576 map = AffineMap::getMultiDimIdentityMap(shape.size(),
577 elementType.getContext());
578
579 // Wrap AffineMap into Attribute.
580 auto layout = AffineMapAttr::get(map);
581
582 // Convert deprecated integer-like memory space to Attribute.
583 Attribute memorySpace =
584 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
585
586 return Base::get(elementType.getContext(), shape, elementType, layout,
587 memorySpace);
588}
589
590MemRefType
591MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn,
592 ArrayRef<int64_t> shape, Type elementType, AffineMap map,
593 unsigned memorySpaceInd) {
594
595 // Use default layout for empty map.
596 if (!map)
597 map = AffineMap::getMultiDimIdentityMap(shape.size(),
598 elementType.getContext());
599
600 // Wrap AffineMap into Attribute.
601 auto layout = AffineMapAttr::get(map);
602
603 // Convert deprecated integer-like memory space to Attribute.
604 Attribute memorySpace =
605 wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext());
606
607 return Base::getChecked(emitErrorFn, elementType.getContext(), shape,
608 elementType, layout, memorySpace);
609}
610
611LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
612 ArrayRef<int64_t> shape, Type elementType,
613 MemRefLayoutAttrInterface layout,
614 Attribute memorySpace) {
615 if (!BaseMemRefType::isValidElementType(elementType))
616 return emitError() << "invalid memref element type";
617
618 // Negative sizes are not allowed except for `kDynamic`.
619 for (int64_t s : shape)
620 if (s < 0 && !ShapedType::isDynamic(s))
621 return emitError() << "invalid memref size";
622
623 assert(layout && "missing layout specification")(static_cast <bool> (layout && "missing layout specification"
) ? void (0) : __assert_fail ("layout && \"missing layout specification\""
, "mlir/lib/IR/BuiltinTypes.cpp", 623, __extension__ __PRETTY_FUNCTION__
))
;
624 if (failed(layout.verifyLayout(shape, emitError)))
625 return failure();
626
627 if (!isSupportedMemorySpace(memorySpace))
628 return emitError() << "unsupported memory space Attribute";
629
630 return success();
631}
632
633//===----------------------------------------------------------------------===//
634// UnrankedMemRefType
635//===----------------------------------------------------------------------===//
636
637unsigned UnrankedMemRefType::getMemorySpaceAsInt() const {
638 return detail::getMemorySpaceAsInt(getMemorySpace());
639}
640
641LogicalResult
642UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError,
643 Type elementType, Attribute memorySpace) {
644 if (!BaseMemRefType::isValidElementType(elementType))
645 return emitError() << "invalid memref element type";
646
647 if (!isSupportedMemorySpace(memorySpace))
648 return emitError() << "unsupported memory space Attribute";
649
650 return success();
651}
652
653// Fallback cases for terminal dim/sym/cst that are not part of a binary op (
654// i.e. single term). Accumulate the AffineExpr into the existing one.
655static void extractStridesFromTerm(AffineExpr e,
656 AffineExpr multiplicativeFactor,
657 MutableArrayRef<AffineExpr> strides,
658 AffineExpr &offset) {
659 if (auto dim = e.dyn_cast<AffineDimExpr>())
660 strides[dim.getPosition()] =
661 strides[dim.getPosition()] + multiplicativeFactor;
662 else
663 offset = offset + e * multiplicativeFactor;
664}
665
666/// Takes a single AffineExpr `e` and populates the `strides` array with the
667/// strides expressions for each dim position.
668/// The convention is that the strides for dimensions d0, .. dn appear in
669/// order to make indexing intuitive into the result.
670static LogicalResult extractStrides(AffineExpr e,
671 AffineExpr multiplicativeFactor,
672 MutableArrayRef<AffineExpr> strides,
673 AffineExpr &offset) {
674 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
675 if (!bin) {
676 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
677 return success();
678 }
679
680 if (bin.getKind() == AffineExprKind::CeilDiv ||
681 bin.getKind() == AffineExprKind::FloorDiv ||
682 bin.getKind() == AffineExprKind::Mod)
683 return failure();
684
685 if (bin.getKind() == AffineExprKind::Mul) {
686 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
687 if (dim) {
688 strides[dim.getPosition()] =
689 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
690 return success();
691 }
692 // LHS and RHS may both contain complex expressions of dims. Try one path
693 // and if it fails try the other. This is guaranteed to succeed because
694 // only one path may have a `dim`, otherwise this is not an AffineExpr in
695 // the first place.
696 if (bin.getLHS().isSymbolicOrConstant())
697 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
698 strides, offset);
699 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
700 strides, offset);
701 }
702
703 if (bin.getKind() == AffineExprKind::Add) {
704 auto res1 =
705 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
706 auto res2 =
707 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
708 return success(succeeded(res1) && succeeded(res2));
709 }
710
711 llvm_unreachable("unexpected binary operation")::llvm::llvm_unreachable_internal("unexpected binary operation"
, "mlir/lib/IR/BuiltinTypes.cpp", 711)
;
712}
713
714/// A stride specification is a list of integer values that are either static
715/// or dynamic (encoded with ShapedType::kDynamic). Strides encode
716/// the distance in the number of elements between successive entries along a
717/// particular dimension.
718///
719/// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a
720/// non-contiguous memory region of `42` by `16` `f32` elements in which the
721/// distance between two consecutive elements along the outer dimension is `1`
722/// and the distance between two consecutive elements along the inner dimension
723/// is `64`.
724///
725/// The convention is that the strides for dimensions d0, .. dn appear in
726/// order to make indexing intuitive into the result.
727static LogicalResult getStridesAndOffset(MemRefType t,
728 SmallVectorImpl<AffineExpr> &strides,
729 AffineExpr &offset) {
730 AffineMap m = t.getLayout().getAffineMap();
731
732 if (m.getNumResults() != 1 && !m.isIdentity())
733 return failure();
734
735 auto zero = getAffineConstantExpr(0, t.getContext());
736 auto one = getAffineConstantExpr(1, t.getContext());
737 offset = zero;
738 strides.assign(t.getRank(), zero);
739
740 // Canonical case for empty map.
741 if (m.isIdentity()) {
742 // 0-D corner case, offset is already 0.
743 if (t.getRank() == 0)
744 return success();
745 auto stridedExpr =
746 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
747 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
748 return success();
749 assert(false && "unexpected failure: extract strides in canonical layout")(static_cast <bool> (false && "unexpected failure: extract strides in canonical layout"
) ? void (0) : __assert_fail ("false && \"unexpected failure: extract strides in canonical layout\""
, "mlir/lib/IR/BuiltinTypes.cpp", 749, __extension__ __PRETTY_FUNCTION__
))
;
750 }
751
752 // Non-canonical case requires more work.
753 auto stridedExpr =
754 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
755 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
756 offset = AffineExpr();
757 strides.clear();
758 return failure();
759 }
760
761 // Simplify results to allow folding to constants and simple checks.
762 unsigned numDims = m.getNumDims();
763 unsigned numSymbols = m.getNumSymbols();
764 offset = simplifyAffineExpr(offset, numDims, numSymbols);
765 for (auto &stride : strides)
766 stride = simplifyAffineExpr(stride, numDims, numSymbols);
767
768 // In practice, a strided memref must be internally non-aliasing. Test
769 // against 0 as a proxy.
770 // TODO: static cases can have more advanced checks.
771 // TODO: dynamic cases would require a way to compare symbolic
772 // expressions and would probably need an affine set context propagated
773 // everywhere.
774 if (llvm::any_of(strides, [](AffineExpr e) {
775 return e == getAffineConstantExpr(0, e.getContext());
776 })) {
777 offset = AffineExpr();
778 strides.clear();
779 return failure();
780 }
781
782 return success();
783}
784
785LogicalResult mlir::getStridesAndOffset(MemRefType t,
786 SmallVectorImpl<int64_t> &strides,
787 int64_t &offset) {
788 // Happy path: the type uses the strided layout directly.
789 if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) {
790 llvm::append_range(strides, strided.getStrides());
791 offset = strided.getOffset();
792 return success();
793 }
794
795 // Otherwise, defer to the affine fallback as layouts are supposed to be
796 // convertible to affine maps.
797 AffineExpr offsetExpr;
798 SmallVector<AffineExpr, 4> strideExprs;
799 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
800 return failure();
801 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
802 offset = cst.getValue();
803 else
804 offset = ShapedType::kDynamic;
805 for (auto e : strideExprs) {
806 if (auto c = e.dyn_cast<AffineConstantExpr>())
807 strides.push_back(c.getValue());
808 else
809 strides.push_back(ShapedType::kDynamic);
810 }
811 return success();
812}
813
814std::pair<SmallVector<int64_t>, int64_t>
815mlir::getStridesAndOffset(MemRefType t) {
816 SmallVector<int64_t> strides;
817 int64_t offset;
818 LogicalResult status = getStridesAndOffset(t, strides, offset);
819 (void)status;
820 assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset")(static_cast <bool> (succeeded(status) && "Invalid use of check-free getStridesAndOffset"
) ? void (0) : __assert_fail ("succeeded(status) && \"Invalid use of check-free getStridesAndOffset\""
, "mlir/lib/IR/BuiltinTypes.cpp", 820, __extension__ __PRETTY_FUNCTION__
))
;
821 return {strides, offset};
822}
823
824//===----------------------------------------------------------------------===//
825/// TupleType
826//===----------------------------------------------------------------------===//
827
828/// Return the elements types for this tuple.
829ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
830
831/// Accumulate the types contained in this tuple and tuples nested within it.
832/// Note that this only flattens nested tuples, not any other container type,
833/// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
834/// (i32, tensor<i32>, f32, i64)
835void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
836 for (Type type : getTypes()) {
837 if (auto nestedTuple = type.dyn_cast<TupleType>())
838 nestedTuple.getFlattenedTypes(types);
839 else
840 types.push_back(type);
841 }
842}
843
844/// Return the number of element types.
845size_t TupleType::size() const { return getImpl()->size(); }
846
847//===----------------------------------------------------------------------===//
848// Type Utilities
849//===----------------------------------------------------------------------===//
850
851/// Return a version of `t` with identity layout if it can be determined
852/// statically that the layout is the canonical contiguous strided layout.
853/// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
854/// `t` with simplified layout.
855/// If `t` has multiple layout maps or a multi-result layout, just return `t`.
856MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
857 AffineMap m = t.getLayout().getAffineMap();
858
859 // Already in canonical form.
860 if (m.isIdentity())
861 return t;
862
863 // Can't reduce to canonical identity form, return in canonical form.
864 if (m.getNumResults() > 1)
865 return t;
866
867 // Corner-case for 0-D affine maps.
868 if (m.getNumDims() == 0 && m.getNumSymbols() == 0) {
869 if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>())
870 if (cst.getValue() == 0)
871 return MemRefType::Builder(t).setLayout({});
872 return t;
873 }
874
875 // 0-D corner case for empty shape that still have an affine map. Example:
876 // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose
877 // offset needs to remain, just return t.
878 if (t.getShape().empty())
879 return t;
880
881 // If the canonical strided layout for the sizes of `t` is equal to the
882 // simplified layout of `t` we can just return an empty layout. Otherwise,
883 // just simplify the existing layout.
884 AffineExpr expr =
885 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
886 auto simplifiedLayoutExpr =
887 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
888 if (expr != simplifiedLayoutExpr)
889 return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get(
890 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)));
891 return MemRefType::Builder(t).setLayout({});
892}
893
894AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
895 ArrayRef<AffineExpr> exprs,
896 MLIRContext *context) {
897 // Size 0 corner case is useful for canonicalizations.
898 if (sizes.empty())
899 return getAffineConstantExpr(0, context);
900
901 assert(!exprs.empty() && "expected exprs")(static_cast <bool> (!exprs.empty() && "expected exprs"
) ? void (0) : __assert_fail ("!exprs.empty() && \"expected exprs\""
, "mlir/lib/IR/BuiltinTypes.cpp", 901, __extension__ __PRETTY_FUNCTION__
))
;
902 auto maps = AffineMap::inferFromExprList(exprs);
903 assert(!maps.empty() && "Expected one non-empty map")(static_cast <bool> (!maps.empty() && "Expected one non-empty map"
) ? void (0) : __assert_fail ("!maps.empty() && \"Expected one non-empty map\""
, "mlir/lib/IR/BuiltinTypes.cpp", 903, __extension__ __PRETTY_FUNCTION__
))
;
904 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
905
906 AffineExpr expr;
907 bool dynamicPoisonBit = false;
908 int64_t runningSize = 1;
909 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
910 int64_t size = std::get<1>(en);
911 AffineExpr dimExpr = std::get<0>(en);
912 AffineExpr stride = dynamicPoisonBit
913 ? getAffineSymbolExpr(nSymbols++, context)
914 : getAffineConstantExpr(runningSize, context);
915 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
916 if (size > 0) {
917 runningSize *= size;
918 assert(runningSize > 0 && "integer overflow in size computation")(static_cast <bool> (runningSize > 0 && "integer overflow in size computation"
) ? void (0) : __assert_fail ("runningSize > 0 && \"integer overflow in size computation\""
, "mlir/lib/IR/BuiltinTypes.cpp", 918, __extension__ __PRETTY_FUNCTION__
))
;
919 } else {
920 dynamicPoisonBit = true;
921 }
922 }
923 return simplifyAffineExpr(expr, numDims, nSymbols);
924}
925
926AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
927 MLIRContext *context) {
928 SmallVector<AffineExpr, 4> exprs;
929 exprs.reserve(sizes.size());
930 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
931 exprs.push_back(getAffineDimExpr(dim, context));
932 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
933}
934
935/// Return true if the layout for `t` is compatible with strided semantics.
936bool mlir::isStrided(MemRefType t) {
937 int64_t offset;
938 SmallVector<int64_t, 4> strides;
939 auto res = getStridesAndOffset(t, strides, offset);
940 return succeeded(res);
941}

/build/source/mlir/include/mlir/IR/Dialect.h

1//===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the 'dialect' abstraction.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_DIALECT_H
14#define MLIR_IR_DIALECT_H
15
16#include "mlir/IR/DialectRegistry.h"
17#include "mlir/IR/OperationSupport.h"
18#include "mlir/Support/TypeID.h"
19
20#include <map>
21#include <tuple>
22
23namespace mlir {
24class DialectAsmParser;
25class DialectAsmPrinter;
26class DialectInterface;
27class OpBuilder;
28class Type;
29
30//===----------------------------------------------------------------------===//
31// Dialect
32//===----------------------------------------------------------------------===//
33
34/// Dialects are groups of MLIR operations, types and attributes, as well as
35/// behavior associated with the entire group. For example, hooks into other
36/// systems for constant folding, interfaces, default named types for asm
37/// printing, etc.
38///
39/// Instances of the dialect object are loaded in a specific MLIRContext.
40///
41class Dialect {
42public:
43 /// Type for a callback provided by the dialect to parse a custom operation.
44 /// This is used for the dialect to provide an alternative way to parse custom
45 /// operations, including unregistered ones.
46 using ParseOpHook =
47 function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>;
48
49 virtual ~Dialect();
50
51 /// Utility function that returns if the given string is a valid dialect
52 /// namespace
53 static bool isValidNamespace(StringRef str);
54
55 MLIRContext *getContext() const { return context; }
56
57 StringRef getNamespace() const { return name; }
58
59 /// Returns the unique identifier that corresponds to this dialect.
60 TypeID getTypeID() const { return dialectID; }
61
62 /// Returns true if this dialect allows for unregistered operations, i.e.
63 /// operations prefixed with the dialect namespace but not registered with
64 /// addOperation.
65 bool allowsUnknownOperations() const { return unknownOpsAllowed; }
66
67 /// Return true if this dialect allows for unregistered types, i.e., types
68 /// prefixed with the dialect namespace but not registered with addType.
69 /// These are represented with OpaqueType.
70 bool allowsUnknownTypes() const { return unknownTypesAllowed; }
71
72 /// Register dialect-wide canonicalization patterns. This method should only
73 /// be used to register canonicalization patterns that do not conceptually
74 /// belong to any single operation in the dialect. (In that case, use the op's
75 /// canonicalizer.) E.g., canonicalization patterns for op interfaces should
76 /// be registered here.
77 virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {}
78
79 /// Registered hook to materialize a single constant operation from a given
80 /// attribute value with the desired resultant type. This method should use
81 /// the provided builder to create the operation without changing the
82 /// insertion position. The generated operation is expected to be constant
83 /// like, i.e. single result, zero operands, non side-effecting, etc. On
84 /// success, this hook should return the value generated to represent the
85 /// constant value. Otherwise, it should return null on failure.
86 virtual Operation *materializeConstant(OpBuilder &builder, Attribute value,
87 Type type, Location loc) {
88 return nullptr;
89 }
90
91 //===--------------------------------------------------------------------===//
92 // Parsing Hooks
93 //===--------------------------------------------------------------------===//
94
95 /// Parse an attribute registered to this dialect. If 'type' is nonnull, it
96 /// refers to the expected type of the attribute.
97 virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const;
98
99 /// Print an attribute registered to this dialect. Note: The type of the
100 /// attribute need not be printed by this method as it is always printed by
101 /// the caller.
102 virtual void printAttribute(Attribute, DialectAsmPrinter &) const {
103 llvm_unreachable("dialect has no registered attribute printing hook")::llvm::llvm_unreachable_internal("dialect has no registered attribute printing hook"
, "mlir/include/mlir/IR/Dialect.h", 103)
;
104 }
105
106 /// Parse a type registered to this dialect.
107 virtual Type parseType(DialectAsmParser &parser) const;
108
109 /// Print a type registered to this dialect.
110 virtual void printType(Type, DialectAsmPrinter &) const {
111 llvm_unreachable("dialect has no registered type printing hook")::llvm::llvm_unreachable_internal("dialect has no registered type printing hook"
, "mlir/include/mlir/IR/Dialect.h", 111)
;
112 }
113
114 /// Return the hook to parse an operation registered to this dialect, if any.
115 /// By default this will lookup for registered operations and return the
116 /// `parse()` method registered on the RegisteredOperationName. Dialects can
117 /// override this behavior and handle unregistered operations as well.
118 virtual std::optional<ParseOpHook>
119 getParseOperationHook(StringRef opName) const;
120
121 /// Print an operation registered to this dialect.
122 /// This hook is invoked for registered operation which don't override the
123 /// `print()` method to define their own custom assembly.
124 virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)>
125 getOperationPrinter(Operation *op) const;
126
127 //===--------------------------------------------------------------------===//
128 // Verification Hooks
129 //===--------------------------------------------------------------------===//
130
131 /// Verify an attribute from this dialect on the argument at 'argIndex' for
132 /// the region at 'regionIndex' on the given operation. Returns failure if
133 /// the verification failed, success otherwise. This hook may optionally be
134 /// invoked from any operation containing a region.
135 virtual LogicalResult verifyRegionArgAttribute(Operation *,
136 unsigned regionIndex,
137 unsigned argIndex,
138 NamedAttribute);
139
140 /// Verify an attribute from this dialect on the result at 'resultIndex' for
141 /// the region at 'regionIndex' on the given operation. Returns failure if
142 /// the verification failed, success otherwise. This hook may optionally be
143 /// invoked from any operation containing a region.
144 virtual LogicalResult verifyRegionResultAttribute(Operation *,
145 unsigned regionIndex,
146 unsigned resultIndex,
147 NamedAttribute);
148
149 /// Verify an attribute from this dialect on the given operation. Returns
150 /// failure if the verification failed, success otherwise.
151 virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) {
152 return success();
153 }
154
155 //===--------------------------------------------------------------------===//
156 // Interfaces
157 //===--------------------------------------------------------------------===//
158
159 /// Lookup an interface for the given ID if one is registered, otherwise
160 /// nullptr.
161 DialectInterface *getRegisteredInterface(TypeID interfaceID) {
162 auto it = registeredInterfaces.find(interfaceID);
163 return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr;
164 }
165 template <typename InterfaceT>
166 InterfaceT *getRegisteredInterface() {
167 return static_cast<InterfaceT *>(
168 getRegisteredInterface(InterfaceT::getInterfaceID()));
169 }
170
171 /// Lookup an op interface for the given ID if one is registered, otherwise
172 /// nullptr.
173 virtual void *getRegisteredInterfaceForOp(TypeID interfaceID,
174 OperationName opName) {
175 return nullptr;
176 }
177 template <typename InterfaceT>
178 typename InterfaceT::Concept *
179 getRegisteredInterfaceForOp(OperationName opName) {
180 return static_cast<typename InterfaceT::Concept *>(
181 getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName));
182 }
183
184 /// Register a dialect interface with this dialect instance.
185 void addInterface(std::unique_ptr<DialectInterface> interface);
186
187 /// Register a set of dialect interfaces with this dialect instance.
188 template <typename... Args>
189 void addInterfaces() {
190 (addInterface(std::make_unique<Args>(this)), ...);
191 }
192 template <typename InterfaceT, typename... Args>
193 InterfaceT &addInterface(Args &&...args) {
194 InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...);
195 addInterface(std::unique_ptr<DialectInterface>(interface));
196 return *interface;
197 }
198
199protected:
200 /// The constructor takes a unique namespace for this dialect as well as the
201 /// context to bind to.
202 /// Note: The namespace must not contain '.' characters.
203 /// Note: All operations belonging to this dialect must have names starting
204 /// with the namespace followed by '.'.
205 /// Example:
206 /// - "tf" for the TensorFlow ops like "tf.add".
207 Dialect(StringRef name, MLIRContext *context, TypeID id);
208
209 /// This method is used by derived classes to add their operations to the set.
210 ///
211 template <typename... Args>
212 void addOperations() {
213 // This initializer_list argument pack expansion is essentially equal to
214 // using a fold expression with a comma operator. Clang however, refuses
215 // to compile a fold expression with a depth of more than 256 by default.
216 // There seem to be no such limitations for initializer_list.
217 (void)std::initializer_list<int>{
218 0, (RegisteredOperationName::insert<Args>(*this), 0)...};
219 }
220
221 /// Register a set of type classes with this dialect.
222 template <typename... Args>
223 void addTypes() {
224 (addType<Args>(), ...);
2
Calling 'Dialect::addType'
225 }
226
227 /// Register a type instance with this dialect.
228 /// The use of this method is in general discouraged in favor of
229 /// 'addTypes<CustomType>()'.
230 void addType(TypeID typeID, AbstractType &&typeInfo);
231
232 /// Register a set of attribute classes with this dialect.
233 template <typename... Args>
234 void addAttributes() {
235 (addAttribute<Args>(), ...);
236 }
237
238 /// Register an attribute instance with this dialect.
239 /// The use of this method is in general discouraged in favor of
240 /// 'addAttributes<CustomAttr>()'.
241 void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo);
242
243 /// Enable support for unregistered operations.
244 void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; }
245
246 /// Enable support for unregistered types.
247 void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; }
248
249private:
250 Dialect(const Dialect &) = delete;
251 void operator=(Dialect &) = delete;
252
253 /// Register an attribute instance with this dialect.
254 template <typename T>
255 void addAttribute() {
256 // Add this attribute to the dialect and register it with the uniquer.
257 addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this));
258 detail::AttributeUniquer::registerAttribute<T>(context);
259 }
260
261 /// Register a type instance with this dialect.
262 template <typename T>
263 void addType() {
264 // Add this type to the dialect and register it with the uniquer.
265 addType(T::getTypeID(), AbstractType::get<T>(*this));
3
Calling 'AbstractType::get'
266 detail::TypeUniquer::registerType<T>(context);
267 }
268
269 /// The namespace of this dialect.
270 StringRef name;
271
272 /// The unique identifier of the derived Op class, this is used in the context
273 /// to allow registering multiple times the same dialect.
274 TypeID dialectID;
275
276 /// This is the context that owns this Dialect object.
277 MLIRContext *context;
278
279 /// Flag that specifies whether this dialect supports unregistered operations,
280 /// i.e. operations prefixed with the dialect namespace but not registered
281 /// with addOperation.
282 bool unknownOpsAllowed = false;
283
284 /// Flag that specifies whether this dialect allows unregistered types, i.e.
285 /// types prefixed with the dialect namespace but not registered with addType.
286 /// These types are represented with OpaqueType.
287 bool unknownTypesAllowed = false;
288
289 /// A collection of registered dialect interfaces.
290 DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
291
292 friend class DialectRegistry;
293 friend void registerDialect();
294 friend class MLIRContext;
295};
296
297} // namespace mlir
298
299namespace llvm {
300/// Provide isa functionality for Dialects.
301template <typename T>
302struct isa_impl<T, ::mlir::Dialect,
303 std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> {
304 static inline bool doit(const ::mlir::Dialect &dialect) {
305 return mlir::TypeID::get<T>() == dialect.getTypeID();
306 }
307};
308template <typename T>
309struct isa_impl<
310 T, ::mlir::Dialect,
311 std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> {
312 static inline bool doit(const ::mlir::Dialect &dialect) {
313 return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>();
314 }
315};
316template <typename T>
317struct cast_retty_impl<T, ::mlir::Dialect *> {
318 using ret_type = T *;
319};
320template <typename T>
321struct cast_retty_impl<T, ::mlir::Dialect> {
322 using ret_type = T &;
323};
324
325template <typename T>
326struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> {
327 template <typename To>
328 static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &>
329 doitImpl(::mlir::Dialect &dialect) {
330 return static_cast<To &>(dialect);
331 }
332 template <typename To>
333 static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value,
334 To &>
335 doitImpl(::mlir::Dialect &dialect) {
336 return *dialect.getRegisteredInterface<To>();
337 }
338
339 static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); }
340};
341template <class T>
342struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> {
343 static auto doit(::mlir::Dialect *dialect) {
344 return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit(
345 *dialect);
346 }
347};
348
349} // namespace llvm
350
351#endif

/build/source/mlir/include/mlir/IR/TypeSupport.h

1//===- TypeSupport.h --------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines support types for registering dialect extended types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_TYPESUPPORT_H
14#define MLIR_IR_TYPESUPPORT_H
15
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/IR/StorageUniquerSupport.h"
18#include "llvm/ADT/Twine.h"
19
20namespace mlir {
21class Dialect;
22class MLIRContext;
23
24//===----------------------------------------------------------------------===//
25// AbstractType
26//===----------------------------------------------------------------------===//
27
28/// This class contains all of the static information common to all instances of
29/// a registered Type.
30class AbstractType {
31public:
32 using HasTraitFn = llvm::unique_function<bool(TypeID) const>;
33 using WalkImmediateSubElementsFn = function_ref<void(
34 Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>;
35 using ReplaceImmediateSubElementsFn =
36 function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>;
37
38 /// Look up the specified abstract type in the MLIRContext and return a
39 /// reference to it.
40 static const AbstractType &lookup(TypeID typeID, MLIRContext *context);
41
42 /// This method is used by Dialect objects when they register the list of
43 /// types they contain.
44 template <typename T>
45 static AbstractType get(Dialect &dialect) {
46 return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(),
4
Address of stack memory associated with temporary object of type '(lambda at /build/source/mlir/include/mlir/IR/StorageUniquerSupport.h:133:12)' is still referred to by a temporary object on the stack upon returning to the caller. This will be a dangling reference
47 T::getWalkImmediateSubElementsFn(),
48 T::getReplaceImmediateSubElementsFn(), T::getTypeID());
49 }
50
51 /// This method is used by Dialect objects to register types with
52 /// custom TypeIDs.
53 /// The use of this method is in general discouraged in favor of
54 /// 'get<CustomType>(dialect)';
55 static AbstractType
56 get(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
57 HasTraitFn &&hasTrait,
58 WalkImmediateSubElementsFn walkImmediateSubElementsFn,
59 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
60 TypeID typeID) {
61 return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait),
62 walkImmediateSubElementsFn,
63 replaceImmediateSubElementsFn, typeID);
64 }
65
66 /// Return the dialect this type was registered to.
67 Dialect &getDialect() const { return const_cast<Dialect &>(dialect); }
68
69 /// Returns an instance of the concept object for the given interface if it
70 /// was registered to this type, null otherwise. This should not be used
71 /// directly.
72 template <typename T>
73 typename T::Concept *getInterface() const {
74 return interfaceMap.lookup<T>();
75 }
76
77 /// Returns true if the type has the interface with the given ID.
78 bool hasInterface(TypeID interfaceID) const {
79 return interfaceMap.contains(interfaceID);
80 }
81
82 /// Returns true if the type has a particular trait.
83 template <template <typename T> class Trait>
84 bool hasTrait() const {
85 return hasTraitFn(TypeID::get<Trait>());
86 }
87
88 /// Returns true if the type has a particular trait.
89 bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); }
90
91 /// Walk the immediate sub-elements of the given type.
92 void walkImmediateSubElements(Type type,
93 function_ref<void(Attribute)> walkAttrsFn,
94 function_ref<void(Type)> walkTypesFn) const;
95
96 /// Replace the immediate sub-elements of the given type.
97 Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs,
98 ArrayRef<Type> replTypes) const;
99
100 /// Return the unique identifier representing the concrete type class.
101 TypeID getTypeID() const { return typeID; }
102
103private:
104 AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap,
105 HasTraitFn &&hasTrait,
106 WalkImmediateSubElementsFn walkImmediateSubElementsFn,
107 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn,
108 TypeID typeID)
109 : dialect(dialect), interfaceMap(std::move(interfaceMap)),
110 hasTraitFn(std::move(hasTrait)),
111 walkImmediateSubElementsFn(walkImmediateSubElementsFn),
112 replaceImmediateSubElementsFn(replaceImmediateSubElementsFn),
113 typeID(typeID) {}
114
115 /// Give StorageUserBase access to the mutable lookup.
116 template <typename ConcreteT, typename BaseT, typename StorageT,
117 typename UniquerT, template <typename T> class... Traits>
118 friend class detail::StorageUserBase;
119
120 /// Look up the specified abstract type in the MLIRContext and return a
121 /// (mutable) pointer to it. Return a null pointer if the type could not
122 /// be found in the context.
123 static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context);
124
125 /// This is the dialect that this type was registered to.
126 const Dialect &dialect;
127
128 /// This is a collection of the interfaces registered to this type.
129 detail::InterfaceMap interfaceMap;
130
131 /// Function to check if the type has a particular trait.
132 HasTraitFn hasTraitFn;
133
134 /// Function to walk the immediate sub-elements of this type.
135 WalkImmediateSubElementsFn walkImmediateSubElementsFn;
136
137 /// Function to replace the immediate sub-elements of this type.
138 ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn;
139
140 /// The unique identifier of the derived Type class.
141 const TypeID typeID;
142};
143
144//===----------------------------------------------------------------------===//
145// TypeStorage
146//===----------------------------------------------------------------------===//
147
148namespace detail {
149struct TypeUniquer;
150} // namespace detail
151
152/// Base storage class appearing in a Type.
153class TypeStorage : public StorageUniquer::BaseStorage {
154 friend detail::TypeUniquer;
155 friend StorageUniquer;
156
157public:
158 /// Return the abstract type descriptor for this type.
159 const AbstractType &getAbstractType() {
160 assert(abstractType && "Malformed type storage object.")(static_cast <bool> (abstractType && "Malformed type storage object."
) ? void (0) : __assert_fail ("abstractType && \"Malformed type storage object.\""
, "mlir/include/mlir/IR/TypeSupport.h", 160, __extension__ __PRETTY_FUNCTION__
))
;
161 return *abstractType;
162 }
163
164protected:
165 /// This constructor is used by derived classes as part of the TypeUniquer.
166 TypeStorage() {}
167
168private:
169 /// Set the abstract type for this storage instance. This is used by the
170 /// TypeUniquer when initializing a newly constructed type storage object.
171 void initialize(const AbstractType &abstractTy) {
172 abstractType = const_cast<AbstractType *>(&abstractTy);
173 }
174
175 /// The abstract description for this type.
176 AbstractType *abstractType{nullptr};
177};
178
179/// Default storage type for types that require no additional initialization or
180/// storage.
181using DefaultTypeStorage = TypeStorage;
182
183//===----------------------------------------------------------------------===//
184// TypeStorageAllocator
185//===----------------------------------------------------------------------===//
186
187/// This is a utility allocator used to allocate memory for instances of derived
188/// Types.
189using TypeStorageAllocator = StorageUniquer::StorageAllocator;
190
191//===----------------------------------------------------------------------===//
192// TypeUniquer
193//===----------------------------------------------------------------------===//
194namespace detail {
195/// A utility class to get, or create, unique instances of types within an
196/// MLIRContext. This class manages all creation and uniquing of types.
197struct TypeUniquer {
198 /// Get an uniqued instance of a type T.
199 template <typename T, typename... Args>
200 static T get(MLIRContext *ctx, Args &&...args) {
201 return getWithTypeID<T, Args...>(ctx, T::getTypeID(),
202 std::forward<Args>(args)...);
203 }
204
205 /// Get an uniqued instance of a parametric type T.
206 /// The use of this method is in general discouraged in favor of
207 /// 'get<T, Args>(ctx, args)'.
208 template <typename T, typename... Args>
209 static std::enable_if_t<
210 !std::is_same<typename T::ImplType, TypeStorage>::value, T>
211 getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) {
212#ifndef NDEBUG
213 if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID))
214 llvm::report_fatal_error(
215 llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
216 "' because storage uniquer isn't initialized: the dialect was likely "
217 "not loaded, or the type wasn't added with addTypes<...>() "
218 "in the Dialect::initialize() method.");
219#endif
220 return ctx->getTypeUniquer().get<typename T::ImplType>(
221 [&, typeID](TypeStorage *storage) {
222 storage->initialize(AbstractType::lookup(typeID, ctx));
223 },
224 typeID, std::forward<Args>(args)...);
225 }
226 /// Get an uniqued instance of a singleton type T.
227 /// The use of this method is in general discouraged in favor of
228 /// 'get<T, Args>(ctx, args)'.
229 template <typename T>
230 static std::enable_if_t<
231 std::is_same<typename T::ImplType, TypeStorage>::value, T>
232 getWithTypeID(MLIRContext *ctx, TypeID typeID) {
233#ifndef NDEBUG
234 if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID))
235 llvm::report_fatal_error(
236 llvm::Twine("can't create type '") + llvm::getTypeName<T>() +
237 "' because storage uniquer isn't initialized: the dialect was likely "
238 "not loaded, or the type wasn't added with addTypes<...>() "
239 "in the Dialect::initialize() method.");
240#endif
241 return ctx->getTypeUniquer().get<typename T::ImplType>(typeID);
242 }
243
244 /// Change the mutable component of the given type instance in the provided
245 /// context.
246 template <typename T, typename... Args>
247 static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl,
248 Args &&...args) {
249 assert(impl && "cannot mutate null type")(static_cast <bool> (impl && "cannot mutate null type"
) ? void (0) : __assert_fail ("impl && \"cannot mutate null type\""
, "mlir/include/mlir/IR/TypeSupport.h", 249, __extension__ __PRETTY_FUNCTION__
))
;
250 return ctx->getTypeUniquer().mutate(T::getTypeID(), impl,
251 std::forward<Args>(args)...);
252 }
253
254 /// Register a type instance T with the uniquer.
255 template <typename T>
256 static void registerType(MLIRContext *ctx) {
257 registerType<T>(ctx, T::getTypeID());
258 }
259
260 /// Register a parametric type instance T with the uniquer.
261 /// The use of this method is in general discouraged in favor of
262 /// 'registerType<T>(ctx)'.
263 template <typename T>
264 static std::enable_if_t<
265 !std::is_same<typename T::ImplType, TypeStorage>::value>
266 registerType(MLIRContext *ctx, TypeID typeID) {
267 ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>(
268 typeID);
269 }
270 /// Register a singleton type instance T with the uniquer.
271 /// The use of this method is in general discouraged in favor of
272 /// 'registerType<T>(ctx)'.
273 template <typename T>
274 static std::enable_if_t<
275 std::is_same<typename T::ImplType, TypeStorage>::value>
276 registerType(MLIRContext *ctx, TypeID typeID) {
277 ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>(
278 typeID, [&ctx, typeID](TypeStorage *storage) {
279 storage->initialize(AbstractType::lookup(typeID, ctx));
280 });
281 }
282};
283} // namespace detail
284
285} // namespace mlir
286
287#endif