File: | build/source/mlir/include/mlir/IR/TypeSupport.h |
Warning: | line 46, column 5 Address of stack memory associated with temporary object of type '(lambda at /build/source/mlir/include/mlir/IR/StorageUniquerSupport.h:133:12)' is still referred to by a temporary object on the stack upon returning to the caller. This will be a dangling reference |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BuiltinTypes.cpp - MLIR Builtin Type Classes -----------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/BuiltinTypes.h" | |||
10 | #include "TypeDetail.h" | |||
11 | #include "mlir/IR/AffineExpr.h" | |||
12 | #include "mlir/IR/AffineMap.h" | |||
13 | #include "mlir/IR/BuiltinAttributes.h" | |||
14 | #include "mlir/IR/BuiltinDialect.h" | |||
15 | #include "mlir/IR/Diagnostics.h" | |||
16 | #include "mlir/IR/Dialect.h" | |||
17 | #include "mlir/IR/FunctionInterfaces.h" | |||
18 | #include "mlir/IR/OpImplementation.h" | |||
19 | #include "mlir/IR/TensorEncoding.h" | |||
20 | #include "llvm/ADT/APFloat.h" | |||
21 | #include "llvm/ADT/BitVector.h" | |||
22 | #include "llvm/ADT/Sequence.h" | |||
23 | #include "llvm/ADT/Twine.h" | |||
24 | #include "llvm/ADT/TypeSwitch.h" | |||
25 | ||||
26 | using namespace mlir; | |||
27 | using namespace mlir::detail; | |||
28 | ||||
29 | //===----------------------------------------------------------------------===// | |||
30 | /// Tablegen Type Definitions | |||
31 | //===----------------------------------------------------------------------===// | |||
32 | ||||
33 | #define GET_TYPEDEF_CLASSES | |||
34 | #include "mlir/IR/BuiltinTypes.cpp.inc" | |||
35 | ||||
36 | //===----------------------------------------------------------------------===// | |||
37 | // BuiltinDialect | |||
38 | //===----------------------------------------------------------------------===// | |||
39 | ||||
40 | void BuiltinDialect::registerTypes() { | |||
41 | addTypes< | |||
| ||||
42 | #define GET_TYPEDEF_LIST | |||
43 | #include "mlir/IR/BuiltinTypes.cpp.inc" | |||
44 | >(); | |||
45 | } | |||
46 | ||||
47 | //===----------------------------------------------------------------------===// | |||
48 | /// ComplexType | |||
49 | //===----------------------------------------------------------------------===// | |||
50 | ||||
51 | /// Verify the construction of an integer type. | |||
52 | LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
53 | Type elementType) { | |||
54 | if (!elementType.isIntOrFloat()) | |||
55 | return emitError() << "invalid element type for complex"; | |||
56 | return success(); | |||
57 | } | |||
58 | ||||
59 | //===----------------------------------------------------------------------===// | |||
60 | // Integer Type | |||
61 | //===----------------------------------------------------------------------===// | |||
62 | ||||
63 | /// Verify the construction of an integer type. | |||
64 | LogicalResult IntegerType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
65 | unsigned width, | |||
66 | SignednessSemantics signedness) { | |||
67 | if (width > IntegerType::kMaxWidth) { | |||
68 | return emitError() << "integer bitwidth is limited to " | |||
69 | << IntegerType::kMaxWidth << " bits"; | |||
70 | } | |||
71 | return success(); | |||
72 | } | |||
73 | ||||
74 | unsigned IntegerType::getWidth() const { return getImpl()->width; } | |||
75 | ||||
76 | IntegerType::SignednessSemantics IntegerType::getSignedness() const { | |||
77 | return getImpl()->signedness; | |||
78 | } | |||
79 | ||||
80 | IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { | |||
81 | if (!scale) | |||
82 | return IntegerType(); | |||
83 | return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); | |||
84 | } | |||
85 | ||||
86 | //===----------------------------------------------------------------------===// | |||
87 | // Float Type | |||
88 | //===----------------------------------------------------------------------===// | |||
89 | ||||
90 | unsigned FloatType::getWidth() { | |||
91 | if (isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType, | |||
92 | Float8E4M3FNUZType, Float8E4M3B11FNUZType>()) | |||
93 | return 8; | |||
94 | if (isa<Float16Type, BFloat16Type>()) | |||
95 | return 16; | |||
96 | if (isa<Float32Type>()) | |||
97 | return 32; | |||
98 | if (isa<Float64Type>()) | |||
99 | return 64; | |||
100 | if (isa<Float80Type>()) | |||
101 | return 80; | |||
102 | if (isa<Float128Type>()) | |||
103 | return 128; | |||
104 | llvm_unreachable("unexpected float type")::llvm::llvm_unreachable_internal("unexpected float type", "mlir/lib/IR/BuiltinTypes.cpp" , 104); | |||
105 | } | |||
106 | ||||
107 | /// Returns the floating semantics for the given type. | |||
108 | const llvm::fltSemantics &FloatType::getFloatSemantics() { | |||
109 | if (isa<Float8E5M2Type>()) | |||
110 | return APFloat::Float8E5M2(); | |||
111 | if (isa<Float8E4M3FNType>()) | |||
112 | return APFloat::Float8E4M3FN(); | |||
113 | if (isa<Float8E5M2FNUZType>()) | |||
114 | return APFloat::Float8E5M2FNUZ(); | |||
115 | if (isa<Float8E4M3FNUZType>()) | |||
116 | return APFloat::Float8E4M3FNUZ(); | |||
117 | if (isa<Float8E4M3B11FNUZType>()) | |||
118 | return APFloat::Float8E4M3B11FNUZ(); | |||
119 | if (isa<BFloat16Type>()) | |||
120 | return APFloat::BFloat(); | |||
121 | if (isa<Float16Type>()) | |||
122 | return APFloat::IEEEhalf(); | |||
123 | if (isa<Float32Type>()) | |||
124 | return APFloat::IEEEsingle(); | |||
125 | if (isa<Float64Type>()) | |||
126 | return APFloat::IEEEdouble(); | |||
127 | if (isa<Float80Type>()) | |||
128 | return APFloat::x87DoubleExtended(); | |||
129 | if (isa<Float128Type>()) | |||
130 | return APFloat::IEEEquad(); | |||
131 | llvm_unreachable("non-floating point type used")::llvm::llvm_unreachable_internal("non-floating point type used" , "mlir/lib/IR/BuiltinTypes.cpp", 131); | |||
132 | } | |||
133 | ||||
134 | FloatType FloatType::scaleElementBitwidth(unsigned scale) { | |||
135 | if (!scale) | |||
136 | return FloatType(); | |||
137 | MLIRContext *ctx = getContext(); | |||
138 | if (isF16() || isBF16()) { | |||
139 | if (scale == 2) | |||
140 | return FloatType::getF32(ctx); | |||
141 | if (scale == 4) | |||
142 | return FloatType::getF64(ctx); | |||
143 | } | |||
144 | if (isF32()) | |||
145 | if (scale == 2) | |||
146 | return FloatType::getF64(ctx); | |||
147 | return FloatType(); | |||
148 | } | |||
149 | ||||
150 | unsigned FloatType::getFPMantissaWidth() { | |||
151 | return APFloat::semanticsPrecision(getFloatSemantics()); | |||
152 | } | |||
153 | ||||
154 | //===----------------------------------------------------------------------===// | |||
155 | // FunctionType | |||
156 | //===----------------------------------------------------------------------===// | |||
157 | ||||
158 | unsigned FunctionType::getNumInputs() const { return getImpl()->numInputs; } | |||
159 | ||||
160 | ArrayRef<Type> FunctionType::getInputs() const { | |||
161 | return getImpl()->getInputs(); | |||
162 | } | |||
163 | ||||
164 | unsigned FunctionType::getNumResults() const { return getImpl()->numResults; } | |||
165 | ||||
166 | ArrayRef<Type> FunctionType::getResults() const { | |||
167 | return getImpl()->getResults(); | |||
168 | } | |||
169 | ||||
170 | FunctionType FunctionType::clone(TypeRange inputs, TypeRange results) const { | |||
171 | return get(getContext(), inputs, results); | |||
172 | } | |||
173 | ||||
174 | /// Returns a new function type with the specified arguments and results | |||
175 | /// inserted. | |||
176 | FunctionType FunctionType::getWithArgsAndResults( | |||
177 | ArrayRef<unsigned> argIndices, TypeRange argTypes, | |||
178 | ArrayRef<unsigned> resultIndices, TypeRange resultTypes) { | |||
179 | SmallVector<Type> argStorage, resultStorage; | |||
180 | TypeRange newArgTypes = function_interface_impl::insertTypesInto( | |||
181 | getInputs(), argIndices, argTypes, argStorage); | |||
182 | TypeRange newResultTypes = function_interface_impl::insertTypesInto( | |||
183 | getResults(), resultIndices, resultTypes, resultStorage); | |||
184 | return clone(newArgTypes, newResultTypes); | |||
185 | } | |||
186 | ||||
187 | /// Returns a new function type without the specified arguments and results. | |||
188 | FunctionType | |||
189 | FunctionType::getWithoutArgsAndResults(const BitVector &argIndices, | |||
190 | const BitVector &resultIndices) { | |||
191 | SmallVector<Type> argStorage, resultStorage; | |||
192 | TypeRange newArgTypes = function_interface_impl::filterTypesOut( | |||
193 | getInputs(), argIndices, argStorage); | |||
194 | TypeRange newResultTypes = function_interface_impl::filterTypesOut( | |||
195 | getResults(), resultIndices, resultStorage); | |||
196 | return clone(newArgTypes, newResultTypes); | |||
197 | } | |||
198 | ||||
199 | //===----------------------------------------------------------------------===// | |||
200 | // OpaqueType | |||
201 | //===----------------------------------------------------------------------===// | |||
202 | ||||
203 | /// Verify the construction of an opaque type. | |||
204 | LogicalResult OpaqueType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
205 | StringAttr dialect, StringRef typeData) { | |||
206 | if (!Dialect::isValidNamespace(dialect.strref())) | |||
207 | return emitError() << "invalid dialect namespace '" << dialect << "'"; | |||
208 | ||||
209 | // Check that the dialect is actually registered. | |||
210 | MLIRContext *context = dialect.getContext(); | |||
211 | if (!context->allowsUnregisteredDialects() && | |||
212 | !context->getLoadedDialect(dialect.strref())) { | |||
213 | return emitError() | |||
214 | << "`!" << dialect << "<\"" << typeData << "\">" | |||
215 | << "` type created with unregistered dialect. If this is " | |||
216 | "intended, please call allowUnregisteredDialects() on the " | |||
217 | "MLIRContext, or use -allow-unregistered-dialect with " | |||
218 | "the MLIR opt tool used"; | |||
219 | } | |||
220 | ||||
221 | return success(); | |||
222 | } | |||
223 | ||||
224 | //===----------------------------------------------------------------------===// | |||
225 | // VectorType | |||
226 | //===----------------------------------------------------------------------===// | |||
227 | ||||
228 | LogicalResult VectorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
229 | ArrayRef<int64_t> shape, Type elementType, | |||
230 | unsigned numScalableDims) { | |||
231 | if (!isValidElementType(elementType)) | |||
232 | return emitError() | |||
233 | << "vector elements must be int/index/float type but got " | |||
234 | << elementType; | |||
235 | ||||
236 | if (any_of(shape, [](int64_t i) { return i <= 0; })) | |||
237 | return emitError() | |||
238 | << "vector types must have positive constant sizes but got " | |||
239 | << shape; | |||
240 | ||||
241 | return success(); | |||
242 | } | |||
243 | ||||
244 | VectorType VectorType::scaleElementBitwidth(unsigned scale) { | |||
245 | if (!scale) | |||
246 | return VectorType(); | |||
247 | if (auto et = getElementType().dyn_cast<IntegerType>()) | |||
248 | if (auto scaledEt = et.scaleElementBitwidth(scale)) | |||
249 | return VectorType::get(getShape(), scaledEt, getNumScalableDims()); | |||
250 | if (auto et = getElementType().dyn_cast<FloatType>()) | |||
251 | if (auto scaledEt = et.scaleElementBitwidth(scale)) | |||
252 | return VectorType::get(getShape(), scaledEt, getNumScalableDims()); | |||
253 | return VectorType(); | |||
254 | } | |||
255 | ||||
256 | VectorType VectorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
257 | Type elementType) const { | |||
258 | return VectorType::get(shape.value_or(getShape()), elementType, | |||
259 | getNumScalableDims()); | |||
260 | } | |||
261 | ||||
262 | //===----------------------------------------------------------------------===// | |||
263 | // TensorType | |||
264 | //===----------------------------------------------------------------------===// | |||
265 | ||||
266 | Type TensorType::getElementType() const { | |||
267 | return llvm::TypeSwitch<TensorType, Type>(*this) | |||
268 | .Case<RankedTensorType, UnrankedTensorType>( | |||
269 | [](auto type) { return type.getElementType(); }); | |||
270 | } | |||
271 | ||||
272 | bool TensorType::hasRank() const { return !isa<UnrankedTensorType>(); } | |||
273 | ||||
274 | ArrayRef<int64_t> TensorType::getShape() const { | |||
275 | return cast<RankedTensorType>().getShape(); | |||
276 | } | |||
277 | ||||
278 | TensorType TensorType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
279 | Type elementType) const { | |||
280 | if (auto unrankedTy = dyn_cast<UnrankedTensorType>()) { | |||
281 | if (shape) | |||
282 | return RankedTensorType::get(*shape, elementType); | |||
283 | return UnrankedTensorType::get(elementType); | |||
284 | } | |||
285 | ||||
286 | auto rankedTy = cast<RankedTensorType>(); | |||
287 | if (!shape) | |||
288 | return RankedTensorType::get(rankedTy.getShape(), elementType, | |||
289 | rankedTy.getEncoding()); | |||
290 | return RankedTensorType::get(shape.value_or(rankedTy.getShape()), elementType, | |||
291 | rankedTy.getEncoding()); | |||
292 | } | |||
293 | ||||
294 | // Check if "elementType" can be an element type of a tensor. | |||
295 | static LogicalResult | |||
296 | checkTensorElementType(function_ref<InFlightDiagnostic()> emitError, | |||
297 | Type elementType) { | |||
298 | if (!TensorType::isValidElementType(elementType)) | |||
299 | return emitError() << "invalid tensor element type: " << elementType; | |||
300 | return success(); | |||
301 | } | |||
302 | ||||
303 | /// Return true if the specified element type is ok in a tensor. | |||
304 | bool TensorType::isValidElementType(Type type) { | |||
305 | // Note: Non standard/builtin types are allowed to exist within tensor | |||
306 | // types. Dialects are expected to verify that tensor types have a valid | |||
307 | // element type within that dialect. | |||
308 | return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType, | |||
309 | IndexType>() || | |||
310 | !llvm::isa<BuiltinDialect>(type.getDialect()); | |||
311 | } | |||
312 | ||||
313 | //===----------------------------------------------------------------------===// | |||
314 | // RankedTensorType | |||
315 | //===----------------------------------------------------------------------===// | |||
316 | ||||
317 | LogicalResult | |||
318 | RankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
319 | ArrayRef<int64_t> shape, Type elementType, | |||
320 | Attribute encoding) { | |||
321 | for (int64_t s : shape) | |||
322 | if (s < 0 && !ShapedType::isDynamic(s)) | |||
323 | return emitError() << "invalid tensor dimension size"; | |||
324 | if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) | |||
325 | if (failed(v.verifyEncoding(shape, elementType, emitError))) | |||
326 | return failure(); | |||
327 | return checkTensorElementType(emitError, elementType); | |||
328 | } | |||
329 | ||||
330 | //===----------------------------------------------------------------------===// | |||
331 | // UnrankedTensorType | |||
332 | //===----------------------------------------------------------------------===// | |||
333 | ||||
334 | LogicalResult | |||
335 | UnrankedTensorType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
336 | Type elementType) { | |||
337 | return checkTensorElementType(emitError, elementType); | |||
338 | } | |||
339 | ||||
340 | //===----------------------------------------------------------------------===// | |||
341 | // BaseMemRefType | |||
342 | //===----------------------------------------------------------------------===// | |||
343 | ||||
344 | Type BaseMemRefType::getElementType() const { | |||
345 | return llvm::TypeSwitch<BaseMemRefType, Type>(*this) | |||
346 | .Case<MemRefType, UnrankedMemRefType>( | |||
347 | [](auto type) { return type.getElementType(); }); | |||
348 | } | |||
349 | ||||
350 | bool BaseMemRefType::hasRank() const { return !isa<UnrankedMemRefType>(); } | |||
351 | ||||
352 | ArrayRef<int64_t> BaseMemRefType::getShape() const { | |||
353 | return cast<MemRefType>().getShape(); | |||
354 | } | |||
355 | ||||
356 | BaseMemRefType BaseMemRefType::cloneWith(std::optional<ArrayRef<int64_t>> shape, | |||
357 | Type elementType) const { | |||
358 | if (auto unrankedTy = dyn_cast<UnrankedMemRefType>()) { | |||
359 | if (!shape) | |||
360 | return UnrankedMemRefType::get(elementType, getMemorySpace()); | |||
361 | MemRefType::Builder builder(*shape, elementType); | |||
362 | builder.setMemorySpace(getMemorySpace()); | |||
363 | return builder; | |||
364 | } | |||
365 | ||||
366 | MemRefType::Builder builder(cast<MemRefType>()); | |||
367 | if (shape) | |||
368 | builder.setShape(*shape); | |||
369 | builder.setElementType(elementType); | |||
370 | return builder; | |||
371 | } | |||
372 | ||||
373 | Attribute BaseMemRefType::getMemorySpace() const { | |||
374 | if (auto rankedMemRefTy = dyn_cast<MemRefType>()) | |||
375 | return rankedMemRefTy.getMemorySpace(); | |||
376 | return cast<UnrankedMemRefType>().getMemorySpace(); | |||
377 | } | |||
378 | ||||
379 | unsigned BaseMemRefType::getMemorySpaceAsInt() const { | |||
380 | if (auto rankedMemRefTy = dyn_cast<MemRefType>()) | |||
381 | return rankedMemRefTy.getMemorySpaceAsInt(); | |||
382 | return cast<UnrankedMemRefType>().getMemorySpaceAsInt(); | |||
383 | } | |||
384 | ||||
385 | //===----------------------------------------------------------------------===// | |||
386 | // MemRefType | |||
387 | //===----------------------------------------------------------------------===// | |||
388 | ||||
389 | /// Given an `originalShape` and a `reducedShape` assumed to be a subset of | |||
390 | /// `originalShape` with some `1` entries erased, return the set of indices | |||
391 | /// that specifies which of the entries of `originalShape` are dropped to obtain | |||
392 | /// `reducedShape`. The returned mask can be applied as a projection to | |||
393 | /// `originalShape` to obtain the `reducedShape`. This mask is useful to track | |||
394 | /// which dimensions must be kept when e.g. compute MemRef strides under | |||
395 | /// rank-reducing operations. Return std::nullopt if reducedShape cannot be | |||
396 | /// obtained by dropping only `1` entries in `originalShape`. | |||
397 | std::optional<llvm::SmallDenseSet<unsigned>> | |||
398 | mlir::computeRankReductionMask(ArrayRef<int64_t> originalShape, | |||
399 | ArrayRef<int64_t> reducedShape) { | |||
400 | size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); | |||
401 | llvm::SmallDenseSet<unsigned> unusedDims; | |||
402 | unsigned reducedIdx = 0; | |||
403 | for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { | |||
404 | // Greedily insert `originalIdx` if match. | |||
405 | if (reducedIdx < reducedRank && | |||
406 | originalShape[originalIdx] == reducedShape[reducedIdx]) { | |||
407 | reducedIdx++; | |||
408 | continue; | |||
409 | } | |||
410 | ||||
411 | unusedDims.insert(originalIdx); | |||
412 | // If no match on `originalIdx`, the `originalShape` at this dimension | |||
413 | // must be 1, otherwise we bail. | |||
414 | if (originalShape[originalIdx] != 1) | |||
415 | return std::nullopt; | |||
416 | } | |||
417 | // The whole reducedShape must be scanned, otherwise we bail. | |||
418 | if (reducedIdx != reducedRank) | |||
419 | return std::nullopt; | |||
420 | return unusedDims; | |||
421 | } | |||
422 | ||||
423 | SliceVerificationResult | |||
424 | mlir::isRankReducedType(ShapedType originalType, | |||
425 | ShapedType candidateReducedType) { | |||
426 | if (originalType == candidateReducedType) | |||
427 | return SliceVerificationResult::Success; | |||
428 | ||||
429 | ShapedType originalShapedType = originalType.cast<ShapedType>(); | |||
430 | ShapedType candidateReducedShapedType = | |||
431 | candidateReducedType.cast<ShapedType>(); | |||
432 | ||||
433 | // Rank and size logic is valid for all ShapedTypes. | |||
434 | ArrayRef<int64_t> originalShape = originalShapedType.getShape(); | |||
435 | ArrayRef<int64_t> candidateReducedShape = | |||
436 | candidateReducedShapedType.getShape(); | |||
437 | unsigned originalRank = originalShape.size(), | |||
438 | candidateReducedRank = candidateReducedShape.size(); | |||
439 | if (candidateReducedRank > originalRank) | |||
440 | return SliceVerificationResult::RankTooLarge; | |||
441 | ||||
442 | auto optionalUnusedDimsMask = | |||
443 | computeRankReductionMask(originalShape, candidateReducedShape); | |||
444 | ||||
445 | // Sizes cannot be matched in case empty vector is returned. | |||
446 | if (!optionalUnusedDimsMask) | |||
447 | return SliceVerificationResult::SizeMismatch; | |||
448 | ||||
449 | if (originalShapedType.getElementType() != | |||
450 | candidateReducedShapedType.getElementType()) | |||
451 | return SliceVerificationResult::ElemTypeMismatch; | |||
452 | ||||
453 | return SliceVerificationResult::Success; | |||
454 | } | |||
455 | ||||
456 | bool mlir::detail::isSupportedMemorySpace(Attribute memorySpace) { | |||
457 | // Empty attribute is allowed as default memory space. | |||
458 | if (!memorySpace) | |||
459 | return true; | |||
460 | ||||
461 | // Supported built-in attributes. | |||
462 | if (memorySpace.isa<IntegerAttr, StringAttr, DictionaryAttr>()) | |||
463 | return true; | |||
464 | ||||
465 | // Allow custom dialect attributes. | |||
466 | if (!isa<BuiltinDialect>(memorySpace.getDialect())) | |||
467 | return true; | |||
468 | ||||
469 | return false; | |||
470 | } | |||
471 | ||||
472 | Attribute mlir::detail::wrapIntegerMemorySpace(unsigned memorySpace, | |||
473 | MLIRContext *ctx) { | |||
474 | if (memorySpace == 0) | |||
475 | return nullptr; | |||
476 | ||||
477 | return IntegerAttr::get(IntegerType::get(ctx, 64), memorySpace); | |||
478 | } | |||
479 | ||||
480 | Attribute mlir::detail::skipDefaultMemorySpace(Attribute memorySpace) { | |||
481 | IntegerAttr intMemorySpace = memorySpace.dyn_cast_or_null<IntegerAttr>(); | |||
482 | if (intMemorySpace && intMemorySpace.getValue() == 0) | |||
483 | return nullptr; | |||
484 | ||||
485 | return memorySpace; | |||
486 | } | |||
487 | ||||
488 | unsigned mlir::detail::getMemorySpaceAsInt(Attribute memorySpace) { | |||
489 | if (!memorySpace) | |||
490 | return 0; | |||
491 | ||||
492 | assert(memorySpace.isa<IntegerAttr>() &&(static_cast <bool> (memorySpace.isa<IntegerAttr> () && "Using `getMemorySpaceInteger` with non-Integer attribute" ) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\"" , "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__ )) | |||
493 | "Using `getMemorySpaceInteger` with non-Integer attribute")(static_cast <bool> (memorySpace.isa<IntegerAttr> () && "Using `getMemorySpaceInteger` with non-Integer attribute" ) ? void (0) : __assert_fail ("memorySpace.isa<IntegerAttr>() && \"Using `getMemorySpaceInteger` with non-Integer attribute\"" , "mlir/lib/IR/BuiltinTypes.cpp", 493, __extension__ __PRETTY_FUNCTION__ )); | |||
494 | ||||
495 | return static_cast<unsigned>(memorySpace.cast<IntegerAttr>().getInt()); | |||
496 | } | |||
497 | ||||
498 | unsigned MemRefType::getMemorySpaceAsInt() const { | |||
499 | return detail::getMemorySpaceAsInt(getMemorySpace()); | |||
500 | } | |||
501 | ||||
502 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
503 | MemRefLayoutAttrInterface layout, | |||
504 | Attribute memorySpace) { | |||
505 | // Use default layout for empty attribute. | |||
506 | if (!layout) | |||
507 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( | |||
508 | shape.size(), elementType.getContext())); | |||
509 | ||||
510 | // Drop default memory space value and replace it with empty attribute. | |||
511 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
512 | ||||
513 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
514 | memorySpace); | |||
515 | } | |||
516 | ||||
517 | MemRefType MemRefType::getChecked( | |||
518 | function_ref<InFlightDiagnostic()> emitErrorFn, ArrayRef<int64_t> shape, | |||
519 | Type elementType, MemRefLayoutAttrInterface layout, Attribute memorySpace) { | |||
520 | ||||
521 | // Use default layout for empty attribute. | |||
522 | if (!layout) | |||
523 | layout = AffineMapAttr::get(AffineMap::getMultiDimIdentityMap( | |||
524 | shape.size(), elementType.getContext())); | |||
525 | ||||
526 | // Drop default memory space value and replace it with empty attribute. | |||
527 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
528 | ||||
529 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
530 | elementType, layout, memorySpace); | |||
531 | } | |||
532 | ||||
533 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
534 | AffineMap map, Attribute memorySpace) { | |||
535 | ||||
536 | // Use default layout for empty map. | |||
537 | if (!map) | |||
538 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
539 | elementType.getContext()); | |||
540 | ||||
541 | // Wrap AffineMap into Attribute. | |||
542 | Attribute layout = AffineMapAttr::get(map); | |||
543 | ||||
544 | // Drop default memory space value and replace it with empty attribute. | |||
545 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
546 | ||||
547 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
548 | memorySpace); | |||
549 | } | |||
550 | ||||
551 | MemRefType | |||
552 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, | |||
553 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, | |||
554 | Attribute memorySpace) { | |||
555 | ||||
556 | // Use default layout for empty map. | |||
557 | if (!map) | |||
558 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
559 | elementType.getContext()); | |||
560 | ||||
561 | // Wrap AffineMap into Attribute. | |||
562 | Attribute layout = AffineMapAttr::get(map); | |||
563 | ||||
564 | // Drop default memory space value and replace it with empty attribute. | |||
565 | memorySpace = skipDefaultMemorySpace(memorySpace); | |||
566 | ||||
567 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
568 | elementType, layout, memorySpace); | |||
569 | } | |||
570 | ||||
571 | MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType, | |||
572 | AffineMap map, unsigned memorySpaceInd) { | |||
573 | ||||
574 | // Use default layout for empty map. | |||
575 | if (!map) | |||
576 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
577 | elementType.getContext()); | |||
578 | ||||
579 | // Wrap AffineMap into Attribute. | |||
580 | Attribute layout = AffineMapAttr::get(map); | |||
581 | ||||
582 | // Convert deprecated integer-like memory space to Attribute. | |||
583 | Attribute memorySpace = | |||
584 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); | |||
585 | ||||
586 | return Base::get(elementType.getContext(), shape, elementType, layout, | |||
587 | memorySpace); | |||
588 | } | |||
589 | ||||
590 | MemRefType | |||
591 | MemRefType::getChecked(function_ref<InFlightDiagnostic()> emitErrorFn, | |||
592 | ArrayRef<int64_t> shape, Type elementType, AffineMap map, | |||
593 | unsigned memorySpaceInd) { | |||
594 | ||||
595 | // Use default layout for empty map. | |||
596 | if (!map) | |||
597 | map = AffineMap::getMultiDimIdentityMap(shape.size(), | |||
598 | elementType.getContext()); | |||
599 | ||||
600 | // Wrap AffineMap into Attribute. | |||
601 | Attribute layout = AffineMapAttr::get(map); | |||
602 | ||||
603 | // Convert deprecated integer-like memory space to Attribute. | |||
604 | Attribute memorySpace = | |||
605 | wrapIntegerMemorySpace(memorySpaceInd, elementType.getContext()); | |||
606 | ||||
607 | return Base::getChecked(emitErrorFn, elementType.getContext(), shape, | |||
608 | elementType, layout, memorySpace); | |||
609 | } | |||
610 | ||||
611 | LogicalResult MemRefType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
612 | ArrayRef<int64_t> shape, Type elementType, | |||
613 | MemRefLayoutAttrInterface layout, | |||
614 | Attribute memorySpace) { | |||
615 | if (!BaseMemRefType::isValidElementType(elementType)) | |||
616 | return emitError() << "invalid memref element type"; | |||
617 | ||||
618 | // Negative sizes are not allowed except for `kDynamic`. | |||
619 | for (int64_t s : shape) | |||
620 | if (s < 0 && !ShapedType::isDynamic(s)) | |||
621 | return emitError() << "invalid memref size"; | |||
622 | ||||
623 | assert(layout && "missing layout specification")(static_cast <bool> (layout && "missing layout specification" ) ? void (0) : __assert_fail ("layout && \"missing layout specification\"" , "mlir/lib/IR/BuiltinTypes.cpp", 623, __extension__ __PRETTY_FUNCTION__ )); | |||
624 | if (failed(layout.verifyLayout(shape, emitError))) | |||
625 | return failure(); | |||
626 | ||||
627 | if (!isSupportedMemorySpace(memorySpace)) | |||
628 | return emitError() << "unsupported memory space Attribute"; | |||
629 | ||||
630 | return success(); | |||
631 | } | |||
632 | ||||
633 | //===----------------------------------------------------------------------===// | |||
634 | // UnrankedMemRefType | |||
635 | //===----------------------------------------------------------------------===// | |||
636 | ||||
637 | unsigned UnrankedMemRefType::getMemorySpaceAsInt() const { | |||
638 | return detail::getMemorySpaceAsInt(getMemorySpace()); | |||
639 | } | |||
640 | ||||
641 | LogicalResult | |||
642 | UnrankedMemRefType::verify(function_ref<InFlightDiagnostic()> emitError, | |||
643 | Type elementType, Attribute memorySpace) { | |||
644 | if (!BaseMemRefType::isValidElementType(elementType)) | |||
645 | return emitError() << "invalid memref element type"; | |||
646 | ||||
647 | if (!isSupportedMemorySpace(memorySpace)) | |||
648 | return emitError() << "unsupported memory space Attribute"; | |||
649 | ||||
650 | return success(); | |||
651 | } | |||
652 | ||||
653 | // Fallback cases for terminal dim/sym/cst that are not part of a binary op ( | |||
654 | // i.e. single term). Accumulate the AffineExpr into the existing one. | |||
655 | static void extractStridesFromTerm(AffineExpr e, | |||
656 | AffineExpr multiplicativeFactor, | |||
657 | MutableArrayRef<AffineExpr> strides, | |||
658 | AffineExpr &offset) { | |||
659 | if (auto dim = e.dyn_cast<AffineDimExpr>()) | |||
660 | strides[dim.getPosition()] = | |||
661 | strides[dim.getPosition()] + multiplicativeFactor; | |||
662 | else | |||
663 | offset = offset + e * multiplicativeFactor; | |||
664 | } | |||
665 | ||||
666 | /// Takes a single AffineExpr `e` and populates the `strides` array with the | |||
667 | /// strides expressions for each dim position. | |||
668 | /// The convention is that the strides for dimensions d0, .. dn appear in | |||
669 | /// order to make indexing intuitive into the result. | |||
670 | static LogicalResult extractStrides(AffineExpr e, | |||
671 | AffineExpr multiplicativeFactor, | |||
672 | MutableArrayRef<AffineExpr> strides, | |||
673 | AffineExpr &offset) { | |||
674 | auto bin = e.dyn_cast<AffineBinaryOpExpr>(); | |||
675 | if (!bin) { | |||
676 | extractStridesFromTerm(e, multiplicativeFactor, strides, offset); | |||
677 | return success(); | |||
678 | } | |||
679 | ||||
680 | if (bin.getKind() == AffineExprKind::CeilDiv || | |||
681 | bin.getKind() == AffineExprKind::FloorDiv || | |||
682 | bin.getKind() == AffineExprKind::Mod) | |||
683 | return failure(); | |||
684 | ||||
685 | if (bin.getKind() == AffineExprKind::Mul) { | |||
686 | auto dim = bin.getLHS().dyn_cast<AffineDimExpr>(); | |||
687 | if (dim) { | |||
688 | strides[dim.getPosition()] = | |||
689 | strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor; | |||
690 | return success(); | |||
691 | } | |||
692 | // LHS and RHS may both contain complex expressions of dims. Try one path | |||
693 | // and if it fails try the other. This is guaranteed to succeed because | |||
694 | // only one path may have a `dim`, otherwise this is not an AffineExpr in | |||
695 | // the first place. | |||
696 | if (bin.getLHS().isSymbolicOrConstant()) | |||
697 | return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(), | |||
698 | strides, offset); | |||
699 | return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(), | |||
700 | strides, offset); | |||
701 | } | |||
702 | ||||
703 | if (bin.getKind() == AffineExprKind::Add) { | |||
704 | auto res1 = | |||
705 | extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset); | |||
706 | auto res2 = | |||
707 | extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset); | |||
708 | return success(succeeded(res1) && succeeded(res2)); | |||
709 | } | |||
710 | ||||
711 | llvm_unreachable("unexpected binary operation")::llvm::llvm_unreachable_internal("unexpected binary operation" , "mlir/lib/IR/BuiltinTypes.cpp", 711); | |||
712 | } | |||
713 | ||||
714 | /// A stride specification is a list of integer values that are either static | |||
715 | /// or dynamic (encoded with ShapedType::kDynamic). Strides encode | |||
716 | /// the distance in the number of elements between successive entries along a | |||
717 | /// particular dimension. | |||
718 | /// | |||
719 | /// For example, `memref<42x16xf32, (64 * d0 + d1)>` specifies a view into a | |||
720 | /// non-contiguous memory region of `42` by `16` `f32` elements in which the | |||
721 | /// distance between two consecutive elements along the outer dimension is `1` | |||
722 | /// and the distance between two consecutive elements along the inner dimension | |||
723 | /// is `64`. | |||
724 | /// | |||
725 | /// The convention is that the strides for dimensions d0, .. dn appear in | |||
726 | /// order to make indexing intuitive into the result. | |||
727 | static LogicalResult getStridesAndOffset(MemRefType t, | |||
728 | SmallVectorImpl<AffineExpr> &strides, | |||
729 | AffineExpr &offset) { | |||
730 | AffineMap m = t.getLayout().getAffineMap(); | |||
731 | ||||
732 | if (m.getNumResults() != 1 && !m.isIdentity()) | |||
733 | return failure(); | |||
734 | ||||
735 | auto zero = getAffineConstantExpr(0, t.getContext()); | |||
736 | auto one = getAffineConstantExpr(1, t.getContext()); | |||
737 | offset = zero; | |||
738 | strides.assign(t.getRank(), zero); | |||
739 | ||||
740 | // Canonical case for empty map. | |||
741 | if (m.isIdentity()) { | |||
742 | // 0-D corner case, offset is already 0. | |||
743 | if (t.getRank() == 0) | |||
744 | return success(); | |||
745 | auto stridedExpr = | |||
746 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); | |||
747 | if (succeeded(extractStrides(stridedExpr, one, strides, offset))) | |||
748 | return success(); | |||
749 | assert(false && "unexpected failure: extract strides in canonical layout")(static_cast <bool> (false && "unexpected failure: extract strides in canonical layout" ) ? void (0) : __assert_fail ("false && \"unexpected failure: extract strides in canonical layout\"" , "mlir/lib/IR/BuiltinTypes.cpp", 749, __extension__ __PRETTY_FUNCTION__ )); | |||
750 | } | |||
751 | ||||
752 | // Non-canonical case requires more work. | |||
753 | auto stridedExpr = | |||
754 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); | |||
755 | if (failed(extractStrides(stridedExpr, one, strides, offset))) { | |||
756 | offset = AffineExpr(); | |||
757 | strides.clear(); | |||
758 | return failure(); | |||
759 | } | |||
760 | ||||
761 | // Simplify results to allow folding to constants and simple checks. | |||
762 | unsigned numDims = m.getNumDims(); | |||
763 | unsigned numSymbols = m.getNumSymbols(); | |||
764 | offset = simplifyAffineExpr(offset, numDims, numSymbols); | |||
765 | for (auto &stride : strides) | |||
766 | stride = simplifyAffineExpr(stride, numDims, numSymbols); | |||
767 | ||||
768 | // In practice, a strided memref must be internally non-aliasing. Test | |||
769 | // against 0 as a proxy. | |||
770 | // TODO: static cases can have more advanced checks. | |||
771 | // TODO: dynamic cases would require a way to compare symbolic | |||
772 | // expressions and would probably need an affine set context propagated | |||
773 | // everywhere. | |||
774 | if (llvm::any_of(strides, [](AffineExpr e) { | |||
775 | return e == getAffineConstantExpr(0, e.getContext()); | |||
776 | })) { | |||
777 | offset = AffineExpr(); | |||
778 | strides.clear(); | |||
779 | return failure(); | |||
780 | } | |||
781 | ||||
782 | return success(); | |||
783 | } | |||
784 | ||||
785 | LogicalResult mlir::getStridesAndOffset(MemRefType t, | |||
786 | SmallVectorImpl<int64_t> &strides, | |||
787 | int64_t &offset) { | |||
788 | // Happy path: the type uses the strided layout directly. | |||
789 | if (auto strided = t.getLayout().dyn_cast<StridedLayoutAttr>()) { | |||
790 | llvm::append_range(strides, strided.getStrides()); | |||
791 | offset = strided.getOffset(); | |||
792 | return success(); | |||
793 | } | |||
794 | ||||
795 | // Otherwise, defer to the affine fallback as layouts are supposed to be | |||
796 | // convertible to affine maps. | |||
797 | AffineExpr offsetExpr; | |||
798 | SmallVector<AffineExpr, 4> strideExprs; | |||
799 | if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr))) | |||
800 | return failure(); | |||
801 | if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>()) | |||
802 | offset = cst.getValue(); | |||
803 | else | |||
804 | offset = ShapedType::kDynamic; | |||
805 | for (auto e : strideExprs) { | |||
806 | if (auto c = e.dyn_cast<AffineConstantExpr>()) | |||
807 | strides.push_back(c.getValue()); | |||
808 | else | |||
809 | strides.push_back(ShapedType::kDynamic); | |||
810 | } | |||
811 | return success(); | |||
812 | } | |||
813 | ||||
814 | std::pair<SmallVector<int64_t>, int64_t> | |||
815 | mlir::getStridesAndOffset(MemRefType t) { | |||
816 | SmallVector<int64_t> strides; | |||
817 | int64_t offset; | |||
818 | LogicalResult status = getStridesAndOffset(t, strides, offset); | |||
819 | (void)status; | |||
820 | assert(succeeded(status) && "Invalid use of check-free getStridesAndOffset")(static_cast <bool> (succeeded(status) && "Invalid use of check-free getStridesAndOffset" ) ? void (0) : __assert_fail ("succeeded(status) && \"Invalid use of check-free getStridesAndOffset\"" , "mlir/lib/IR/BuiltinTypes.cpp", 820, __extension__ __PRETTY_FUNCTION__ )); | |||
821 | return {strides, offset}; | |||
822 | } | |||
823 | ||||
824 | //===----------------------------------------------------------------------===// | |||
825 | /// TupleType | |||
826 | //===----------------------------------------------------------------------===// | |||
827 | ||||
828 | /// Return the elements types for this tuple. | |||
829 | ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); } | |||
830 | ||||
831 | /// Accumulate the types contained in this tuple and tuples nested within it. | |||
832 | /// Note that this only flattens nested tuples, not any other container type, | |||
833 | /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to | |||
834 | /// (i32, tensor<i32>, f32, i64) | |||
835 | void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) { | |||
836 | for (Type type : getTypes()) { | |||
837 | if (auto nestedTuple = type.dyn_cast<TupleType>()) | |||
838 | nestedTuple.getFlattenedTypes(types); | |||
839 | else | |||
840 | types.push_back(type); | |||
841 | } | |||
842 | } | |||
843 | ||||
844 | /// Return the number of element types. | |||
845 | size_t TupleType::size() const { return getImpl()->size(); } | |||
846 | ||||
847 | //===----------------------------------------------------------------------===// | |||
848 | // Type Utilities | |||
849 | //===----------------------------------------------------------------------===// | |||
850 | ||||
851 | /// Return a version of `t` with identity layout if it can be determined | |||
852 | /// statically that the layout is the canonical contiguous strided layout. | |||
853 | /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of | |||
854 | /// `t` with simplified layout. | |||
855 | /// If `t` has multiple layout maps or a multi-result layout, just return `t`. | |||
856 | MemRefType mlir::canonicalizeStridedLayout(MemRefType t) { | |||
857 | AffineMap m = t.getLayout().getAffineMap(); | |||
858 | ||||
859 | // Already in canonical form. | |||
860 | if (m.isIdentity()) | |||
861 | return t; | |||
862 | ||||
863 | // Can't reduce to canonical identity form, return in canonical form. | |||
864 | if (m.getNumResults() > 1) | |||
865 | return t; | |||
866 | ||||
867 | // Corner-case for 0-D affine maps. | |||
868 | if (m.getNumDims() == 0 && m.getNumSymbols() == 0) { | |||
869 | if (auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>()) | |||
870 | if (cst.getValue() == 0) | |||
871 | return MemRefType::Builder(t).setLayout({}); | |||
872 | return t; | |||
873 | } | |||
874 | ||||
875 | // 0-D corner case for empty shape that still have an affine map. Example: | |||
876 | // `memref<f32, affine_map<()[s0] -> (s0)>>`. This is a 1 element memref whose | |||
877 | // offset needs to remain, just return t. | |||
878 | if (t.getShape().empty()) | |||
879 | return t; | |||
880 | ||||
881 | // If the canonical strided layout for the sizes of `t` is equal to the | |||
882 | // simplified layout of `t` we can just return an empty layout. Otherwise, | |||
883 | // just simplify the existing layout. | |||
884 | AffineExpr expr = | |||
885 | makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext()); | |||
886 | auto simplifiedLayoutExpr = | |||
887 | simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols()); | |||
888 | if (expr != simplifiedLayoutExpr) | |||
889 | return MemRefType::Builder(t).setLayout(AffineMapAttr::get(AffineMap::get( | |||
890 | m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr))); | |||
891 | return MemRefType::Builder(t).setLayout({}); | |||
892 | } | |||
893 | ||||
894 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, | |||
895 | ArrayRef<AffineExpr> exprs, | |||
896 | MLIRContext *context) { | |||
897 | // Size 0 corner case is useful for canonicalizations. | |||
898 | if (sizes.empty()) | |||
899 | return getAffineConstantExpr(0, context); | |||
900 | ||||
901 | assert(!exprs.empty() && "expected exprs")(static_cast <bool> (!exprs.empty() && "expected exprs" ) ? void (0) : __assert_fail ("!exprs.empty() && \"expected exprs\"" , "mlir/lib/IR/BuiltinTypes.cpp", 901, __extension__ __PRETTY_FUNCTION__ )); | |||
902 | auto maps = AffineMap::inferFromExprList(exprs); | |||
903 | assert(!maps.empty() && "Expected one non-empty map")(static_cast <bool> (!maps.empty() && "Expected one non-empty map" ) ? void (0) : __assert_fail ("!maps.empty() && \"Expected one non-empty map\"" , "mlir/lib/IR/BuiltinTypes.cpp", 903, __extension__ __PRETTY_FUNCTION__ )); | |||
904 | unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols(); | |||
905 | ||||
906 | AffineExpr expr; | |||
907 | bool dynamicPoisonBit = false; | |||
908 | int64_t runningSize = 1; | |||
909 | for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) { | |||
910 | int64_t size = std::get<1>(en); | |||
911 | AffineExpr dimExpr = std::get<0>(en); | |||
912 | AffineExpr stride = dynamicPoisonBit | |||
913 | ? getAffineSymbolExpr(nSymbols++, context) | |||
914 | : getAffineConstantExpr(runningSize, context); | |||
915 | expr = expr ? expr + dimExpr * stride : dimExpr * stride; | |||
916 | if (size > 0) { | |||
917 | runningSize *= size; | |||
918 | assert(runningSize > 0 && "integer overflow in size computation")(static_cast <bool> (runningSize > 0 && "integer overflow in size computation" ) ? void (0) : __assert_fail ("runningSize > 0 && \"integer overflow in size computation\"" , "mlir/lib/IR/BuiltinTypes.cpp", 918, __extension__ __PRETTY_FUNCTION__ )); | |||
919 | } else { | |||
920 | dynamicPoisonBit = true; | |||
921 | } | |||
922 | } | |||
923 | return simplifyAffineExpr(expr, numDims, nSymbols); | |||
924 | } | |||
925 | ||||
926 | AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes, | |||
927 | MLIRContext *context) { | |||
928 | SmallVector<AffineExpr, 4> exprs; | |||
929 | exprs.reserve(sizes.size()); | |||
930 | for (auto dim : llvm::seq<unsigned>(0, sizes.size())) | |||
931 | exprs.push_back(getAffineDimExpr(dim, context)); | |||
932 | return makeCanonicalStridedLayoutExpr(sizes, exprs, context); | |||
933 | } | |||
934 | ||||
935 | /// Return true if the layout for `t` is compatible with strided semantics. | |||
936 | bool mlir::isStrided(MemRefType t) { | |||
937 | int64_t offset; | |||
938 | SmallVector<int64_t, 4> strides; | |||
939 | auto res = getStridesAndOffset(t, strides, offset); | |||
940 | return succeeded(res); | |||
941 | } |
1 | //===- Dialect.h - IR Dialect Description -----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the 'dialect' abstraction. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_DIALECT_H |
14 | #define MLIR_IR_DIALECT_H |
15 | |
16 | #include "mlir/IR/DialectRegistry.h" |
17 | #include "mlir/IR/OperationSupport.h" |
18 | #include "mlir/Support/TypeID.h" |
19 | |
20 | #include <map> |
21 | #include <tuple> |
22 | |
23 | namespace mlir { |
24 | class DialectAsmParser; |
25 | class DialectAsmPrinter; |
26 | class DialectInterface; |
27 | class OpBuilder; |
28 | class Type; |
29 | |
30 | //===----------------------------------------------------------------------===// |
31 | // Dialect |
32 | //===----------------------------------------------------------------------===// |
33 | |
34 | /// Dialects are groups of MLIR operations, types and attributes, as well as |
35 | /// behavior associated with the entire group. For example, hooks into other |
36 | /// systems for constant folding, interfaces, default named types for asm |
37 | /// printing, etc. |
38 | /// |
39 | /// Instances of the dialect object are loaded in a specific MLIRContext. |
40 | /// |
41 | class Dialect { |
42 | public: |
43 | /// Type for a callback provided by the dialect to parse a custom operation. |
44 | /// This is used for the dialect to provide an alternative way to parse custom |
45 | /// operations, including unregistered ones. |
46 | using ParseOpHook = |
47 | function_ref<ParseResult(OpAsmParser &parser, OperationState &result)>; |
48 | |
49 | virtual ~Dialect(); |
50 | |
51 | /// Utility function that returns if the given string is a valid dialect |
52 | /// namespace |
53 | static bool isValidNamespace(StringRef str); |
54 | |
55 | MLIRContext *getContext() const { return context; } |
56 | |
57 | StringRef getNamespace() const { return name; } |
58 | |
59 | /// Returns the unique identifier that corresponds to this dialect. |
60 | TypeID getTypeID() const { return dialectID; } |
61 | |
62 | /// Returns true if this dialect allows for unregistered operations, i.e. |
63 | /// operations prefixed with the dialect namespace but not registered with |
64 | /// addOperation. |
65 | bool allowsUnknownOperations() const { return unknownOpsAllowed; } |
66 | |
67 | /// Return true if this dialect allows for unregistered types, i.e., types |
68 | /// prefixed with the dialect namespace but not registered with addType. |
69 | /// These are represented with OpaqueType. |
70 | bool allowsUnknownTypes() const { return unknownTypesAllowed; } |
71 | |
72 | /// Register dialect-wide canonicalization patterns. This method should only |
73 | /// be used to register canonicalization patterns that do not conceptually |
74 | /// belong to any single operation in the dialect. (In that case, use the op's |
75 | /// canonicalizer.) E.g., canonicalization patterns for op interfaces should |
76 | /// be registered here. |
77 | virtual void getCanonicalizationPatterns(RewritePatternSet &results) const {} |
78 | |
79 | /// Registered hook to materialize a single constant operation from a given |
80 | /// attribute value with the desired resultant type. This method should use |
81 | /// the provided builder to create the operation without changing the |
82 | /// insertion position. The generated operation is expected to be constant |
83 | /// like, i.e. single result, zero operands, non side-effecting, etc. On |
84 | /// success, this hook should return the value generated to represent the |
85 | /// constant value. Otherwise, it should return null on failure. |
86 | virtual Operation *materializeConstant(OpBuilder &builder, Attribute value, |
87 | Type type, Location loc) { |
88 | return nullptr; |
89 | } |
90 | |
91 | //===--------------------------------------------------------------------===// |
92 | // Parsing Hooks |
93 | //===--------------------------------------------------------------------===// |
94 | |
95 | /// Parse an attribute registered to this dialect. If 'type' is nonnull, it |
96 | /// refers to the expected type of the attribute. |
97 | virtual Attribute parseAttribute(DialectAsmParser &parser, Type type) const; |
98 | |
99 | /// Print an attribute registered to this dialect. Note: The type of the |
100 | /// attribute need not be printed by this method as it is always printed by |
101 | /// the caller. |
102 | virtual void printAttribute(Attribute, DialectAsmPrinter &) const { |
103 | llvm_unreachable("dialect has no registered attribute printing hook")::llvm::llvm_unreachable_internal("dialect has no registered attribute printing hook" , "mlir/include/mlir/IR/Dialect.h", 103); |
104 | } |
105 | |
106 | /// Parse a type registered to this dialect. |
107 | virtual Type parseType(DialectAsmParser &parser) const; |
108 | |
109 | /// Print a type registered to this dialect. |
110 | virtual void printType(Type, DialectAsmPrinter &) const { |
111 | llvm_unreachable("dialect has no registered type printing hook")::llvm::llvm_unreachable_internal("dialect has no registered type printing hook" , "mlir/include/mlir/IR/Dialect.h", 111); |
112 | } |
113 | |
114 | /// Return the hook to parse an operation registered to this dialect, if any. |
115 | /// By default this will lookup for registered operations and return the |
116 | /// `parse()` method registered on the RegisteredOperationName. Dialects can |
117 | /// override this behavior and handle unregistered operations as well. |
118 | virtual std::optional<ParseOpHook> |
119 | getParseOperationHook(StringRef opName) const; |
120 | |
121 | /// Print an operation registered to this dialect. |
122 | /// This hook is invoked for registered operation which don't override the |
123 | /// `print()` method to define their own custom assembly. |
124 | virtual llvm::unique_function<void(Operation *, OpAsmPrinter &printer)> |
125 | getOperationPrinter(Operation *op) const; |
126 | |
127 | //===--------------------------------------------------------------------===// |
128 | // Verification Hooks |
129 | //===--------------------------------------------------------------------===// |
130 | |
131 | /// Verify an attribute from this dialect on the argument at 'argIndex' for |
132 | /// the region at 'regionIndex' on the given operation. Returns failure if |
133 | /// the verification failed, success otherwise. This hook may optionally be |
134 | /// invoked from any operation containing a region. |
135 | virtual LogicalResult verifyRegionArgAttribute(Operation *, |
136 | unsigned regionIndex, |
137 | unsigned argIndex, |
138 | NamedAttribute); |
139 | |
140 | /// Verify an attribute from this dialect on the result at 'resultIndex' for |
141 | /// the region at 'regionIndex' on the given operation. Returns failure if |
142 | /// the verification failed, success otherwise. This hook may optionally be |
143 | /// invoked from any operation containing a region. |
144 | virtual LogicalResult verifyRegionResultAttribute(Operation *, |
145 | unsigned regionIndex, |
146 | unsigned resultIndex, |
147 | NamedAttribute); |
148 | |
149 | /// Verify an attribute from this dialect on the given operation. Returns |
150 | /// failure if the verification failed, success otherwise. |
151 | virtual LogicalResult verifyOperationAttribute(Operation *, NamedAttribute) { |
152 | return success(); |
153 | } |
154 | |
155 | //===--------------------------------------------------------------------===// |
156 | // Interfaces |
157 | //===--------------------------------------------------------------------===// |
158 | |
159 | /// Lookup an interface for the given ID if one is registered, otherwise |
160 | /// nullptr. |
161 | DialectInterface *getRegisteredInterface(TypeID interfaceID) { |
162 | auto it = registeredInterfaces.find(interfaceID); |
163 | return it != registeredInterfaces.end() ? it->getSecond().get() : nullptr; |
164 | } |
165 | template <typename InterfaceT> |
166 | InterfaceT *getRegisteredInterface() { |
167 | return static_cast<InterfaceT *>( |
168 | getRegisteredInterface(InterfaceT::getInterfaceID())); |
169 | } |
170 | |
171 | /// Lookup an op interface for the given ID if one is registered, otherwise |
172 | /// nullptr. |
173 | virtual void *getRegisteredInterfaceForOp(TypeID interfaceID, |
174 | OperationName opName) { |
175 | return nullptr; |
176 | } |
177 | template <typename InterfaceT> |
178 | typename InterfaceT::Concept * |
179 | getRegisteredInterfaceForOp(OperationName opName) { |
180 | return static_cast<typename InterfaceT::Concept *>( |
181 | getRegisteredInterfaceForOp(InterfaceT::getInterfaceID(), opName)); |
182 | } |
183 | |
184 | /// Register a dialect interface with this dialect instance. |
185 | void addInterface(std::unique_ptr<DialectInterface> interface); |
186 | |
187 | /// Register a set of dialect interfaces with this dialect instance. |
188 | template <typename... Args> |
189 | void addInterfaces() { |
190 | (addInterface(std::make_unique<Args>(this)), ...); |
191 | } |
192 | template <typename InterfaceT, typename... Args> |
193 | InterfaceT &addInterface(Args &&...args) { |
194 | InterfaceT *interface = new InterfaceT(this, std::forward<Args>(args)...); |
195 | addInterface(std::unique_ptr<DialectInterface>(interface)); |
196 | return *interface; |
197 | } |
198 | |
199 | protected: |
200 | /// The constructor takes a unique namespace for this dialect as well as the |
201 | /// context to bind to. |
202 | /// Note: The namespace must not contain '.' characters. |
203 | /// Note: All operations belonging to this dialect must have names starting |
204 | /// with the namespace followed by '.'. |
205 | /// Example: |
206 | /// - "tf" for the TensorFlow ops like "tf.add". |
207 | Dialect(StringRef name, MLIRContext *context, TypeID id); |
208 | |
209 | /// This method is used by derived classes to add their operations to the set. |
210 | /// |
211 | template <typename... Args> |
212 | void addOperations() { |
213 | // This initializer_list argument pack expansion is essentially equal to |
214 | // using a fold expression with a comma operator. Clang however, refuses |
215 | // to compile a fold expression with a depth of more than 256 by default. |
216 | // There seem to be no such limitations for initializer_list. |
217 | (void)std::initializer_list<int>{ |
218 | 0, (RegisteredOperationName::insert<Args>(*this), 0)...}; |
219 | } |
220 | |
221 | /// Register a set of type classes with this dialect. |
222 | template <typename... Args> |
223 | void addTypes() { |
224 | (addType<Args>(), ...); |
225 | } |
226 | |
227 | /// Register a type instance with this dialect. |
228 | /// The use of this method is in general discouraged in favor of |
229 | /// 'addTypes<CustomType>()'. |
230 | void addType(TypeID typeID, AbstractType &&typeInfo); |
231 | |
232 | /// Register a set of attribute classes with this dialect. |
233 | template <typename... Args> |
234 | void addAttributes() { |
235 | (addAttribute<Args>(), ...); |
236 | } |
237 | |
238 | /// Register an attribute instance with this dialect. |
239 | /// The use of this method is in general discouraged in favor of |
240 | /// 'addAttributes<CustomAttr>()'. |
241 | void addAttribute(TypeID typeID, AbstractAttribute &&attrInfo); |
242 | |
243 | /// Enable support for unregistered operations. |
244 | void allowUnknownOperations(bool allow = true) { unknownOpsAllowed = allow; } |
245 | |
246 | /// Enable support for unregistered types. |
247 | void allowUnknownTypes(bool allow = true) { unknownTypesAllowed = allow; } |
248 | |
249 | private: |
250 | Dialect(const Dialect &) = delete; |
251 | void operator=(Dialect &) = delete; |
252 | |
253 | /// Register an attribute instance with this dialect. |
254 | template <typename T> |
255 | void addAttribute() { |
256 | // Add this attribute to the dialect and register it with the uniquer. |
257 | addAttribute(T::getTypeID(), AbstractAttribute::get<T>(*this)); |
258 | detail::AttributeUniquer::registerAttribute<T>(context); |
259 | } |
260 | |
261 | /// Register a type instance with this dialect. |
262 | template <typename T> |
263 | void addType() { |
264 | // Add this type to the dialect and register it with the uniquer. |
265 | addType(T::getTypeID(), AbstractType::get<T>(*this)); |
266 | detail::TypeUniquer::registerType<T>(context); |
267 | } |
268 | |
269 | /// The namespace of this dialect. |
270 | StringRef name; |
271 | |
272 | /// The unique identifier of the derived Op class, this is used in the context |
273 | /// to allow registering multiple times the same dialect. |
274 | TypeID dialectID; |
275 | |
276 | /// This is the context that owns this Dialect object. |
277 | MLIRContext *context; |
278 | |
279 | /// Flag that specifies whether this dialect supports unregistered operations, |
280 | /// i.e. operations prefixed with the dialect namespace but not registered |
281 | /// with addOperation. |
282 | bool unknownOpsAllowed = false; |
283 | |
284 | /// Flag that specifies whether this dialect allows unregistered types, i.e. |
285 | /// types prefixed with the dialect namespace but not registered with addType. |
286 | /// These types are represented with OpaqueType. |
287 | bool unknownTypesAllowed = false; |
288 | |
289 | /// A collection of registered dialect interfaces. |
290 | DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces; |
291 | |
292 | friend class DialectRegistry; |
293 | friend void registerDialect(); |
294 | friend class MLIRContext; |
295 | }; |
296 | |
297 | } // namespace mlir |
298 | |
299 | namespace llvm { |
300 | /// Provide isa functionality for Dialects. |
301 | template <typename T> |
302 | struct isa_impl<T, ::mlir::Dialect, |
303 | std::enable_if_t<std::is_base_of<::mlir::Dialect, T>::value>> { |
304 | static inline bool doit(const ::mlir::Dialect &dialect) { |
305 | return mlir::TypeID::get<T>() == dialect.getTypeID(); |
306 | } |
307 | }; |
308 | template <typename T> |
309 | struct isa_impl< |
310 | T, ::mlir::Dialect, |
311 | std::enable_if_t<std::is_base_of<::mlir::DialectInterface, T>::value>> { |
312 | static inline bool doit(const ::mlir::Dialect &dialect) { |
313 | return const_cast<::mlir::Dialect &>(dialect).getRegisteredInterface<T>(); |
314 | } |
315 | }; |
316 | template <typename T> |
317 | struct cast_retty_impl<T, ::mlir::Dialect *> { |
318 | using ret_type = T *; |
319 | }; |
320 | template <typename T> |
321 | struct cast_retty_impl<T, ::mlir::Dialect> { |
322 | using ret_type = T &; |
323 | }; |
324 | |
325 | template <typename T> |
326 | struct cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect> { |
327 | template <typename To> |
328 | static std::enable_if_t<std::is_base_of<::mlir::Dialect, To>::value, To &> |
329 | doitImpl(::mlir::Dialect &dialect) { |
330 | return static_cast<To &>(dialect); |
331 | } |
332 | template <typename To> |
333 | static std::enable_if_t<std::is_base_of<::mlir::DialectInterface, To>::value, |
334 | To &> |
335 | doitImpl(::mlir::Dialect &dialect) { |
336 | return *dialect.getRegisteredInterface<To>(); |
337 | } |
338 | |
339 | static auto &doit(::mlir::Dialect &dialect) { return doitImpl<T>(dialect); } |
340 | }; |
341 | template <class T> |
342 | struct cast_convert_val<T, ::mlir::Dialect *, ::mlir::Dialect *> { |
343 | static auto doit(::mlir::Dialect *dialect) { |
344 | return &cast_convert_val<T, ::mlir::Dialect, ::mlir::Dialect>::doit( |
345 | *dialect); |
346 | } |
347 | }; |
348 | |
349 | } // namespace llvm |
350 | |
351 | #endif |
1 | //===- TypeSupport.h --------------------------------------------*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file defines support types for registering dialect extended types. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #ifndef MLIR_IR_TYPESUPPORT_H | |||
14 | #define MLIR_IR_TYPESUPPORT_H | |||
15 | ||||
16 | #include "mlir/IR/MLIRContext.h" | |||
17 | #include "mlir/IR/StorageUniquerSupport.h" | |||
18 | #include "llvm/ADT/Twine.h" | |||
19 | ||||
20 | namespace mlir { | |||
21 | class Dialect; | |||
22 | class MLIRContext; | |||
23 | ||||
24 | //===----------------------------------------------------------------------===// | |||
25 | // AbstractType | |||
26 | //===----------------------------------------------------------------------===// | |||
27 | ||||
28 | /// This class contains all of the static information common to all instances of | |||
29 | /// a registered Type. | |||
30 | class AbstractType { | |||
31 | public: | |||
32 | using HasTraitFn = llvm::unique_function<bool(TypeID) const>; | |||
33 | using WalkImmediateSubElementsFn = function_ref<void( | |||
34 | Type, function_ref<void(Attribute)>, function_ref<void(Type)>)>; | |||
35 | using ReplaceImmediateSubElementsFn = | |||
36 | function_ref<Type(Type, ArrayRef<Attribute>, ArrayRef<Type>)>; | |||
37 | ||||
38 | /// Look up the specified abstract type in the MLIRContext and return a | |||
39 | /// reference to it. | |||
40 | static const AbstractType &lookup(TypeID typeID, MLIRContext *context); | |||
41 | ||||
42 | /// This method is used by Dialect objects when they register the list of | |||
43 | /// types they contain. | |||
44 | template <typename T> | |||
45 | static AbstractType get(Dialect &dialect) { | |||
46 | return AbstractType(dialect, T::getInterfaceMap(), T::getHasTraitFn(), | |||
| ||||
47 | T::getWalkImmediateSubElementsFn(), | |||
48 | T::getReplaceImmediateSubElementsFn(), T::getTypeID()); | |||
49 | } | |||
50 | ||||
51 | /// This method is used by Dialect objects to register types with | |||
52 | /// custom TypeIDs. | |||
53 | /// The use of this method is in general discouraged in favor of | |||
54 | /// 'get<CustomType>(dialect)'; | |||
55 | static AbstractType | |||
56 | get(Dialect &dialect, detail::InterfaceMap &&interfaceMap, | |||
57 | HasTraitFn &&hasTrait, | |||
58 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, | |||
59 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, | |||
60 | TypeID typeID) { | |||
61 | return AbstractType(dialect, std::move(interfaceMap), std::move(hasTrait), | |||
62 | walkImmediateSubElementsFn, | |||
63 | replaceImmediateSubElementsFn, typeID); | |||
64 | } | |||
65 | ||||
66 | /// Return the dialect this type was registered to. | |||
67 | Dialect &getDialect() const { return const_cast<Dialect &>(dialect); } | |||
68 | ||||
69 | /// Returns an instance of the concept object for the given interface if it | |||
70 | /// was registered to this type, null otherwise. This should not be used | |||
71 | /// directly. | |||
72 | template <typename T> | |||
73 | typename T::Concept *getInterface() const { | |||
74 | return interfaceMap.lookup<T>(); | |||
75 | } | |||
76 | ||||
77 | /// Returns true if the type has the interface with the given ID. | |||
78 | bool hasInterface(TypeID interfaceID) const { | |||
79 | return interfaceMap.contains(interfaceID); | |||
80 | } | |||
81 | ||||
82 | /// Returns true if the type has a particular trait. | |||
83 | template <template <typename T> class Trait> | |||
84 | bool hasTrait() const { | |||
85 | return hasTraitFn(TypeID::get<Trait>()); | |||
86 | } | |||
87 | ||||
88 | /// Returns true if the type has a particular trait. | |||
89 | bool hasTrait(TypeID traitID) const { return hasTraitFn(traitID); } | |||
90 | ||||
91 | /// Walk the immediate sub-elements of the given type. | |||
92 | void walkImmediateSubElements(Type type, | |||
93 | function_ref<void(Attribute)> walkAttrsFn, | |||
94 | function_ref<void(Type)> walkTypesFn) const; | |||
95 | ||||
96 | /// Replace the immediate sub-elements of the given type. | |||
97 | Type replaceImmediateSubElements(Type type, ArrayRef<Attribute> replAttrs, | |||
98 | ArrayRef<Type> replTypes) const; | |||
99 | ||||
100 | /// Return the unique identifier representing the concrete type class. | |||
101 | TypeID getTypeID() const { return typeID; } | |||
102 | ||||
103 | private: | |||
104 | AbstractType(Dialect &dialect, detail::InterfaceMap &&interfaceMap, | |||
105 | HasTraitFn &&hasTrait, | |||
106 | WalkImmediateSubElementsFn walkImmediateSubElementsFn, | |||
107 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn, | |||
108 | TypeID typeID) | |||
109 | : dialect(dialect), interfaceMap(std::move(interfaceMap)), | |||
110 | hasTraitFn(std::move(hasTrait)), | |||
111 | walkImmediateSubElementsFn(walkImmediateSubElementsFn), | |||
112 | replaceImmediateSubElementsFn(replaceImmediateSubElementsFn), | |||
113 | typeID(typeID) {} | |||
114 | ||||
115 | /// Give StorageUserBase access to the mutable lookup. | |||
116 | template <typename ConcreteT, typename BaseT, typename StorageT, | |||
117 | typename UniquerT, template <typename T> class... Traits> | |||
118 | friend class detail::StorageUserBase; | |||
119 | ||||
120 | /// Look up the specified abstract type in the MLIRContext and return a | |||
121 | /// (mutable) pointer to it. Return a null pointer if the type could not | |||
122 | /// be found in the context. | |||
123 | static AbstractType *lookupMutable(TypeID typeID, MLIRContext *context); | |||
124 | ||||
125 | /// This is the dialect that this type was registered to. | |||
126 | const Dialect &dialect; | |||
127 | ||||
128 | /// This is a collection of the interfaces registered to this type. | |||
129 | detail::InterfaceMap interfaceMap; | |||
130 | ||||
131 | /// Function to check if the type has a particular trait. | |||
132 | HasTraitFn hasTraitFn; | |||
133 | ||||
134 | /// Function to walk the immediate sub-elements of this type. | |||
135 | WalkImmediateSubElementsFn walkImmediateSubElementsFn; | |||
136 | ||||
137 | /// Function to replace the immediate sub-elements of this type. | |||
138 | ReplaceImmediateSubElementsFn replaceImmediateSubElementsFn; | |||
139 | ||||
140 | /// The unique identifier of the derived Type class. | |||
141 | const TypeID typeID; | |||
142 | }; | |||
143 | ||||
144 | //===----------------------------------------------------------------------===// | |||
145 | // TypeStorage | |||
146 | //===----------------------------------------------------------------------===// | |||
147 | ||||
148 | namespace detail { | |||
149 | struct TypeUniquer; | |||
150 | } // namespace detail | |||
151 | ||||
152 | /// Base storage class appearing in a Type. | |||
153 | class TypeStorage : public StorageUniquer::BaseStorage { | |||
154 | friend detail::TypeUniquer; | |||
155 | friend StorageUniquer; | |||
156 | ||||
157 | public: | |||
158 | /// Return the abstract type descriptor for this type. | |||
159 | const AbstractType &getAbstractType() { | |||
160 | assert(abstractType && "Malformed type storage object.")(static_cast <bool> (abstractType && "Malformed type storage object." ) ? void (0) : __assert_fail ("abstractType && \"Malformed type storage object.\"" , "mlir/include/mlir/IR/TypeSupport.h", 160, __extension__ __PRETTY_FUNCTION__ )); | |||
161 | return *abstractType; | |||
162 | } | |||
163 | ||||
164 | protected: | |||
165 | /// This constructor is used by derived classes as part of the TypeUniquer. | |||
166 | TypeStorage() {} | |||
167 | ||||
168 | private: | |||
169 | /// Set the abstract type for this storage instance. This is used by the | |||
170 | /// TypeUniquer when initializing a newly constructed type storage object. | |||
171 | void initialize(const AbstractType &abstractTy) { | |||
172 | abstractType = const_cast<AbstractType *>(&abstractTy); | |||
173 | } | |||
174 | ||||
175 | /// The abstract description for this type. | |||
176 | AbstractType *abstractType{nullptr}; | |||
177 | }; | |||
178 | ||||
179 | /// Default storage type for types that require no additional initialization or | |||
180 | /// storage. | |||
181 | using DefaultTypeStorage = TypeStorage; | |||
182 | ||||
183 | //===----------------------------------------------------------------------===// | |||
184 | // TypeStorageAllocator | |||
185 | //===----------------------------------------------------------------------===// | |||
186 | ||||
187 | /// This is a utility allocator used to allocate memory for instances of derived | |||
188 | /// Types. | |||
189 | using TypeStorageAllocator = StorageUniquer::StorageAllocator; | |||
190 | ||||
191 | //===----------------------------------------------------------------------===// | |||
192 | // TypeUniquer | |||
193 | //===----------------------------------------------------------------------===// | |||
194 | namespace detail { | |||
195 | /// A utility class to get, or create, unique instances of types within an | |||
196 | /// MLIRContext. This class manages all creation and uniquing of types. | |||
197 | struct TypeUniquer { | |||
198 | /// Get an uniqued instance of a type T. | |||
199 | template <typename T, typename... Args> | |||
200 | static T get(MLIRContext *ctx, Args &&...args) { | |||
201 | return getWithTypeID<T, Args...>(ctx, T::getTypeID(), | |||
202 | std::forward<Args>(args)...); | |||
203 | } | |||
204 | ||||
205 | /// Get an uniqued instance of a parametric type T. | |||
206 | /// The use of this method is in general discouraged in favor of | |||
207 | /// 'get<T, Args>(ctx, args)'. | |||
208 | template <typename T, typename... Args> | |||
209 | static std::enable_if_t< | |||
210 | !std::is_same<typename T::ImplType, TypeStorage>::value, T> | |||
211 | getWithTypeID(MLIRContext *ctx, TypeID typeID, Args &&...args) { | |||
212 | #ifndef NDEBUG | |||
213 | if (!ctx->getTypeUniquer().isParametricStorageInitialized(typeID)) | |||
214 | llvm::report_fatal_error( | |||
215 | llvm::Twine("can't create type '") + llvm::getTypeName<T>() + | |||
216 | "' because storage uniquer isn't initialized: the dialect was likely " | |||
217 | "not loaded, or the type wasn't added with addTypes<...>() " | |||
218 | "in the Dialect::initialize() method."); | |||
219 | #endif | |||
220 | return ctx->getTypeUniquer().get<typename T::ImplType>( | |||
221 | [&, typeID](TypeStorage *storage) { | |||
222 | storage->initialize(AbstractType::lookup(typeID, ctx)); | |||
223 | }, | |||
224 | typeID, std::forward<Args>(args)...); | |||
225 | } | |||
226 | /// Get an uniqued instance of a singleton type T. | |||
227 | /// The use of this method is in general discouraged in favor of | |||
228 | /// 'get<T, Args>(ctx, args)'. | |||
229 | template <typename T> | |||
230 | static std::enable_if_t< | |||
231 | std::is_same<typename T::ImplType, TypeStorage>::value, T> | |||
232 | getWithTypeID(MLIRContext *ctx, TypeID typeID) { | |||
233 | #ifndef NDEBUG | |||
234 | if (!ctx->getTypeUniquer().isSingletonStorageInitialized(typeID)) | |||
235 | llvm::report_fatal_error( | |||
236 | llvm::Twine("can't create type '") + llvm::getTypeName<T>() + | |||
237 | "' because storage uniquer isn't initialized: the dialect was likely " | |||
238 | "not loaded, or the type wasn't added with addTypes<...>() " | |||
239 | "in the Dialect::initialize() method."); | |||
240 | #endif | |||
241 | return ctx->getTypeUniquer().get<typename T::ImplType>(typeID); | |||
242 | } | |||
243 | ||||
244 | /// Change the mutable component of the given type instance in the provided | |||
245 | /// context. | |||
246 | template <typename T, typename... Args> | |||
247 | static LogicalResult mutate(MLIRContext *ctx, typename T::ImplType *impl, | |||
248 | Args &&...args) { | |||
249 | assert(impl && "cannot mutate null type")(static_cast <bool> (impl && "cannot mutate null type" ) ? void (0) : __assert_fail ("impl && \"cannot mutate null type\"" , "mlir/include/mlir/IR/TypeSupport.h", 249, __extension__ __PRETTY_FUNCTION__ )); | |||
250 | return ctx->getTypeUniquer().mutate(T::getTypeID(), impl, | |||
251 | std::forward<Args>(args)...); | |||
252 | } | |||
253 | ||||
254 | /// Register a type instance T with the uniquer. | |||
255 | template <typename T> | |||
256 | static void registerType(MLIRContext *ctx) { | |||
257 | registerType<T>(ctx, T::getTypeID()); | |||
258 | } | |||
259 | ||||
260 | /// Register a parametric type instance T with the uniquer. | |||
261 | /// The use of this method is in general discouraged in favor of | |||
262 | /// 'registerType<T>(ctx)'. | |||
263 | template <typename T> | |||
264 | static std::enable_if_t< | |||
265 | !std::is_same<typename T::ImplType, TypeStorage>::value> | |||
266 | registerType(MLIRContext *ctx, TypeID typeID) { | |||
267 | ctx->getTypeUniquer().registerParametricStorageType<typename T::ImplType>( | |||
268 | typeID); | |||
269 | } | |||
270 | /// Register a singleton type instance T with the uniquer. | |||
271 | /// The use of this method is in general discouraged in favor of | |||
272 | /// 'registerType<T>(ctx)'. | |||
273 | template <typename T> | |||
274 | static std::enable_if_t< | |||
275 | std::is_same<typename T::ImplType, TypeStorage>::value> | |||
276 | registerType(MLIRContext *ctx, TypeID typeID) { | |||
277 | ctx->getTypeUniquer().registerSingletonStorageType<TypeStorage>( | |||
278 | typeID, [&ctx, typeID](TypeStorage *storage) { | |||
279 | storage->initialize(AbstractType::lookup(typeID, ctx)); | |||
280 | }); | |||
281 | } | |||
282 | }; | |||
283 | } // namespace detail | |||
284 | ||||
285 | } // namespace mlir | |||
286 | ||||
287 | #endif |