| File: | build/source/mlir/include/mlir/IR/TypeSupport.h |
| Warning: | line 46, column 5 Address of stack memory associated with temporary object of type '(lambda at /build/source/mlir/include/mlir/IR/StorageUniquerSupport.h:133:12)' is still referred to by a temporary object on the stack upon returning to the caller. This will be a dangling reference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
| 1 | //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// | |||
| 2 | // | |||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
| 4 | // See https://llvm.org/LICENSE.txt for license information. | |||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
| 6 | // | |||
| 7 | //===----------------------------------------------------------------------===// | |||
| 8 | ||||
| 9 | #include "mlir/IR/BuiltinTypes.h" | |||
| 10 | #include "TypeDetail.h" | |||
| 11 | #include "mlir/IR/AffineExpr.h" | |||
| 12 | #include "mlir/IR/AffineMap.h" | |||
| 13 | #include "mlir/IR/BuiltinAttributes.h" | |||
| 14 | #include "mlir/IR/BuiltinDialect.h" | |||
| 15 | #include "mlir/IR/Diagnostics.h" | |||
| 16 | #include "mlir/IR/Dialect.h" | |||
| 17 | #include "mlir/IR/FunctionInterfaces.h" | |||
| 18 | #include "mlir/IR/OpImplementation.h" | |||
| 19 | #include "mlir/IR/TensorEncoding.h" | |||
| 20 | #include "llvm/ADT/APFloat.h" | |||
| 21 | #include "llvm/ADT/BitVector.h" | |||
| 22 | #include "llvm/ADT/Sequence.h" | |||
| 23 | #include "llvm/ADT/Twine.h" | |||
| 24 | #include "llvm/ADT/TypeSwitch.h" | |||
| 25 | ||||
| 26 | using namespace mlir; | |||
| 27 | using namespace mlir::detail; | |||
| 28 | ||||
| 29 | //===----------------------------------------------------------------------===// | |||
| 30 | /// Tablegen Type Definitions | |||
| 31 | //===----------------------------------------------------------------------===// | |||
| 32 | ||||
| 33 | #define GET_TYPEDEF_CLASSES | |||
| 34 | #include "mlir/IR/BuiltinTypes.cpp.inc" | |||
| 35 | ||||
| 36 | //===----------------------------------------------------------------------===// | |||
| 37 | // BuiltinDialect | |||
| 38 | //===----------------------------------------------------------------------===// | |||
| 39 | ||||
| 40 | void BuiltinDialect::registerTypes() { | |||
| 41 | addTypes< | |||
| ||||
| 42 | #define GET_TYPEDEF_LIST | |||
| 43 | #include "mlir/IR/BuiltinTypes.cpp.inc" | |||
| 44 | >(); | |||
| 45 | } | |||
| 46 | ||||
| 47 | //===----------------------------------------------------------------------===// | |||
| 48 | /// ComplexType | |||
| 49 | //===----------------------------------------------------------------------===// | |||
| 50 | ||||
| 51 | /// Verify the construction of an integer type. | |||
| 52 | LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 53 | Type elementType) { | |||
| 54 | if (!elementType.isIntOrFloat()) | |||
| 55 | return emitError() << "invalid element type for complex"; | |||
| 56 | return success(); | |||
| 57 | } | |||
| 58 | ||||
| 59 | //===----------------------------------------------------------------------===// | |||
| 60 | // Integer Type | |||
| 61 | //===----------------------------------------------------------------------===// | |||
| 62 | ||||
| 63 | /// Verify the construction of an integer type. | |||
| 64 | LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 65 | unsigned width, | |||
| 66 | SignednessSemantics signedness) { | |||
| 67 | if (width > IntegerType::kMaxWidth) { | |||
| 68 | return emitError() << "integer bitwidth is limited to " | |||
| 69 | << IntegerType::kMaxWidth << " bits"; | |||
| 70 | } | |||
| 71 | return success(); | |||
| 72 | } | |||
| 73 | ||||
| 74 | unsigned IntegerType::getWidth() const { return getImpl()->width; } | |||
| 75 | ||||
| 76 | IntegerType::SignednessSemantics IntegerType::getSignedness() const { | |||
| 77 | return getImpl()->signedness; | |||
| 78 | } | |||
| 79 | ||||
| 80 | IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { | |||
| 81 | if (!scale) | |||
| 82 | return IntegerType(); | |||
| 83 | return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); | |||
| 84 | } | |||
| 85 | ||||
| 86 | //===----------------------------------------------------------------------===// | |||
| 87 | // Float Type | |||
| 88 | //===----------------------------------------------------------------------===// | |||
| 89 | ||||
| 90 | unsigned FloatType::getWidth() { | |||
| 91 | if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, | |||
| 92 | Float8E4M3FNUZType, Float8E4M3B11FNUZType>()) | |||
| 93 | return 8; | |||
| 94 | if (isa<Float16Type, BFloat16Type>()) | |||
| 95 | return 16; | |||
| 96 | if (isa<Float32Type>()) | |||
| 97 | return 32; | |||
| 98 | if (isa<Float64Type>()) | |||
| 99 | return 64; | |||
| 100 | if (isa<Float80Type>()) | |||
| 101 | return 80; | |||
| 102 | if (isa<Float128Type>()) | |||
| 103 | return 128; | |||
| 104 | llvm_unreachable("unexpected float type")::llvm::llvm_unreachable_internal("unexpected float type", "mlir/lib/IR/BuiltinTypes.cpp" , 104); | |||
| 105 | } | |||
| 106 | ||||
| 107 | /// Returns the floating semantics for the given type. | |||
| 108 | const llvm::fltSemantics &FloatType::getFloatSemantics() { | |||
| 109 | if (isa<Float8E5M2Type>()) | |||
| 110 | return APFloat::Float8E5M2(); | |||
| 111 | if (isa<Float8E4M3FNType>()) | |||
| 112 | return APFloat::Float8E4M3FN(); | |||
| 113 | if (isa<Float8E5M2FNUZType>()) | |||
| 114 | return APFloat::Float8E5M2FNUZ(); | |||
| 115 | if (isa<Float8E4M3FNUZType>()) | |||
| 116 | return APFloat::Float8E4M3FNUZ(); | |||
| 117 | if (isa<Float8E4M3B11FNUZType>()) | |||
| 118 | return APFloat::Float8E4M3B11FNUZ(); | |||
| 119 | if (isa<BFloat16Type>()) | |||
| 120 | return APFloat::BFloat(); | |||
| 121 | if (isa<Float16Type>()) | |||
| 122 | return APFloat::IEEEhalf(); | |||
| 123 | if (isa<Float32Type>()) | |||
| 124 | return APFloat::IEEEsingle(); | |||
| 125 | if (isa<Float64Type>()) | |||
| 126 | return APFloat::IEEEdouble(); | |||
| 127 | if (isa<Float80Type>()) | |||
| 128 | return APFloat::x87DoubleExtended(); | |||
| 129 | if (isa<Float128Type>()) | |||
| 130 | return APFloat::IEEEquad(); | |||
| 131 | llvm_unreachable("non-floating point type used")::llvm::llvm_unreachable_internal("non-floating point type used" , "mlir/lib/IR/BuiltinTypes.cpp", 131); | |||
| 132 | } | |||
| 133 | ||||
| 134 | FloatType FloatType::scaleElementBitwidth(unsigned scale) { | |||
| 135 | if (!scale) | |||
| 136 | return FloatType(); | |||
| 137 | MLIRContext *ctx = getContext(); | |||
| 138 | if (isF16() || isBF16()) { | |||
| 139 | if (scale == 2) | |||
| 140 | return FloatType::getF32(ctx); | |||
| 141 | if (scale == 4) | |||
| 142 | return FloatType::getF64(ctx); | |||
| 143 | } | |||
| 144 | if (isF32()) | |||
| 145 | if (scale == 2) | |||
| 146 | return FloatType::getF64(ctx); | |||
| 147 | return FloatType(); | |||
| 148 | } | |||
| 149 | ||||
| 150 | unsigned FloatType::getFPMantissaWidth() { | |||
| 151 | return APFloat::semanticsPrecision(getFloatSemantics()); | |||
| 152 | } | |||
| 153 | ||||
| 154 | //===----------------------------------------------------------------------===// | |||
| 155 | // FunctionType | |||
| 156 | //===----------------------------------------------------------------------===// | |||
| 157 | ||||
| 158 | unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } | |||
| 159 | ||||
| 160 | ArrayRef<Type> FunctionType::getInputs() const { | |||
| 161 | return getImpl()->getInputs(); | |||
| 162 | } | |||
| 163 | ||||
| 164 | unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } | |||
| 165 | ||||
| 166 | ArrayRef<Type> FunctionType::getResults() const { | |||
| 167 | return getImpl()->getResults(); | |||
| 168 | } | |||
| 169 | ||||
| 170 | FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { | |||
| 171 | return get(getContext(), inputs, results); | |||
| 172 | } | |||
| 173 | ||||
| 174 | /// Returns a new function type with the specified arguments and results | |||
| 175 | /// inserted. | |||
| 176 | FunctionType FunctionType::getWithArgsAndResults( | |||
| 177 | ArrayRef<unsigned> argIndices, TypeRange argTypes, | |||
| 178 | ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { | |||
| 179 | SmallVector<Type> argStorage, resultStorage; | |||
| 180 | TypeRange newArgTypes = function_interface_impl::insertTypesInto( | |||
| 181 | getInputs(), argIndices, argTypes, argStorage); | |||
| 182 | TypeRange newResultTypes = function_interface_impl::insertTypesInto( | |||
| 183 | getResults(), resultIndices, resultTypes, resultStorage); | |||
| 184 | return clone(newArgTypes, newResultTypes); | |||
| 185 | } | |||
| 186 | ||||
| 187 | /// Returns a new function type without the specified arguments and results. | |||
| 188 | FunctionType | |||
| 189 | FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, | |||
| 190 | const BitVector &resultIndices) { | |||
| 191 | SmallVector<Type> argStorage, resultStorage; | |||
| 192 | TypeRange newArgTypes = function_interface_impl::filterTypesOut( | |||
| 193 | getInputs(), argIndices, argStorage); | |||
| 194 | TypeRange newResultTypes = function_interface_impl::filterTypesOut( | |||
| 195 | getResults(), resultIndices, resultStorage); | |||
| 196 | return clone(newArgTypes, newResultTypes); | |||
| 197 | } | |||
| 198 | ||||
| 199 | //===----------------------------------------------------------------------===// | |||
| 200 | // OpaqueType | |||
| 201 | //===----------------------------------------------------------------------===// | |||
| 202 | ||||
| 203 | /// Verify the construction of an opaque type. | |||
| 204 | LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 205 | StringAttr dialect, StringRef typeData) { | |||
| 206 | if (!Dialect::isValidNamespace(dialect.strref())) | |||
| 207 | return emitError() << "invalid dialect namespace '" << dialect << "'"; | |||
| 208 | ||||
| 209 | // Check that the dialect is actually registered. | |||
| 210 | MLIRContext *context = dialect.getContext(); | |||
| 211 | if (!context->allowsUnregisteredDialects() && | |||
| 212 | !context->getLoadedDialect(dialect.strref())) { | |||
| 213 | return emitError() | |||
| 214 | << "`!" << dialect << "<\"" << typeData << "\">" | |||
| 215 | << "` type created with unregistered dialect. If this is " | |||
| 216 | "intended, please call allowUnregisteredDialects() on the " | |||
| 217 | "MLIRContext, or use -allow-unregistered-dialect with " | |||
| 218 | "the MLIR opt tool used"; | |||
| 219 | } | |||
| 220 | ||||
| 221 | return success(); | |||
| 222 | } | |||
| 223 | ||||
| 224 | //===----------------------------------------------------------------------===// | |||
| 225 | // VectorType | |||
| 226 | //===----------------------------------------------------------------------===// | |||
| 227 | ||||
| 228 | LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 229 | ArrayRef<int64_t> shape, Type elementType, | |||
| 230 | unsigned numScalableDims) { | |||
| 231 | if (!isValidElementType(elementType)) | |||
| 232 | return emitError() | |||
| 233 | << "vector elements must be int/index/float type but got " | |||
| 234 | << elementType; | |||
| 235 | ||||
| 236 | if (any_of(shape, [](int64_t i) { return i <= 0; })) | |||
| 237 | return emitError() | |||
| 238 | << "vector types must have positive constant sizes but got " | |||
| 239 | << shape; | |||
| 240 | ||||
| 241 | return success(); | |||
| 242 | } | |||
| 243 | ||||
| 244 | VectorType VectorType::scaleElementBitwidth(unsigned scale) { | |||
| 245 | if (!scale) | |||
| 246 | return VectorType(); | |||
| 247 | if (auto et = getElementType().dyn_cast<IntegerType>()) | |||
| 248 | if (auto scaledEt = et.scaleElementBitwidth(scale)) | |||
| 249 | return VectorType::get(getShape(), scaledEt, getNumScalableDims()); | |||
| 250 | if (auto et = getElementType().dyn_cast<FloatType>()) | |||
| 251 | if (auto scaledEt = et.scaleElementBitwidth(scale)) | |||
| 252 | return VectorType::get(getShape(), scaledEt, getNumScalableDims()); | |||
| 253 | return VectorType(); | |||
| 254 | } | |||
| 255 | ||||
| 256 | VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
| 257 | Type elementType) const { | |||
| 258 | return VectorType::get(shape.value_or(getShape()), elementType, | |||
| 259 | getNumScalableDims()); | |||
| 260 | } | |||
| 261 | ||||
| 262 | //===----------------------------------------------------------------------===// | |||
| 263 | // TensorType | |||
| 264 | //===----------------------------------------------------------------------===// | |||
| 265 | ||||
| 266 | Type TensorType::getElementType() const { | |||
| 267 | return llvm::TypeSwitch<TensorType, Type>(*this) | |||
| 268 | .Case<RankedTensorType, UnrankedTensorType>( | |||
| 269 | [](auto type) { return type.getElementType(); }); | |||
| 270 | } | |||
| 271 | ||||
| 272 | bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } | |||
| 273 | ||||
| 274 | ArrayRef<int64_t> TensorType::getShape() const { | |||
| 275 | return cast<RankedTensorType>().getShape(); | |||
| 276 | } | |||
| 277 | ||||
| 278 | TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
| 279 | Type elementType) const { | |||
| 280 | if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { | |||
| 281 | if (shape) | |||
| 282 | return RankedTensorType::get(*shape, elementType); | |||
| 283 | return UnrankedTensorType::get(elementType); | |||
| 284 | } | |||
| 285 | ||||
| 286 | auto rankedTy = cast<RankedTensorType>(); | |||
| 287 | if (!shape) | |||
| 288 | return RankedTensorType::get(rankedTy.getShape(), elementType, | |||
| 289 | rankedTy.getEncoding()); | |||
| 290 | return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, | |||
| 291 | rankedTy.getEncoding()); | |||
| 292 | } | |||
| 293 | ||||
| 294 | // Check if "elementType" can be an element type of a tensor. | |||
| 295 | static LogicalResult | |||
| 296 | checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, | |||
| 297 | Type elementType) { | |||
| 298 | if (!TensorType::isValidElementType(elementType)) | |||
| 299 | return emitError() << "invalid tensor element type: " << elementType; | |||
| 300 | return success(); | |||
| 301 | } | |||
| 302 | ||||
| 303 | /// Return true if the specified element type is ok in a tensor. | |||
| 304 | bool TensorType::isValidElementType(Type type) { | |||
| 305 | // Note: Non standard/builtin types are allowed to exist within tensor | |||
| 306 | // types. Dialects are expected to verify that tensor types have a valid | |||
| 307 | // element type within that dialect. | |||
| 308 | return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, | |||
| 309 | IndexType>() || | |||
| 310 | !llvm::isa<BuiltinDialect>(type.getDialect()); | |||
| 311 | } | |||
| 312 | ||||
| 313 | //===----------------------------------------------------------------------===// | |||
| 314 | // RankedTensorType | |||
| 315 | //===----------------------------------------------------------------------===// | |||
| 316 | ||||
| 317 | LogicalResult | |||
| 318 | RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 319 | ArrayRef<int64_t> shape, Type elementType, | |||
| 320 | Attribute encoding) { | |||
| 321 | for (int64_t s : shape) | |||
| 322 | if (s < 0 && !ShapedType::isDynamic(s)) | |||
| 323 | return emitError() << "invalid tensor dimension size"; | |||
| 324 | if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) | |||
| 325 | if (failed(v.verifyEncoding(shape, elementType, emitError))) | |||
| 326 | return failure(); | |||
| 327 | return checkTensorElementType(emitError, elementType); | |||
| 328 | } | |||
| 329 | ||||
| 330 | //===----------------------------------------------------------------------===// | |||
| 331 | // UnrankedTensorType | |||
| 332 | //===----------------------------------------------------------------------===// | |||
| 333 | ||||
| 334 | LogicalResult | |||
| 335 | UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 336 | Type elementType) { | |||
| 337 | return checkTensorElementType(emitError, elementType); | |||
| 338 | } | |||
| 339 | ||||
| 340 | //===----------------------------------------------------------------------===// | |||
| 341 | // BaseMemRefType | |||
| 342 | //===----------------------------------------------------------------------===// | |||
| 343 | ||||
| 344 | Type BaseMemRefType::getElementType() const { | |||
| 345 | return llvm::TypeSwitch<BaseMemRefType, Type>(*this) | |||
| 346 | .Case<MemRefType, UnrankedMemRefType>( | |||
| 347 | [](auto type) { return type.getElementType(); }); | |||
| 348 | } | |||
| 349 | ||||
| 350 | bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } | |||
| 351 | ||||
| 352 | ArrayRef<int64_t> BaseMemRefType::getShape() const { | |||
| 353 | return cast<MemRefType>().getShape(); | |||
| 354 | } | |||
| 355 | ||||
| 356 | BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
| 357 | Type elementType) const { | |||
| 358 | if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { | |||
| 359 | if (!shape) | |||
| 360 | return UnrankedMemRefType::get(elementType, getMemorySpace()); | |||
| 361 | MemRefType::Builder builder(*shape, elementType); | |||
| 362 | builder.setMemorySpace(getMemorySpace()); | |||
| 363 | return builder; | |||
| 364 | } | |||
| 365 | ||||
| 366 | MemRefType::Builder builder(cast<MemRefType>()); | |||
| 367 | if (shape) | |||
| 368 | builder.setShape(*shape); | |||
| 369 | builder.setElementType(elementType); | |||
| 370 | return builder; | |||
| 371 | } | |||
| 372 | ||||
| 373 | Attribute BaseMemRefType::getMemorySpace() const { | |||
| 374 | if (auto rankedMemRefTy = dyn_cast<MemRefType>()) | |||
| 375 | return rankedMemRefTy.getMemorySpace(); | |||
| 376 | return cast<UnrankedMemRefType>().getMemorySpace(); | |||
| 377 | } | |||
| 378 | ||||
| 379 | unsigned BaseMemRefType::getMemorySpaceAsInt() const { | |||
| 380 | if (auto rankedMemRefTy = dyn_cast<MemRefType>()) | |||
| 381 | return rankedMemRefTy.getMemorySpaceAsInt(); | |||
| 382 | return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); | |||
| 383 | } | |||
| 384 | ||||
| 385 | //===----------------------------------------------------------------------===// | |||
| 386 | // MemRefType | |||
| 387 | //===----------------------------------------------------------------------===// | |||
| 388 | ||||
| 389 | /// Given an `originalShape` and a `reducedShape` assumed to be a subset of | |||
| 390 | /// `originalShape` with some `1` entries erased, return the set of indices | |||
| 391 | /// that specifies which of the entries of `originalShape` are dropped to obtain | |||
| 392 | /// `reducedShape`. The returned mask can be applied as a projection to | |||
| 393 | /// `originalShape` to obtain the `reducedShape`. This mask is useful to track | |||
| 394 | /// which dimensions must be kept when e.g. compute MemRef strides under | |||
| 395 | /// rank-reducing operations. Return std::nullopt if reducedShape cannot be | |||
| 396 | /// obtained by dropping only `1` entries in `originalShape`. | |||
| 397 | std::optional<llvm::SmallDenseSet<unsigned>> | |||
| 398 | mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, | |||
| 399 | ArrayRef<int64_t> reducedShape) { | |||
| 400 | size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); | |||
| 401 | llvm::SmallDenseSet<unsigned> unusedDims; | |||
| 402 | unsigned reducedIdx = 0; | |||
| 403 | for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { | |||
| 404 | // Greedily insert `originalIdx` if match. | |||
| 405 | if (reducedIdx < reducedRank && | |||
| 406 | originalShape[originalIdx] == reducedShape[reducedIdx]) { | |||
| 407 | reducedIdx++; | |||
| 408 | continue; | |||
| 409 | } | |||
| 410 | ||||
| 411 | unusedDims.insert(originalIdx); | |||
| 412 | // If no match on `originalIdx`, the `originalShape` at this dimension | |||
| 413 | // must be 1, otherwise we bail. | |||
| 414 | if (originalShape[originalIdx] != 1) | |||
| 415 | return std::nullopt; | |||
| 416 | } | |||
| 417 | // The whole reducedShape must be scanned, otherwise we bail. | |||
| 418 | if (reducedIdx != reducedRank) | |||
| 419 | return std::nullopt; | |||
| 420 | return unusedDims; | |||
| 421 | } | |||
| 422 | ||||
| 423 | SliceVerificationResult | |||
| 424 | mlir::isRankReducedType(ShapedType originalType, | |||
| 425 | ShapedType candidateReducedType) { | |||
| 426 | if (originalType == candidateReducedType) | |||
| 427 | return SliceVerificationResult::Success; | |||
| 428 | ||||
| 429 | ShapedType originalShapedType = originalType.cast<ShapedType>(); | |||
| 430 | ShapedType candidateReducedShapedType = | |||
| 431 | candidateReducedType.cast<ShapedType>(); | |||
| 432 | ||||
| 433 | // Rank and size logic is valid for all ShapedTypes. | |||
| 434 | ArrayRef<int64_t> originalShape = originalShapedType.getShape(); | |||
| 435 | ArrayRef<int64_t> candidateReducedShape = | |||
| 436 | candidateReducedShapedType.getShape(); | |||
| 437 | unsigned originalRank = originalShape.size(), | |||
| 438 | candidateReducedRank = candidateReducedShape.size(); | |||
| 439 | if (candidateReducedRank > originalRank) | |||
| 440 | return SliceVerificationResult::RankTooLarge; | |||
| 441 | ||||
| 442 | auto optionalUnusedDimsMask = | |||
| 443 | computeRankReductionMask(originalShape, candidateReducedShape); | |||
| 444 | ||||
| 445 | // Sizes cannot be matched in case empty vector is returned. | |||
| 446 | if (!optionalUnusedDimsMask) | |||
| 447 | return SliceVerificationResult::SizeMismatch; | |||
| 448 | ||||
| 449 | if (originalShapedType.getElementType() != | |||
| 450 | candidateReducedShapedType.getElementType()) | |||
| 451 | return SliceVerificationResult::ElemTypeMismatch; | |||
| 452 | ||||
| 453 | return SliceVerificationResult::Success; | |||
| 454 | } | |||
| 455 | ||||
| 456 | bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { | |||
| 457 | // Empty attribute is allowed as default memory space. | |||
| 458 | if (!memorySpace) | |||
| 459 | return true; | |||
| 460 | ||||
| 461 | // Supported built-in attributes. | |||
| 462 | if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) | |||
| 463 | return true; | |||
| 464 | ||||
| 465 | // Allow custom dialect attributes. | |||
| 466 | if (!isa<BuiltinDialect>(memorySpace.getDialect())) | |||
| 467 | return true; | |||
| 468 | ||||
| 469 | return false; | |||
| 470 | } | |||
| 471 | ||||
| 472 | Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, | |||
| 473 | MLIRContext *ctx) { | |||
| 474 | if (memorySpace == 0) | |||
| 475 | return nullptr; | |||
| 476 | ||||
| 477 | return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); | |||
| 478 | } | |||
| 479 | ||||
| 480 | Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { | |||
| 481 | IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); | |||
| 482 | if (intMemorySpace && intMemorySpace.getValue() == 0) | |||
| 483 | return nullptr; | |||
| 484 | ||||
| 485 | return memorySpace; | |||
| 486 | } | |||
| 487 | ||||
| 488 | unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { | |||
| 489 | if (!memorySpace) | |||
| 490 | return 0; | |||
| 491 | ||||
| 492 | assert(memorySpace.isa<IntegerAttr>() &&(static_cast <bool> (memorySpace.isa<IntegerAttr> () && "Using `getMemorySpaceInteger` with non-Integer attribute" ) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\"" , "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__ )) | |||
| 493 | "Using `getMemorySpaceInteger` with non-Integer attribute")(static_cast <bool> (memorySpace.isa<IntegerAttr> () && "Using `getMemorySpaceInteger` with non-Integer attribute" ) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\"" , "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__ )); | |||
| 494 | ||||
| 495 | return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); | |||
| 496 | } | |||
| 497 | ||||
| 498 | unsigned MemRefType::getMemorySpaceAsInt() const { | |||
| 499 | return detail::getMemorySpaceAsInt(getMemorySpace()); | |||
| 500 | } | |||
| 501 | ||||
| 502 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
| 503 | MemRefLayoutAttrInterface layout, | |||
| 504 | Attribute memorySpace) { | |||
| 505 | // Use default layout for empty attribute. | |||
| 506 | if (!layout) | |||
| 507 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( | |||
| 508 | shape.size(), elementType.getContext())); | |||
| 509 | ||||
| 510 | // Drop default memory space value and replace it with empty attribute. | |||
| 511 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
| 512 | ||||
| 513 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
| 514 | memorySpace); | |||
| 515 | } | |||
| 516 | ||||
| 517 | MemRefType MemRefType::getChecked( | |||
| 518 | function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, | |||
| 519 | Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { | |||
| 520 | ||||
| 521 | // Use default layout for empty attribute. | |||
| 522 | if (!layout) | |||
| 523 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( | |||
| 524 | shape.size(), elementType.getContext())); | |||
| 525 | ||||
| 526 | // Drop default memory space value and replace it with empty attribute. | |||
| 527 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
| 528 | ||||
| 529 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
| 530 | elementType, layout, memorySpace); | |||
| 531 | } | |||
| 532 | ||||
| 533 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
| 534 | AffineMap map, Attribute memorySpace) { | |||
| 535 | ||||
| 536 | // Use default layout for empty map. | |||
| 537 | if (!map) | |||
| 538 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
| 539 | elementType.getContext()); | |||
| 540 | ||||
| 541 | // Wrap AffineMap into Attribute. | |||
| 542 | auto layout = AffineMapAttr::get(map); | |||
| 543 | ||||
| 544 | // Drop default memory space value and replace it with empty attribute. | |||
| 545 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
| 546 | ||||
| 547 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
| 548 | memorySpace); | |||
| 549 | } | |||
| 550 | ||||
| 551 | MemRefType | |||
| 552 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, | |||
| 553 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, | |||
| 554 | Attribute memorySpace) { | |||
| 555 | ||||
| 556 | // Use default layout for empty map. | |||
| 557 | if (!map) | |||
| 558 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
| 559 | elementType.getContext()); | |||
| 560 | ||||
| 561 | // Wrap AffineMap into Attribute. | |||
| 562 | auto layout = AffineMapAttr::get(map); | |||
| 563 | ||||
| 564 | // Drop default memory space value and replace it with empty attribute. | |||
| 565 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
| 566 | ||||
| 567 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
| 568 | elementType, layout, memorySpace); | |||
| 569 | } | |||
| 570 | ||||
| 571 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
| 572 | AffineMap map, unsigned memorySpaceInd) { | |||
| 573 | ||||
| 574 | // Use default layout for empty map. | |||
| 575 | if (!map) | |||
| 576 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
| 577 | elementType.getContext()); | |||
| 578 | ||||
| 579 | // Wrap AffineMap into Attribute. | |||
| 580 | auto layout = AffineMapAttr::get(map); | |||
| 581 | ||||
| 582 | // Convert deprecated integer-like memory space to Attribute. | |||
| 583 | Attribute memorySpace = | |||
| 584 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); | |||
| 585 | ||||
| 586 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
| 587 | memorySpace); | |||
| 588 | } | |||
| 589 | ||||
| 590 | MemRefType | |||
| 591 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, | |||
| 592 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, | |||
| 593 | unsigned memorySpaceInd) { | |||
| 594 | ||||
| 595 | // Use default layout for empty map. | |||
| 596 | if (!map) | |||
| 597 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
| 598 | elementType.getContext()); | |||
| 599 | ||||
| 600 | // Wrap AffineMap into Attribute. | |||
| 601 | auto layout = AffineMapAttr::get(map); | |||
| 602 | ||||
| 603 | // Convert deprecated integer-like memory space to Attribute. | |||
| 604 | Attribute memorySpace = | |||
| 605 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); | |||
| 606 | ||||
| 607 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
| 608 | elementType, layout, memorySpace); | |||
| 609 | } | |||
| 610 | ||||
| 611 | LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 612 | ArrayRef<int64_t> shape, Type elementType, | |||
| 613 | MemRefLayoutAttrInterface layout, | |||
| 614 | Attribute memorySpace) { | |||
| 615 | if (!BaseMemRefType::isValidElementType(elementType)) | |||
| 616 | return emitError() << "invalid memref element type"; | |||
| 617 | ||||
| 618 | // Negative sizes are not allowed except for `kDynamic`. | |||
| 619 | for (int64_t s : shape) | |||
| 620 | if (s < 0 && !ShapedType::isDynamic(s)) | |||
| 621 | return emitError() << "invalid memref size"; | |||
| 622 | ||||
| 623 | assert(layout && "missing layout specification")(static_cast <bool> (layout && "missing layout specification" ) ? void (0) : __assert_fail ("layout && \"missing layout specification\"" , "mlir/lib/IR/BuiltinTypes.cpp", 623, __extension__ __PRETTY_FUNCTION__ )); | |||
| 624 | if (failed(layout.verifyLayout(shape, emitError))) | |||
| 625 | return failure(); | |||
| 626 | ||||
| 627 | if (!isSupportedMemorySpace(memorySpace)) | |||
| 628 | return emitError() << "unsupported memory space Attribute"; | |||
| 629 | ||||
| 630 | return success(); | |||
| 631 | } | |||
| 632 | ||||
| 633 | //===----------------------------------------------------------------------===// | |||
| 634 | // UnrankedMemRefType | |||
| 635 | //===----------------------------------------------------------------------===// | |||
| 636 | ||||
| 637 | unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { | |||
| 638 | return detail::getMemorySpaceAsInt(getMemorySpace()); | |||
| 639 | } | |||
| 640 | ||||
| 641 | LogicalResult | |||
| 642 | UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
| 643 | Type elementType, Attribute memorySpace) { | |||
| 644 | if (!BaseMemRefType::isValidElementType(elementType)) | |||
| 645 | return emitError() << "invalid memref element type"; | |||
| 646 | ||||
| 647 | if (!isSupportedMemorySpace(memorySpace)) | |||
| 648 | return emitError() << "unsupported memory space Attribute"; | |||
| 649 | ||||
| 650 | return success(); | |||
| 651 | } | |||
| 652 | ||||
| 653 | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( | |||
| 654 | // i.e. single term). Accumulate the AffineExpr into the existing one. | |||
| 655 | static void extractStridesFromTerm(AffineExpr e, | |||
| 656 | AffineExpr multiplicativeFactor, | |||
| 657 | MutableArrayRef<AffineExpr> strides, | |||
| 658 | AffineExpr &offset) { | |||
| 659 | if (auto dim = e.dyn_cast<AffineDimExpr>()) | |||
| 660 | strides[dim.getPosition()] = | |||
| 661 | strides[dim.getPosition()] + multiplicativeFactor; | |||
| 662 | else | |||
| 663 | offset = offset + e * multiplicativeFactor; | |||
| 664 | } | |||
| 665 | ||||
| 666 | /// Takes a single AffineExpr `e` and populates the `strides` array with the | |||
| 667 | /// strides expressions for each dim position. | |||
| 668 | /// The convention is that the strides for dimensions d0, .. dn appear in | |||
| 669 | /// order to make indexing intuitive into the result. | |||
| 670 | static LogicalResult extractStrides(AffineExpr e, | |||
| 671 | AffineExpr multiplicativeFactor, | |||
| 672 | MutableArrayRef<AffineExpr> strides, | |||
| 673 | AffineExpr &offset) { | |||
| 674 | auto bin = e.dyn_cast<AffineBinaryOpExpr>(); | |||
| 675 | if (!bin) { | |||
| 676 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); | |||
| 677 | return success(); | |||
| 678 | } | |||
| 679 | ||||
| 680 | if (bin.getKind() == AffineExprKind::CeilDiv || | |||
| 681 | bin.getKind() == AffineExprKind::FloorDiv || | |||
| 682 | bin.getKind() == AffineExprKind::Mod) | |||
| 683 | return failure(); | |||
| 684 | ||||
| 685 | if (bin.getKind() == AffineExprKind::Mul) { | |||
| 686 | auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); | |||
| 687 | if (dim) { | |||
| 688 | strides[dim.getPosition()] = | |||
| 689 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; | |||
| 690 | return success(); | |||
| 691 | } | |||
| 692 | // LHS and RHS may both contain complex expressions of dims. Try one path | |||
| 693 | // and if it fails try the other. This is guaranteed to succeed because | |||
| 694 | // only one path may have a `dim`, otherwise this is not an AffineExpr in | |||
| 695 | // the first place. | |||
| 696 | if (bin.getLHS().isSymbolicOrConstant()) | |||
| 697 | return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), | |||
| 698 | strides, offset); | |||
| 699 | return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), | |||
| 700 | strides, offset); | |||
| 701 | } | |||
| 702 | ||||
| 703 | if (bin.getKind() == AffineExprKind::Add) { | |||
| 704 | auto res1 = | |||
| 705 | extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); | |||
| 706 | auto res2 = | |||
| 707 | extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); | |||
| 708 | return success(succeeded(res1) && succeeded(res2)); | |||
| 709 | } | |||
| 710 | ||||
| 711 | llvm_unreachable("unexpected binary operation")::llvm::llvm_unreachable_internal("unexpected binary operation" , "mlir/lib/IR/BuiltinTypes.cpp", 711); | |||
| 712 | } | |||
| 713 | ||||
| 714 | /// A stride specification is a list of integer values that are either static | |||
| 715 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode | |||
| 716 | /// the distance in the number of elements between successive entries along a | |||
| 717 | /// particular dimension. | |||
| 718 | /// | |||
| 719 | /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a | |||
| 720 | /// non-contiguous memory region of `42` by `16` `f32` elements in which the | |||
| 721 | /// distance between two consecutive elements along the outer dimension is `1` | |||
| 722 | /// and the distance between two consecutive elements along the inner dimension | |||
| 723 | /// is `64`. | |||
| 724 | /// | |||
| 725 | /// The convention is that the strides for dimensions d0, .. dn appear in | |||
| 726 | /// order to make indexing intuitive into the result. | |||
| 727 | static LogicalResult getStridesAndOffset(MemRefType t, | |||
| 728 | SmallVectorImpl<AffineExpr> &strides, | |||
| 729 | AffineExpr &offset) { | |||
| 730 | AffineMap m = t.getLayout().getAffineMap(); | |||
| 731 | ||||
| 732 | if (m.getNumResults() != 1 && !m.isIdentity()) | |||
| 733 | return failure(); | |||
| 734 | ||||
| 735 | auto zero = getAffineConstantExpr(0, t.getContext()); | |||
| 736 | auto one = getAffineConstantExpr(1, t.getContext()); | |||
| 737 | offset = zero; | |||
| 738 | strides.assign(t.getRank(), zero); | |||
| 739 | ||||
| 740 | // Canonical case for empty map. | |||
| 741 | if (m.isIdentity()) { | |||
| 742 | // 0-D corner case, offset is already 0. | |||
| 743 | if (t.getRank() == 0) | |||
| 744 | return success(); | |||
| 745 | auto stridedExpr = | |||
| 746 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); | |||
| 747 | if (succeeded(extractStrides(stridedExpr, one, strides, offset))) | |||
| 748 | return success(); | |||
| 749 | assert(false && "unexpected failure: extract strides in canonical layout")(static_cast <bool> (false && "unexpected failure: extract strides in canonical layout" ) ? void (0) : __assert_fail ("false && \"unexpected failure: extract strides in canonical layout\"" , "mlir/lib/IR/BuiltinTypes.cpp", 749, __extension__ __PRETTY_FUNCTION__ )); | |||
| 750 | } | |||
| 751 | ||||
| 752 | // Non-canonical case requires more work. | |||
| 753 | auto stridedExpr = | |||
| 754 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); | |||
| 755 | if (failed(extractStrides(stridedExpr, one, strides, offset))) { | |||
| 756 | offset = AffineExpr(); | |||
| 757 | strides.clear(); | |||
| 758 | return failure(); | |||
| 759 | } | |||
| 760 | ||||
| 761 | // Simplify results to allow folding to constants and simple checks. | |||
| 762 | unsigned numDims = m.getNumDims(); | |||
| 763 | unsigned numSymbols = m.getNumSymbols(); | |||
| 764 | offset = simplifyAffineExpr(offset, numDims, numSymbols); | |||
| 765 | for (auto &stride : strides) | |||
| 766 | stride = simplifyAffineExpr(stride, numDims, numSymbols); | |||
| 767 | ||||
| 768 | // In practice, a strided memref must be internally non-aliasing. Test | |||
| 769 | // against 0 as a proxy. | |||
| 770 | // TODO: static cases can have more advanced checks. | |||
| 771 | // TODO: dynamic cases would require a way to compare symbolic | |||
| 772 | // expressions and would probably need an affine set context propagated | |||
| 773 | // everywhere. | |||
| 774 | if (llvm::any_of(strides, [](AffineExpr e) { | |||
| 775 | return e == getAffineConstantExpr(0, e.getContext()); | |||
| 776 | })) { | |||
| 777 | offset = AffineExpr(); | |||
| 778 | strides.clear(); | |||
| 779 | return failure(); | |||
| 780 | } | |||
| 781 | ||||
| 782 | return success(); | |||
| 783 | } | |||
| 784 | ||||
| 785 | LogicalResult mlir::getStridesAndOffset(MemRefType t, | |||
| 786 | SmallVectorImpl<int64_t> &strides, | |||
| 787 | int64_t &offset) { | |||
| 788 | // Happy path: the type uses the strided layout directly. | |||
| 789 | if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) { | |||
| 790 | llvm::append_range(strides, strided.getStrides()); | |||
| 791 | offset = strided.getOffset(); | |||
| 792 | return success(); | |||
| 793 | } | |||
| 794 | ||||
| 795 | // Otherwise, defer to the affine fallback as layouts are supposed to be | |||
| 796 | // convertible to affine maps. | |||
| 797 | AffineExpr offsetExpr; | |||
| 798 | SmallVector<AffineExpr, 4> strideExprs; | |||
| 799 | if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) | |||
| 800 | return failure(); | |||
| 801 | if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) | |||
| 802 | offset = cst.getValue(); | |||
| 803 | else | |||
| 804 | offset = ShapedType::kDynamic; | |||
| 805 | for (auto e : strideExprs) { | |||
| 806 | if (auto c = e.dyn_cast<AffineConstantExpr>()) | |||
| 807 | strides.push_back(c.getValue()); | |||
| 808 | else | |||
| 809 | strides.push_back(ShapedType::kDynamic); | |||
| 810 | } | |||
| 811 | return success(); | |||
| 812 | } | |||
| 813 | ||||
| 814 | std::pair<SmallVector<int64_t>, int64_t> | |||
| 815 | mlir::getStridesAndOffset(MemRefType t) { | |||
| 816 | SmallVector<int64_t> strides; | |||
| 817 | int64_t offset; | |||
| 818 | LogicalResult status = getStridesAndOffset(t, strides, offset); | |||
| 819 | (void)status; | |||
| 820 | assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset")(static_cast <bool> (succeeded(status) && "Invalid use of check-free getStridesAndOffset" ) ? void (0) : __assert_fail ("succeeded(status) && \"Invalid use of check-free getStridesAndOffset\"" , "mlir/lib/IR/BuiltinTypes.cpp", 820, __extension__ __PRETTY_FUNCTION__ )); | |||
| 821 | return {strides, offset}; | |||
| 822 | } | |||
| 823 | ||||
| 824 | //===----------------------------------------------------------------------===// | |||
| 825 | /// TupleType | |||
| 826 | //===----------------------------------------------------------------------===// | |||
| 827 | ||||
| 828 | /// Return the elements types for this tuple. | |||
| 829 | ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } | |||
| 830 | ||||
| 831 | /// Accumulate the types contained in this tuple and tuples nested within it. | |||
| 832 | /// Note that this only flattens nested tuples, not any other container type, | |||
| 833 | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to | |||
| 834 | /// (i32, tensor<i32>, f32, i64) | |||
| 835 | void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { | |||
| 836 | for (Type type : getTypes()) { | |||
| 837 | if (auto nestedTuple = type.dyn_cast<TupleType>()) | |||
| 838 | nestedTuple.getFlattenedTypes(types); | |||
| 839 | else | |||
| 840 | types.push_back(type); | |||
| 841 | } | |||
| 842 | } | |||
| 843 | ||||
| 844 | /// Return the number of element types. | |||
| 845 | size_t TupleType::size() const { return getImpl()->size(); } | |||
| 846 | ||||
| 847 | //===----------------------------------------------------------------------===// | |||
| 848 | // Type Utilities | |||
| 849 | //===----------------------------------------------------------------------===// | |||
| 850 | ||||
| 851 | /// Return a version of `t` with identity layout if it can be determined | |||
| 852 | /// statically that the layout is the canonical contiguous strided layout. | |||
| 853 | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of | |||
| 854 | /// `t` with simplified layout. | |||
| 855 | /// If `t` has multiple layout maps or a multi-result layout, just return `t`. | |||
| 856 | MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { | |||
| 857 | AffineMap m = t.getLayout().getAffineMap(); | |||
| 858 | ||||
| 859 | // Already in canonical form. | |||
| 860 | if (m.isIdentity()) | |||
| 861 | return t; | |||
| 862 | ||||
| 863 | // Can't reduce to canonical identity form, return in canonical form. | |||
| 864 | if (m.getNumResults() > 1) | |||
| 865 | return t; | |||
| 866 | ||||
| 867 | // Corner-case for 0-D affine maps. | |||
| 868 | if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { | |||
| 869 | if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) | |||
| 870 | if (cst.getValue() == 0) | |||
| 871 | return MemRefType::Builder(t).setLayout({}); | |||
| 872 | return t; | |||
| 873 | } | |||
| 874 | ||||
| 875 | // 0-D corner case for empty shape that still have an affine map. Example: | |||
| 876 | // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose | |||
| 877 | // offset needs to remain, just return t. | |||
| 878 | if (t.getShape().empty()) | |||
| 879 | return t; | |||
| 880 | ||||
| 881 | // If the canonical strided layout for the sizes of `t` is equal to the | |||
| 882 | // simplified layout of `t` we can just return an empty layout. Otherwise, | |||
| 883 | // just simplify the existing layout. | |||
| 884 | AffineExpr expr = | |||
| 885 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); | |||
| 886 | auto simplifiedLayoutExpr = | |||
| 887 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); | |||
| 888 | if (expr != simplifiedLayoutExpr) | |||
| 889 | return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( | |||
| 890 | m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); | |||
| 891 | return MemRefType::Builder(t).setLayout({}); | |||
| 892 | } | |||
| 893 | ||||
| 894 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, | |||
| 895 | ArrayRef<AffineExpr> exprs, | |||
| 896 | MLIRContext *context) { | |||
| 897 | // Size 0 corner case is useful for canonicalizations. | |||
| 898 | if (sizes.empty()) | |||
| 899 | return getAffineConstantExpr(0, context); | |||
| 900 | ||||
| 901 | assert(!exprs.empty() && "expected exprs")(static_cast <bool> (!exprs.empty() && "expected exprs" ) ? void (0) : __assert_fail ("!exprs.empty() && \"expected exprs\"" , "mlir/lib/IR/BuiltinTypes.cpp", 901, __extension__ __PRETTY_FUNCTION__ )); | |||
| 902 | auto maps = AffineMap::inferFromExprList(exprs); | |||
| 903 | assert(!maps.empty() && "Expected one non-empty map")(static_cast <bool> (!maps.empty() && "Expected one non-empty map" ) ? void (0) : __assert_fail ("!maps.empty() && \"Expected one non-empty map\"" , "mlir/lib/IR/BuiltinTypes.cpp", 903, __extension__ __PRETTY_FUNCTION__ )); | |||
| 904 | unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); | |||
| 905 | ||||
| 906 | AffineExpr expr; | |||
| 907 | bool dynamicPoisonBit = false; | |||
| 908 | int64_t runningSize = 1; | |||
| 909 | for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { | |||
| 910 | int64_t size = std::get<1>(en); | |||
| 911 | AffineExpr dimExpr = std::get<0>(en); | |||
| 912 | AffineExpr stride = dynamicPoisonBit | |||
| 913 | ? getAffineSymbolExpr(nSymbols++, context) | |||
| 914 | : getAffineConstantExpr(runningSize, context); | |||
| 915 | expr = expr ? expr + dimExpr * stride : dimExpr * stride; | |||
| 916 | if (size > 0) { | |||
| 917 | runningSize *= size; | |||
| 918 | assert(runningSize > 0 && "integer overflow in size computation")(static_cast <bool> (runningSize > 0 && "integer overflow in size computation" ) ? void (0) : __assert_fail ("runningSize > 0 && \"integer overflow in size computation\"" , "mlir/lib/IR/BuiltinTypes.cpp", 918, __extension__ __PRETTY_FUNCTION__ )); | |||
| 919 | } else { | |||
| 920 | dynamicPoisonBit = true; | |||
| 921 | } | |||
| 922 | } | |||
| 923 | return simplifyAffineExpr(expr, numDims, nSymbols); | |||
| 924 | } | |||
| 925 | ||||
| 926 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, | |||
| 927 | MLIRContext *context) { | |||
| 928 | SmallVector<AffineExpr, 4> exprs; | |||
| 929 | exprs.reserve(sizes.size()); | |||
| 930 | for (auto dim : llvm::seq<unsigned>(0, sizes.size())) | |||
| 931 | exprs.push_back(getAffineDimExpr(dim, context)); | |||
| 932 | return makeCanonicalStridedLayoutExpr(sizes, exprs, context); | |||
| 933 | } | |||
| 934 | ||||
| 935 | /// Return true if the layout for `t` is compatible with strided semantics. | |||
| 936 | bool mlir::isStrided(MemRefType t) { | |||
| 937 | int64_t offset; | |||
| 938 | SmallVector<int64_t, 4> strides; | |||
| 939 | auto res = getStridesAndOffset(t, strides, offset); | |||
| 940 | return succeeded(res); | |||
| 941 | } |
| 1 | //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// |
| 2 | // |
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | // See https://llvm.org/LICENSE.txt for license information. |
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | // |
| 7 | //===----------------------------------------------------------------------===// |
| 8 | // |
| 9 | // This file defines the 'dialect' abstraction. |
| 10 | // |
| 11 | //===----------------------------------------------------------------------===// |
| 12 | |
| 13 | #ifndef MLIR_IR_DIALECT_H |
| 14 | #define MLIR_IR_DIALECT_H |
| 15 | |
| 16 | #include "mlir/IR/DialectRegistry.h" |
| 17 | #include "mlir/IR/OperationSupport.h" |
| 18 | #include "mlir/Support/TypeID.h" |
| 19 | |
| 20 | #include <map> |
| 21 | #include <tuple> |
| 22 | |
| 23 | namespace mlir { |
| 24 | class DialectAsmParser; |
| 25 | class DialectAsmPrinter; |
| 26 | class DialectInterface; |
| 27 | class OpBuilder; |
| 28 | class Type; |
| 29 | |
| 30 | //===----------------------------------------------------------------------===// |
| 31 | // Dialect |
| 32 | //===----------------------------------------------------------------------===// |
| 33 | |
| 34 | /// Dialects are groups of MLIR operations, types and attributes, as well as |
| 35 | /// behavior associated with the entire group. For example, hooks into other |
| 36 | /// systems for constant folding, interfaces, default named types for asm |
| 37 | /// printing, etc. |
| 38 | /// |
| 39 | /// Instances of the dialect object are loaded in a specific MLIRContext. |
| 40 | /// |
| 41 | class Dialect { |
| 42 | public: |
| 43 | /// Type for a callback provided by the dialect to parse a custom operation. |
| 44 | /// This is used for the dialect to provide an alternative way to parse custom |
| 45 | /// operations, including unregistered ones. |
| 46 | using ParseOpHook = |
| 47 | function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>; |
| 48 | |
| 49 | virtual ~Dialect(); |
| 50 | |
| 51 | /// Utility function that returns if the given string is a valid dialect |
| 52 | /// namespace |
| 53 | static bool isValidNamespace(StringRef str); |
| 54 | |
| 55 | MLIRContext *getContext() const { return context; } |
| 56 | |
| 57 | StringRef getNamespace() const { return name; } |
| 58 | |
| 59 | /// Returns the unique identifier that corresponds to this dialect. |
| 60 | TypeID getTypeID() const { return dialectID; } |
| 61 | |
| 62 | /// Returns true if this dialect allows for unregistered operations, i.e. |
| 63 | /// operations prefixed with the dialect namespace but not registered with |
| 64 | /// addOperation. |
| 65 | bool allowsUnknownOperations() const { return unknownOpsAllowed; } |
| 66 | |
| 67 | /// Return true if this dialect allows for unregistered types, i.e., types |
| 68 | /// prefixed with the dialect namespace but not registered with addType. |
| 69 | /// These are represented with OpaqueType. |
| 70 | bool allowsUnknownTypes() const { return unknownTypesAllowed; } |
| 71 | |
| 72 | /// Register dialect-wide canonicalization patterns. This method should only |
| 73 | /// be used to register canonicalization patterns that do not conceptually |
| 74 | /// belong to any single operation in the dialect. (In that case, use the op's |
| 75 | /// canonicalizer.) E.g., canonicalization patterns for op interfaces should |
| 76 | /// be registered here. |
| 77 | virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {} |
| 78 | |
| 79 | /// Registered hook to materialize a single constant operation from a given |
| 80 | /// attribute value with the desired resultant type. This method should use |
| 81 | /// the provided builder to create the operation without changing the |
| 82 | /// insertion position. The generated operation is expected to be constant |
| 83 | /// like, i.e. single result, zero operands, non side-effecting, etc. On |
| 84 | /// success, this hook should return the value generated to represent the |
| 85 | /// constant value. Otherwise, it should return null on failure. |
| 86 | virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, |
| 87 | Type type, Location loc) { |
| 88 | return nullptr; |
| 89 | } |
| 90 | |
| 91 | //===--------------------------------------------------------------------===// |
| 92 | // Parsing Hooks |
| 93 | //===--------------------------------------------------------------------===// |
| 94 | |
| 95 | /// Parse an attribute registered to this dialect. If 'type' is nonnull, it |
| 96 | /// refers to the expected type of the attribute. |
| 97 | virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; |
| 98 | |
| 99 | /// Print an attribute registered to this dialect. Note: The type of the |
| 100 | /// attribute need not be printed by this method as it is always printed by |
| 101 | /// the caller. |
| 102 | virtual void printAttribute(Attribute, DialectAsmPrinter &) const { |
| 103 | llvm_unreachable("dialect has no registered attribute printing hook")::llvm::llvm_unreachable_internal("dialect has no registered attribute printing hook" , "mlir/include/mlir/IR/Dialect.h", 103); |
| 104 | } |
| 105 | |
| 106 | /// Parse a type registered to this dialect. |
| 107 | virtual Type parseType(DialectAsmParser &parser) const; |
| 108 | |
| 109 | /// Print a type registered to this dialect. |
| 110 | virtual void printType(Type, DialectAsmPrinter &) const { |
| 111 | llvm_unreachable("dialect has no registered type printing hook")::llvm::llvm_unreachable_internal("dialect has no registered type printing hook" , "mlir/include/mlir/IR/Dialect.h", 111); |
| 112 | } |
| 113 | |
| 114 | /// Return the hook to parse an operation registered to this dialect, if any. |
| 115 | /// By default this will lookup for registered operations and return the |
| 116 | /// `parse()` method registered on the RegisteredOperationName. Dialects can |
| 117 | /// override this behavior and handle unregistered operations as well. |
| 118 | virtual std::optional<ParseOpHook> |
| 119 | getParseOperationHook(StringRef opName) const; |
| 120 | |
| 121 | /// Print an operation registered to this dialect. |
| 122 | /// This hook is invoked for registered operation which don't override the |
| 123 | /// `print()` method to define their own custom assembly. |
| 124 | virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> |
| 125 | getOperationPrinter(Operation *op) const; |
| 126 | |
| 127 | //===--------------------------------------------------------------------===// |
| 128 | // Verification Hooks |
| 129 | //===--------------------------------------------------------------------===// |
| 130 | |
| 131 | /// Verify an attribute from this dialect on the argument at 'argIndex' for |
| 132 | /// the region at 'regionIndex' on the given operation. Returns failure if |
| 133 | /// the verification failed, success otherwise. This hook may optionally be |
| 134 | /// invoked from any operation containing a region. |
| 135 | virtual LogicalResult verifyRegionArgAttribute(Operation *, |
| 136 | unsigned regionIndex, |
| 137 | unsigned argIndex, |
| 138 | NamedAttribute); |
| 139 | |
| 140 | /// Verify an attribute from this dialect on the result at 'resultIndex' for |
| 141 | /// the region at 'regionIndex' on the given operation. Returns failure if |
| 142 | /// the verification failed, success otherwise. This hook may optionally be |
| 143 | /// invoked from any operation containing a region. |
| 144 | virtual LogicalResult verifyRegionResultAttribute(Operation *, |
| 145 | unsigned regionIndex, |
| 146 | unsigned resultIndex, |
| 147 | NamedAttribute); |
| 148 | |
| 149 | /// Verify an attribute from this dialect on the given operation. Returns |
| 150 | /// failure if the verification failed, success otherwise. |
| 151 | virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { |
| 152 | return success(); |
| 153 | } |
| 154 | |
| 155 | //===--------------------------------------------------------------------===// |
| 156 | // Interfaces |
| 157 | //===--------------------------------------------------------------------===// |
| 158 | |
| 159 | /// Lookup an interface for the given ID if one is registered, otherwise |
| 160 | /// nullptr. |
| 161 | DialectInterface *getRegisteredInterface(TypeID interfaceID) { |
| 162 | auto it = registeredInterfaces.find(interfaceID); |
| 163 | return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; |
| 164 | } |
| 165 | template <typename InterfaceT> |
| 166 | InterfaceT *getRegisteredInterface() { |
| 167 | return static_cast<InterfaceT *>( |
| 168 | getRegisteredInterface(InterfaceT::getInterfaceID())); |
| 169 | } |
| 170 | |
| 171 | /// Lookup an op interface for the given ID if one is registered, otherwise |
| 172 | /// nullptr. |
| 173 | virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, |
| 174 | OperationName opName) { |
| 175 | return nullptr; |
| 176 | } |
| 177 | template <typename InterfaceT> |
| 178 | typename InterfaceT::Concept * |
| 179 | getRegisteredInterfaceForOp(OperationName opName) { |
| 180 | return static_cast<typename InterfaceT::Concept *>( |
| 181 | getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); |
| 182 | } |
| 183 | |
| 184 | /// Register a dialect interface with this dialect instance. |
| 185 | void addInterface(std::unique_ptr<DialectInterface> interface); |
| 186 | |
| 187 | /// Register a set of dialect interfaces with this dialect instance. |
| 188 | template <typename... Args> |
| 189 | void addInterfaces() { |
| 190 | (addInterface(std::make_unique<Args>(this)), ...); |
| 191 | } |
| 192 | template <typename InterfaceT, typename... Args> |
| 193 | InterfaceT &addInterface(Args &&...args) { |
| 194 | InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...); |
| 195 | addInterface(std::unique_ptr<DialectInterface>(interface)); |
| 196 | return *interface; |
| 197 | } |
| 198 | |
| 199 | protected: |
| 200 | /// The constructor takes a unique namespace for this dialect as well as the |
| 201 | /// context to bind to. |
| 202 | /// Note: The namespace must not contain '.' characters. |
| 203 | /// Note: All operations belonging to this dialect must have names starting |
| 204 | /// with the namespace followed by '.'. |
| 205 | /// Example: |
| 206 | /// - "tf" for the TensorFlow ops like "tf.add". |
| 207 | Dialect(StringRef name, MLIRContext *context, TypeID id); |
| 208 | |
| 209 | /// This method is used by derived classes to add their operations to the set. |
| 210 | /// |
| 211 | template <typename... Args> |
| 212 | void addOperations() { |
| 213 | // This initializer_list argument pack expansion is essentially equal to |
| 214 | // using a fold expression with a comma operator. Clang however, refuses |
| 215 | // to compile a fold expression with a depth of more than 256 by default. |
| 216 | // There seem to be no such limitations for initializer_list. |
| 217 | (void)std::initializer_list<int>{ |
| 218 | 0, (RegisteredOperationName::insert<Args>(*this), 0)...}; |
| 219 | } |
| 220 | |
| 221 | /// Register a set of type classes with this dialect. |
| 222 | template <typename... Args> |
| 223 | void addTypes() { |
| 224 | (addType<Args>(), ...); |
| 225 | } |
| 226 | |
| 227 | /// Register a type instance with this dialect. |
| 228 | /// The use of this method is in general discouraged in favor of |
| 229 | /// 'addTypes<CustomType>()'. |
| 230 | void addType(TypeID typeID, AbstractType &&typeInfo); |
| 231 | |
| 232 | /// Register a set of attribute classes with this dialect. |
| 233 | template <typename... Args> |
| 234 | void addAttributes() { |
| 235 | (addAttribute<Args>(), ...); |
| 236 | } |
| 237 | |
| 238 | /// Register an attribute instance with this dialect. |
| 239 | /// The use of this method is in general discouraged in favor of |
| 240 | /// 'addAttributes<CustomAttr>()'. |
| 241 | void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); |
| 242 | |
| 243 | /// Enable support for unregistered operations. |
| 244 | void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } |
| 245 | |
| 246 | /// Enable support for unregistered types. |
| 247 | void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } |
| 248 | |
| 249 | private: |
| 250 | Dialect(const Dialect &) = delete; |
| 251 | void operator=(Dialect &) = delete; |
| 252 | |
| 253 | /// Register an attribute instance with this dialect. |
| 254 | template <typename T> |
| 255 | void addAttribute() { |
| 256 | // Add this attribute to the dialect and register it with the uniquer. |
| 257 | addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this)); |
| 258 | detail::AttributeUniquer::registerAttribute<T>(context); |
| 259 | } |
| 260 | |
| 261 | /// Register a type instance with this dialect. |
| 262 | template <typename T> |
| 263 | void addType() { |
| 264 | // Add this type to the dialect and register it with the uniquer. |
| 265 | addType(T::getTypeID(), AbstractType::get<T>(*this)); |
| 266 | detail::TypeUniquer::registerType<T>(context); |
| 267 | } |
| 268 | |
| 269 | /// The namespace of this dialect. |
| 270 | StringRef name; |
| 271 | |
| 272 | /// The unique identifier of the derived Op class, this is used in the context |
| 273 | /// to allow registering multiple times the same dialect. |
| 274 | TypeID dialectID; |
| 275 | |
| 276 | /// This is the context that owns this Dialect object. |
| 277 | MLIRContext *context; |
| 278 | |
| 279 | /// Flag that specifies whether this dialect supports unregistered operations, |
| 280 | /// i.e. operations prefixed with the dialect namespace but not registered |
| 281 | /// with addOperation. |
| 282 | bool unknownOpsAllowed = false; |
| 283 | |
| 284 | /// Flag that specifies whether this dialect allows unregistered types, i.e. |
| 285 | /// types prefixed with the dialect namespace but not registered with addType. |
| 286 | /// These types are represented with OpaqueType. |
| 287 | bool unknownTypesAllowed = false; |
| 288 | |
| 289 | /// A collection of registered dialect interfaces. |
| 290 | DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; |
| 291 | |
| 292 | friend class DialectRegistry; |
| 293 | friend void registerDialect(); |
| 294 | friend class MLIRContext; |
| 295 | }; |
| 296 | |
| 297 | } // namespace mlir |
| 298 | |
| 299 | namespace llvm { |
| 300 | /// Provide isa functionality for Dialects. |
| 301 | template <typename T> |
| 302 | struct isa_impl<T, ::mlir::Dialect, |
| 303 | std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> { |
| 304 | static inline bool doit(const ::mlir::Dialect &dialect) { |
| 305 | return mlir::TypeID::get<T>() == dialect.getTypeID(); |
| 306 | } |
| 307 | }; |
| 308 | template <typename T> |
| 309 | struct isa_impl< |
| 310 | T, ::mlir::Dialect, |
| 311 | std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> { |
| 312 | static inline bool doit(const ::mlir::Dialect &dialect) { |
| 313 | return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>(); |
| 314 | } |
| 315 | }; |
| 316 | template <typename T> |
| 317 | struct cast_retty_impl<T, ::mlir::Dialect *> { |
| 318 | using ret_type = T *; |
| 319 | }; |
| 320 | template <typename T> |
| 321 | struct cast_retty_impl<T, ::mlir::Dialect> { |
| 322 | using ret_type = T &; |
| 323 | }; |
| 324 | |
| 325 | template <typename T> |
| 326 | struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> { |
| 327 | template <typename To> |
| 328 | static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &> |
| 329 | doitImpl(::mlir::Dialect &dialect) { |
| 330 | return static_cast<To &>(dialect); |
| 331 | } |
| 332 | template <typename To> |
| 333 | static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value, |
| 334 | To &> |
| 335 | doitImpl(::mlir::Dialect &dialect) { |
| 336 | return *dialect.getRegisteredInterface<To>(); |
| 337 | } |
| 338 | |
| 339 | static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); } |
| 340 | }; |
| 341 | template <class T> |
| 342 | struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> { |
| 343 | static auto doit(::mlir::Dialect *dialect) { |
| 344 | return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit( |
| 345 | *dialect); |
| 346 | } |
| 347 | }; |
| 348 | |
| 349 | } // namespace llvm |
| 350 | |
| 351 | #endif |
| 1 | //===- TypeSupport.h --------------------------------------------*- C++ -*-===// | |||
| 2 | // | |||
| 3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
| 4 | // See https://llvm.org/LICENSE.txt for license information. | |||
| 5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
| 6 | // | |||
| 7 | //===----------------------------------------------------------------------===// | |||
| 8 | // | |||
| 9 | // This file defines support types for registering dialect extended types. | |||
| 10 | // | |||
| 11 | //===----------------------------------------------------------------------===// | |||
| 12 | ||||
| 13 | #ifndef MLIR_IR_TYPESUPPORT_H | |||
| 14 | #define MLIR_IR_TYPESUPPORT_H | |||
| 15 | ||||
| 16 | #include "mlir/IR/MLIRContext.h" | |||
| 17 | #include "mlir/IR/StorageUniquerSupport.h" | |||
| 18 | #include "llvm/ADT/Twine.h" | |||
| 19 | ||||
| 20 | namespace mlir { | |||
| 21 | class Dialect; | |||
| 22 | class MLIRContext; | |||
| 23 | ||||
| 24 | //===----------------------------------------------------------------------===// | |||
| 25 | // AbstractType | |||
| 26 | //===----------------------------------------------------------------------===// | |||
| 27 | ||||
| 28 | /// This class contains all of the static information common to all instances of | |||
| 29 | /// a registered Type. | |||
| 30 | class AbstractType { | |||
| 31 | public: | |||
| 32 | using HasTraitFn = llvm::unique_function<bool(TypeID) const>; | |||
| 33 | using WalkImmediateSubElementsFn = function_ref<void( | |||
| 34 | Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>; | |||
| 35 | using ReplaceImmediateSubElementsFn = | |||
| 36 | function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>; | |||
| 37 | ||||
| 38 | /// Look up the specified abstract type in the MLIRContext and return a | |||
| 39 | /// reference to it. | |||
| 40 | static const AbstractType &lookup(TypeID typeID, MLIRContext *context); | |||
| 41 | ||||
| 42 | /// This method is used by Dialect objects when they register the list of | |||
| 43 | /// types they contain. | |||
| 44 | template <typename T> | |||
| 45 | static AbstractType get(Dialect &dialect) { | |||
| 46 | return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), | |||
| ||||
| 47 | T::getWalkImmediateSubElementsFn(), | |||
| 48 | T::getReplaceImmediateSubElementsFn(), T::getTypeID()); | |||
| 49 | } | |||
| 50 | ||||
| 51 | /// This method is used by Dialect objects to register types with | |||
| 52 | /// custom TypeIDs. | |||
| 53 | /// The use of this method is in general discouraged in favor of | |||
| 54 | /// 'get<CustomType>(dialect)'; | |||
| 55 | static AbstractType | |||
| 56 | get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, | |||
| 57 | HasTraitFn &&hasTrait, | |||
| 58 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, | |||
| 59 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, | |||
| 60 | TypeID typeID) { | |||
| 61 | return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), | |||
| 62 | walkImmediateSubElementsFn, | |||
| 63 | replaceImmediateSubElementsFn, typeID); | |||
| 64 | } | |||
| 65 | ||||
| 66 | /// Return the dialect this type was registered to. | |||
| 67 | Dialect &getDialect() const { return const_cast<Dialect &>(dialect); } | |||
| 68 | ||||
| 69 | /// Returns an instance of the concept object for the given interface if it | |||
| 70 | /// was registered to this type, null otherwise. This should not be used | |||
| 71 | /// directly. | |||
| 72 | template <typename T> | |||
| 73 | typename T::Concept *getInterface() const { | |||
| 74 | return interfaceMap.lookup<T>(); | |||
| 75 | } | |||
| 76 | ||||
| 77 | /// Returns true if the type has the interface with the given ID. | |||
| 78 | bool hasInterface(TypeID interfaceID) const { | |||
| 79 | return interfaceMap.contains(interfaceID); | |||
| 80 | } | |||
| 81 | ||||
| 82 | /// Returns true if the type has a particular trait. | |||
| 83 | template <template <typename T> class Trait> | |||
| 84 | bool hasTrait() const { | |||
| 85 | return hasTraitFn(TypeID::get<Trait>()); | |||
| 86 | } | |||
| 87 | ||||
| 88 | /// Returns true if the type has a particular trait. | |||
| 89 | bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } | |||
| 90 | ||||
| 91 | /// Walk the immediate sub-elements of the given type. | |||
| 92 | void walkImmediateSubElements(Type type, | |||
| 93 | function_ref<void(Attribute)> walkAttrsFn, | |||
| 94 | function_ref<void(Type)> walkTypesFn) const; | |||
| 95 | ||||
| 96 | /// Replace the immediate sub-elements of the given type. | |||
| 97 | Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs, | |||
| 98 | ArrayRef<Type> replTypes) const; | |||
| 99 | ||||
| 100 | /// Return the unique identifier representing the concrete type class. | |||
| 101 | TypeID getTypeID() const { return typeID; } | |||
| 102 | ||||
| 103 | private: | |||
| 104 | AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, | |||
| 105 | HasTraitFn &&hasTrait, | |||
| 106 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, | |||
| 107 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, | |||
| 108 | TypeID typeID) | |||
| 109 | : dialect(dialect), interfaceMap(std::move(interfaceMap)), | |||
| 110 | hasTraitFn(std::move(hasTrait)), | |||
| 111 | walkImmediateSubElementsFn(walkImmediateSubElementsFn), | |||
| 112 | replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), | |||
| 113 | typeID(typeID) {} | |||
| 114 | ||||
| 115 | /// Give StorageUserBase access to the mutable lookup. | |||
| 116 | template <typename ConcreteT, typename BaseT, typename StorageT, | |||
| 117 | typename UniquerT, template <typename T> class... Traits> | |||
| 118 | friend class detail::StorageUserBase; | |||
| 119 | ||||
| 120 | /// Look up the specified abstract type in the MLIRContext and return a | |||
| 121 | /// (mutable) pointer to it. Return a null pointer if the type could not | |||
| 122 | /// be found in the context. | |||
| 123 | static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); | |||
| 124 | ||||
| 125 | /// This is the dialect that this type was registered to. | |||
| 126 | const Dialect &dialect; | |||
| 127 | ||||
| 128 | /// This is a collection of the interfaces registered to this type. | |||
| 129 | detail::InterfaceMap interfaceMap; | |||
| 130 | ||||
| 131 | /// Function to check if the type has a particular trait. | |||
| 132 | HasTraitFn hasTraitFn; | |||
| 133 | ||||
| 134 | /// Function to walk the immediate sub-elements of this type. | |||
| 135 | WalkImmediateSubElementsFn walkImmediateSubElementsFn; | |||
| 136 | ||||
| 137 | /// Function to replace the immediate sub-elements of this type. | |||
| 138 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn; | |||
| 139 | ||||
| 140 | /// The unique identifier of the derived Type class. | |||
| 141 | const TypeID typeID; | |||
| 142 | }; | |||
| 143 | ||||
| 144 | //===----------------------------------------------------------------------===// | |||
| 145 | // TypeStorage | |||
| 146 | //===----------------------------------------------------------------------===// | |||
| 147 | ||||
| 148 | namespace detail { | |||
| 149 | struct TypeUniquer; | |||
| 150 | } // namespace detail | |||
| 151 | ||||
| 152 | /// Base storage class appearing in a Type. | |||
| 153 | class TypeStorage : public StorageUniquer::BaseStorage { | |||
| 154 | friend detail::TypeUniquer; | |||
| 155 | friend StorageUniquer; | |||
| 156 | ||||
| 157 | public: | |||
| 158 | /// Return the abstract type descriptor for this type. | |||
| 159 | const AbstractType &getAbstractType() { | |||
| 160 | assert(abstractType && "Malformed type storage object.")(static_cast <bool> (abstractType && "Malformed type storage object." ) ? void (0) : __assert_fail ("abstractType && \"Malformed type storage object.\"" , "mlir/include/mlir/IR/TypeSupport.h", 160, __extension__ __PRETTY_FUNCTION__ )); | |||
| 161 | return *abstractType; | |||
| 162 | } | |||
| 163 | ||||
| 164 | protected: | |||
| 165 | /// This constructor is used by derived classes as part of the TypeUniquer. | |||
| 166 | TypeStorage() {} | |||
| 167 | ||||
| 168 | private: | |||
| 169 | /// Set the abstract type for this storage instance. This is used by the | |||
| 170 | /// TypeUniquer when initializing a newly constructed type storage object. | |||
| 171 | void initialize(const AbstractType &abstractTy) { | |||
| 172 | abstractType = const_cast<AbstractType *>(&abstractTy); | |||
| 173 | } | |||
| 174 | ||||
| 175 | /// The abstract description for this type. | |||
| 176 | AbstractType *abstractType{nullptr}; | |||
| 177 | }; | |||
| 178 | ||||
| 179 | /// Default storage type for types that require no additional initialization or | |||
| 180 | /// storage. | |||
| 181 | using DefaultTypeStorage = TypeStorage; | |||
| 182 | ||||
| 183 | //===----------------------------------------------------------------------===// | |||
| 184 | // TypeStorageAllocator | |||
| 185 | //===----------------------------------------------------------------------===// | |||
| 186 | ||||
| 187 | /// This is a utility allocator used to allocate memory for instances of derived | |||
| 188 | /// Types. | |||
| 189 | using TypeStorageAllocator = StorageUniquer::StorageAllocator; | |||
| 190 | ||||
| 191 | //===----------------------------------------------------------------------===// | |||
| 192 | // TypeUniquer | |||
| 193 | //===----------------------------------------------------------------------===// | |||
| 194 | namespace detail { | |||
| 195 | /// A utility class to get, or create, unique instances of types within an | |||
| 196 | /// MLIRContext. This class manages all creation and uniquing of types. | |||
| 197 | struct TypeUniquer { | |||
| 198 | /// Get an uniqued instance of a type T. | |||
| 199 | template <typename T, typename... Args> | |||
| 200 | static T get(MLIRContext *ctx, Args &&...args) { | |||
| 201 | return getWithTypeID<T, Args...>(ctx, T::getTypeID(), | |||
| 202 | std::forward<Args>(args)...); | |||
| 203 | } | |||
| 204 | ||||
| 205 | /// Get an uniqued instance of a parametric type T. | |||
| 206 | /// The use of this method is in general discouraged in favor of | |||
| 207 | /// 'get<T, Args>(ctx, args)'. | |||
| 208 | template <typename T, typename... Args> | |||
| 209 | static std::enable_if_t< | |||
| 210 | !std::is_same<typename T::ImplType, TypeStorage>::value, T> | |||
| 211 | getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { | |||
| 212 | #ifndef NDEBUG | |||
| 213 | if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) | |||
| 214 | llvm::report_fatal_error( | |||
| 215 | llvm::Twine("can't create type '") + llvm::getTypeName<T>() + | |||
| 216 | "' because storage uniquer isn't initialized: the dialect was likely " | |||
| 217 | "not loaded, or the type wasn't added with addTypes<...>() " | |||
| 218 | "in the Dialect::initialize() method."); | |||
| 219 | #endif | |||
| 220 | return ctx->getTypeUniquer().get<typename T::ImplType>( | |||
| 221 | [&, typeID](TypeStorage *storage) { | |||
| 222 | storage->initialize(AbstractType::lookup(typeID, ctx)); | |||
| 223 | }, | |||
| 224 | typeID, std::forward<Args>(args)...); | |||
| 225 | } | |||
| 226 | /// Get an uniqued instance of a singleton type T. | |||
| 227 | /// The use of this method is in general discouraged in favor of | |||
| 228 | /// 'get<T, Args>(ctx, args)'. | |||
| 229 | template <typename T> | |||
| 230 | static std::enable_if_t< | |||
| 231 | std::is_same<typename T::ImplType, TypeStorage>::value, T> | |||
| 232 | getWithTypeID(MLIRContext *ctx, TypeID typeID) { | |||
| 233 | #ifndef NDEBUG | |||
| 234 | if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) | |||
| 235 | llvm::report_fatal_error( | |||
| 236 | llvm::Twine("can't create type '") + llvm::getTypeName<T>() + | |||
| 237 | "' because storage uniquer isn't initialized: the dialect was likely " | |||
| 238 | "not loaded, or the type wasn't added with addTypes<...>() " | |||
| 239 | "in the Dialect::initialize() method."); | |||
| 240 | #endif | |||
| 241 | return ctx->getTypeUniquer().get<typename T::ImplType>(typeID); | |||
| 242 | } | |||
| 243 | ||||
| 244 | /// Change the mutable component of the given type instance in the provided | |||
| 245 | /// context. | |||
| 246 | template <typename T, typename... Args> | |||
| 247 | static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, | |||
| 248 | Args &&...args) { | |||
| 249 | assert(impl && "cannot mutate null type")(static_cast <bool> (impl && "cannot mutate null type" ) ? void (0) : __assert_fail ("impl && \"cannot mutate null type\"" , "mlir/include/mlir/IR/TypeSupport.h", 249, __extension__ __PRETTY_FUNCTION__ )); | |||
| 250 | return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, | |||
| 251 | std::forward<Args>(args)...); | |||
| 252 | } | |||
| 253 | ||||
| 254 | /// Register a type instance T with the uniquer. | |||
| 255 | template <typename T> | |||
| 256 | static void registerType(MLIRContext *ctx) { | |||
| 257 | registerType<T>(ctx, T::getTypeID()); | |||
| 258 | } | |||
| 259 | ||||
| 260 | /// Register a parametric type instance T with the uniquer. | |||
| 261 | /// The use of this method is in general discouraged in favor of | |||
| 262 | /// 'registerType<T>(ctx)'. | |||
| 263 | template <typename T> | |||
| 264 | static std::enable_if_t< | |||
| 265 | !std::is_same<typename T::ImplType, TypeStorage>::value> | |||
| 266 | registerType(MLIRContext *ctx, TypeID typeID) { | |||
| 267 | ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>( | |||
| 268 | typeID); | |||
| 269 | } | |||
| 270 | /// Register a singleton type instance T with the uniquer. | |||
| 271 | /// The use of this method is in general discouraged in favor of | |||
| 272 | /// 'registerType<T>(ctx)'. | |||
| 273 | template <typename T> | |||
| 274 | static std::enable_if_t< | |||
| 275 | std::is_same<typename T::ImplType, TypeStorage>::value> | |||
| 276 | registerType(MLIRContext *ctx, TypeID typeID) { | |||
| 277 | ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>( | |||
| 278 | typeID, [&ctx, typeID](TypeStorage *storage) { | |||
| 279 | storage->initialize(AbstractType::lookup(typeID, ctx)); | |||
| 280 | }); | |||
| 281 | } | |||
| 282 | }; | |||
| 283 | } // namespace detail | |||
| 284 | ||||
| 285 | } // namespace mlir | |||
| 286 | ||||
| 287 | #endif |