File: | build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/IR/BuiltinAttributes.cpp |
Warning: | line 845, column 9 1st function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/BuiltinAttributes.h" | |||
10 | #include "AttributeDetail.h" | |||
11 | #include "mlir/IR/AffineMap.h" | |||
12 | #include "mlir/IR/BuiltinDialect.h" | |||
13 | #include "mlir/IR/Dialect.h" | |||
14 | #include "mlir/IR/DialectResourceBlobManager.h" | |||
15 | #include "mlir/IR/IntegerSet.h" | |||
16 | #include "mlir/IR/OpImplementation.h" | |||
17 | #include "mlir/IR/Operation.h" | |||
18 | #include "mlir/IR/SymbolTable.h" | |||
19 | #include "mlir/IR/Types.h" | |||
20 | #include "llvm/ADT/APSInt.h" | |||
21 | #include "llvm/ADT/Sequence.h" | |||
22 | #include "llvm/ADT/TypeSwitch.h" | |||
23 | #include "llvm/Support/Endian.h" | |||
24 | ||||
25 | using namespace mlir; | |||
26 | using namespace mlir::detail; | |||
27 | ||||
28 | //===----------------------------------------------------------------------===// | |||
29 | /// Tablegen Attribute Definitions | |||
30 | //===----------------------------------------------------------------------===// | |||
31 | ||||
32 | #define GET_ATTRDEF_CLASSES | |||
33 | #include "mlir/IR/BuiltinAttributes.cpp.inc" | |||
34 | ||||
35 | //===----------------------------------------------------------------------===// | |||
36 | // BuiltinDialect | |||
37 | //===----------------------------------------------------------------------===// | |||
38 | ||||
39 | void BuiltinDialect::registerAttributes() { | |||
40 | addAttributes< | |||
41 | #define GET_ATTRDEF_LIST | |||
42 | #include "mlir/IR/BuiltinAttributes.cpp.inc" | |||
43 | >(); | |||
44 | } | |||
45 | ||||
46 | //===----------------------------------------------------------------------===// | |||
47 | // ArrayAttr | |||
48 | //===----------------------------------------------------------------------===// | |||
49 | ||||
50 | void ArrayAttr::walkImmediateSubElements( | |||
51 | function_ref<void(Attribute)> walkAttrsFn, | |||
52 | function_ref<void(Type)> walkTypesFn) const { | |||
53 | for (Attribute attr : getValue()) | |||
54 | walkAttrsFn(attr); | |||
55 | } | |||
56 | ||||
57 | Attribute | |||
58 | ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, | |||
59 | ArrayRef<Type> replTypes) const { | |||
60 | return get(getContext(), replAttrs); | |||
61 | } | |||
62 | ||||
63 | //===----------------------------------------------------------------------===// | |||
64 | // DictionaryAttr | |||
65 | //===----------------------------------------------------------------------===// | |||
66 | ||||
67 | /// Helper function that does either an in place sort or sorts from source array | |||
68 | /// into destination. If inPlace then storage is both the source and the | |||
69 | /// destination, else value is the source and storage destination. Returns | |||
70 | /// whether source was sorted. | |||
71 | template <bool inPlace> | |||
72 | static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, | |||
73 | SmallVectorImpl<NamedAttribute> &storage) { | |||
74 | // Specialize for the common case. | |||
75 | switch (value.size()) { | |||
76 | case 0: | |||
77 | // Zero already sorted. | |||
78 | if (!inPlace) | |||
79 | storage.clear(); | |||
80 | break; | |||
81 | case 1: | |||
82 | // One already sorted but may need to be copied. | |||
83 | if (!inPlace) | |||
84 | storage.assign({value[0]}); | |||
85 | break; | |||
86 | case 2: { | |||
87 | bool isSorted = value[0] < value[1]; | |||
88 | if (inPlace) { | |||
89 | if (!isSorted) | |||
90 | std::swap(storage[0], storage[1]); | |||
91 | } else if (isSorted) { | |||
92 | storage.assign({value[0], value[1]}); | |||
93 | } else { | |||
94 | storage.assign({value[1], value[0]}); | |||
95 | } | |||
96 | return !isSorted; | |||
97 | } | |||
98 | default: | |||
99 | if (!inPlace) | |||
100 | storage.assign(value.begin(), value.end()); | |||
101 | // Check to see they are sorted already. | |||
102 | bool isSorted = llvm::is_sorted(value); | |||
103 | // If not, do a general sort. | |||
104 | if (!isSorted) | |||
105 | llvm::array_pod_sort(storage.begin(), storage.end()); | |||
106 | return !isSorted; | |||
107 | } | |||
108 | return false; | |||
109 | } | |||
110 | ||||
111 | /// Returns an entry with a duplicate name from the given sorted array of named | |||
112 | /// attributes. Returns llvm::None if all elements have unique names. | |||
113 | static Optional<NamedAttribute> | |||
114 | findDuplicateElement(ArrayRef<NamedAttribute> value) { | |||
115 | const Optional<NamedAttribute> none{llvm::None}; | |||
116 | if (value.size() < 2) | |||
117 | return none; | |||
118 | ||||
119 | if (value.size() == 2) | |||
120 | return value[0].getName() == value[1].getName() ? value[0] : none; | |||
121 | ||||
122 | const auto *it = std::adjacent_find(value.begin(), value.end(), | |||
123 | [](NamedAttribute l, NamedAttribute r) { | |||
124 | return l.getName() == r.getName(); | |||
125 | }); | |||
126 | return it != value.end() ? *it : none; | |||
127 | } | |||
128 | ||||
129 | bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, | |||
130 | SmallVectorImpl<NamedAttribute> &storage) { | |||
131 | bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); | |||
132 | assert(!findDuplicateElement(storage) &&(static_cast <bool> (!findDuplicateElement(storage) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__ )) | |||
133 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(storage) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__ )); | |||
134 | return isSorted; | |||
135 | } | |||
136 | ||||
137 | bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { | |||
138 | bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); | |||
139 | assert(!findDuplicateElement(array) &&(static_cast <bool> (!findDuplicateElement(array) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__ )) | |||
140 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(array) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__ )); | |||
141 | return isSorted; | |||
142 | } | |||
143 | ||||
144 | Optional<NamedAttribute> | |||
145 | DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, | |||
146 | bool isSorted) { | |||
147 | if (!isSorted) | |||
148 | dictionaryAttrSort</*inPlace=*/true>(array, array); | |||
149 | return findDuplicateElement(array); | |||
150 | } | |||
151 | ||||
152 | DictionaryAttr DictionaryAttr::get(MLIRContext *context, | |||
153 | ArrayRef<NamedAttribute> value) { | |||
154 | if (value.empty()) | |||
155 | return DictionaryAttr::getEmpty(context); | |||
156 | ||||
157 | // We need to sort the element list to canonicalize it. | |||
158 | SmallVector<NamedAttribute, 8> storage; | |||
159 | if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) | |||
160 | value = storage; | |||
161 | assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__ )) | |||
162 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__ )); | |||
163 | return Base::get(context, value); | |||
164 | } | |||
165 | /// Construct a dictionary with an array of values that is known to already be | |||
166 | /// sorted by name and uniqued. | |||
167 | DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, | |||
168 | ArrayRef<NamedAttribute> value) { | |||
169 | if (value.empty()) | |||
170 | return DictionaryAttr::getEmpty(context); | |||
171 | // Ensure that the attribute elements are unique and sorted. | |||
172 | assert(llvm::is_sorted((static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__ )) | |||
173 | value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__ )) | |||
174 | "expected attribute values to be sorted")(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__ )); | |||
175 | assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__ )) | |||
176 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__ )); | |||
177 | return Base::get(context, value); | |||
178 | } | |||
179 | ||||
180 | /// Return the specified attribute if present, null otherwise. | |||
181 | Attribute DictionaryAttr::get(StringRef name) const { | |||
182 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
183 | return it.second ? it.first->getValue() : Attribute(); | |||
184 | } | |||
185 | Attribute DictionaryAttr::get(StringAttr name) const { | |||
186 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
187 | return it.second ? it.first->getValue() : Attribute(); | |||
188 | } | |||
189 | ||||
190 | /// Return the specified named attribute if present, None otherwise. | |||
191 | Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { | |||
192 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
193 | return it.second ? *it.first : Optional<NamedAttribute>(); | |||
194 | } | |||
195 | Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { | |||
196 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
197 | return it.second ? *it.first : Optional<NamedAttribute>(); | |||
198 | } | |||
199 | ||||
200 | /// Return whether the specified attribute is present. | |||
201 | bool DictionaryAttr::contains(StringRef name) const { | |||
202 | return impl::findAttrSorted(begin(), end(), name).second; | |||
203 | } | |||
204 | bool DictionaryAttr::contains(StringAttr name) const { | |||
205 | return impl::findAttrSorted(begin(), end(), name).second; | |||
206 | } | |||
207 | ||||
208 | DictionaryAttr::iterator DictionaryAttr::begin() const { | |||
209 | return getValue().begin(); | |||
210 | } | |||
211 | DictionaryAttr::iterator DictionaryAttr::end() const { | |||
212 | return getValue().end(); | |||
213 | } | |||
214 | size_t DictionaryAttr::size() const { return getValue().size(); } | |||
215 | ||||
216 | DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { | |||
217 | return Base::get(context, ArrayRef<NamedAttribute>()); | |||
218 | } | |||
219 | ||||
220 | void DictionaryAttr::walkImmediateSubElements( | |||
221 | function_ref<void(Attribute)> walkAttrsFn, | |||
222 | function_ref<void(Type)> walkTypesFn) const { | |||
223 | for (const NamedAttribute &attr : getValue()) | |||
224 | walkAttrsFn(attr.getValue()); | |||
225 | } | |||
226 | ||||
227 | Attribute | |||
228 | DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, | |||
229 | ArrayRef<Type> replTypes) const { | |||
230 | std::vector<NamedAttribute> vec = getValue().vec(); | |||
231 | for (auto &it : llvm::enumerate(replAttrs)) | |||
232 | vec[it.index()].setValue(it.value()); | |||
233 | ||||
234 | // The above only modifies the mapped value, but not the key, and therefore | |||
235 | // not the order of the elements. It remains sorted | |||
236 | return getWithSorted(getContext(), vec); | |||
237 | } | |||
238 | ||||
239 | //===----------------------------------------------------------------------===// | |||
240 | // StringAttr | |||
241 | //===----------------------------------------------------------------------===// | |||
242 | ||||
243 | StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { | |||
244 | return Base::get(context, "", NoneType::get(context)); | |||
245 | } | |||
246 | ||||
247 | /// Twine support for StringAttr. | |||
248 | StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { | |||
249 | // Fast-path empty twine. | |||
250 | if (twine.isTriviallyEmpty()) | |||
251 | return get(context); | |||
252 | SmallVector<char, 32> tempStr; | |||
253 | return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); | |||
254 | } | |||
255 | ||||
256 | /// Twine support for StringAttr. | |||
257 | StringAttr StringAttr::get(const Twine &twine, Type type) { | |||
258 | SmallVector<char, 32> tempStr; | |||
259 | return Base::get(type.getContext(), twine.toStringRef(tempStr), type); | |||
260 | } | |||
261 | ||||
262 | StringRef StringAttr::getValue() const { return getImpl()->value; } | |||
263 | ||||
264 | Type StringAttr::getType() const { return getImpl()->type; } | |||
265 | ||||
266 | Dialect *StringAttr::getReferencedDialect() const { | |||
267 | return getImpl()->referencedDialect; | |||
268 | } | |||
269 | ||||
270 | //===----------------------------------------------------------------------===// | |||
271 | // FloatAttr | |||
272 | //===----------------------------------------------------------------------===// | |||
273 | ||||
274 | double FloatAttr::getValueAsDouble() const { | |||
275 | return getValueAsDouble(getValue()); | |||
276 | } | |||
277 | double FloatAttr::getValueAsDouble(APFloat value) { | |||
278 | if (&value.getSemantics() != &APFloat::IEEEdouble()) { | |||
279 | bool losesInfo = false; | |||
280 | value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, | |||
281 | &losesInfo); | |||
282 | } | |||
283 | return value.convertToDouble(); | |||
284 | } | |||
285 | ||||
286 | LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
287 | Type type, APFloat value) { | |||
288 | // Verify that the type is correct. | |||
289 | if (!type.isa<FloatType>()) | |||
290 | return emitError() << "expected floating point type"; | |||
291 | ||||
292 | // Verify that the type semantics match that of the value. | |||
293 | if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { | |||
294 | return emitError() | |||
295 | << "FloatAttr type doesn't match the type implied by its value"; | |||
296 | } | |||
297 | return success(); | |||
298 | } | |||
299 | ||||
300 | //===----------------------------------------------------------------------===// | |||
301 | // SymbolRefAttr | |||
302 | //===----------------------------------------------------------------------===// | |||
303 | ||||
304 | SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, | |||
305 | ArrayRef<FlatSymbolRefAttr> nestedRefs) { | |||
306 | return get(StringAttr::get(ctx, value), nestedRefs); | |||
307 | } | |||
308 | ||||
309 | FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { | |||
310 | return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); | |||
311 | } | |||
312 | ||||
313 | FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { | |||
314 | return get(value, {}).cast<FlatSymbolRefAttr>(); | |||
315 | } | |||
316 | ||||
317 | FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { | |||
318 | auto symName = | |||
319 | symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); | |||
320 | assert(symName && "value does not have a valid symbol name")(static_cast <bool> (symName && "value does not have a valid symbol name" ) ? void (0) : __assert_fail ("symName && \"value does not have a valid symbol name\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 320, __extension__ __PRETTY_FUNCTION__ )); | |||
321 | return SymbolRefAttr::get(symName); | |||
322 | } | |||
323 | ||||
324 | StringAttr SymbolRefAttr::getLeafReference() const { | |||
325 | ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); | |||
326 | return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); | |||
327 | } | |||
328 | ||||
329 | void SymbolRefAttr::walkImmediateSubElements( | |||
330 | function_ref<void(Attribute)> walkAttrsFn, | |||
331 | function_ref<void(Type)> walkTypesFn) const { | |||
332 | walkAttrsFn(getRootReference()); | |||
333 | for (FlatSymbolRefAttr ref : getNestedReferences()) | |||
334 | walkAttrsFn(ref); | |||
335 | } | |||
336 | ||||
337 | Attribute | |||
338 | SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, | |||
339 | ArrayRef<Type> replTypes) const { | |||
340 | ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front(); | |||
341 | ArrayRef<FlatSymbolRefAttr> nestedRefs( | |||
342 | static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()), | |||
343 | rawNestedRefs.size()); | |||
344 | return get(replAttrs[0].cast<StringAttr>(), nestedRefs); | |||
345 | } | |||
346 | ||||
347 | //===----------------------------------------------------------------------===// | |||
348 | // IntegerAttr | |||
349 | //===----------------------------------------------------------------------===// | |||
350 | ||||
351 | int64_t IntegerAttr::getInt() const { | |||
352 | assert((getType().isIndex() || getType().isSignlessInteger()) &&(static_cast <bool> ((getType().isIndex() || getType(). isSignlessInteger()) && "must be signless integer") ? void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__ )) | |||
353 | "must be signless integer")(static_cast <bool> ((getType().isIndex() || getType(). isSignlessInteger()) && "must be signless integer") ? void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__ )); | |||
354 | return getValue().getSExtValue(); | |||
355 | } | |||
356 | ||||
357 | int64_t IntegerAttr::getSInt() const { | |||
358 | assert(getType().isSignedInteger() && "must be signed integer")(static_cast <bool> (getType().isSignedInteger() && "must be signed integer") ? void (0) : __assert_fail ("getType().isSignedInteger() && \"must be signed integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 358, __extension__ __PRETTY_FUNCTION__ )); | |||
359 | return getValue().getSExtValue(); | |||
360 | } | |||
361 | ||||
362 | uint64_t IntegerAttr::getUInt() const { | |||
363 | assert(getType().isUnsignedInteger() && "must be unsigned integer")(static_cast <bool> (getType().isUnsignedInteger() && "must be unsigned integer") ? void (0) : __assert_fail ("getType().isUnsignedInteger() && \"must be unsigned integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 363, __extension__ __PRETTY_FUNCTION__ )); | |||
364 | return getValue().getZExtValue(); | |||
365 | } | |||
366 | ||||
367 | /// Return the value as an APSInt which carries the signed from the type of | |||
368 | /// the attribute. This traps on signless integers types! | |||
369 | APSInt IntegerAttr::getAPSInt() const { | |||
370 | assert(!getType().isSignlessInteger() &&(static_cast <bool> (!getType().isSignlessInteger() && "Signless integers don't carry a sign for APSInt") ? void (0 ) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 371, __extension__ __PRETTY_FUNCTION__ )) | |||
371 | "Signless integers don't carry a sign for APSInt")(static_cast <bool> (!getType().isSignlessInteger() && "Signless integers don't carry a sign for APSInt") ? void (0 ) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 371, __extension__ __PRETTY_FUNCTION__ )); | |||
372 | return APSInt(getValue(), getType().isUnsignedInteger()); | |||
373 | } | |||
374 | ||||
375 | LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
376 | Type type, APInt value) { | |||
377 | if (IntegerType integerType = type.dyn_cast<IntegerType>()) { | |||
378 | if (integerType.getWidth() != value.getBitWidth()) | |||
379 | return emitError() << "integer type bit width (" << integerType.getWidth() | |||
380 | << ") doesn't match value bit width (" | |||
381 | << value.getBitWidth() << ")"; | |||
382 | return success(); | |||
383 | } | |||
384 | if (type.isa<IndexType>()) | |||
385 | return success(); | |||
386 | return emitError() << "expected integer or index type"; | |||
387 | } | |||
388 | ||||
389 | BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { | |||
390 | auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); | |||
391 | return attr.cast<BoolAttr>(); | |||
392 | } | |||
393 | ||||
394 | //===----------------------------------------------------------------------===// | |||
395 | // BoolAttr | |||
396 | //===----------------------------------------------------------------------===// | |||
397 | ||||
398 | bool BoolAttr::getValue() const { | |||
399 | auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); | |||
400 | return storage->value.getBoolValue(); | |||
401 | } | |||
402 | ||||
403 | bool BoolAttr::classof(Attribute attr) { | |||
404 | IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); | |||
405 | return intAttr && intAttr.getType().isSignlessInteger(1); | |||
406 | } | |||
407 | ||||
408 | //===----------------------------------------------------------------------===// | |||
409 | // OpaqueAttr | |||
410 | //===----------------------------------------------------------------------===// | |||
411 | ||||
412 | LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
413 | StringAttr dialect, StringRef attrData, | |||
414 | Type type) { | |||
415 | if (!Dialect::isValidNamespace(dialect.strref())) | |||
416 | return emitError() << "invalid dialect namespace '" << dialect << "'"; | |||
417 | ||||
418 | // Check that the dialect is actually registered. | |||
419 | MLIRContext *context = dialect.getContext(); | |||
420 | if (!context->allowsUnregisteredDialects() && | |||
421 | !context->getLoadedDialect(dialect.strref())) { | |||
422 | return emitError() | |||
423 | << "#" << dialect << "<\"" << attrData << "\"> : " << type | |||
424 | << " attribute created with unregistered dialect. If this is " | |||
425 | "intended, please call allowUnregisteredDialects() on the " | |||
426 | "MLIRContext, or use -allow-unregistered-dialect with " | |||
427 | "the MLIR opt tool used"; | |||
428 | } | |||
429 | ||||
430 | return success(); | |||
431 | } | |||
432 | ||||
433 | //===----------------------------------------------------------------------===// | |||
434 | // DenseElementsAttr Utilities | |||
435 | //===----------------------------------------------------------------------===// | |||
436 | ||||
437 | /// Get the bitwidth of a dense element type within the buffer. | |||
438 | /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. | |||
439 | static size_t getDenseElementStorageWidth(size_t origWidth) { | |||
440 | return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); | |||
441 | } | |||
442 | static size_t getDenseElementStorageWidth(Type elementType) { | |||
443 | return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); | |||
444 | } | |||
445 | ||||
446 | /// Set a bit to a specific value. | |||
447 | static void setBit(char *rawData, size_t bitPos, bool value) { | |||
448 | if (value) | |||
449 | rawData[bitPos / CHAR_BIT8] |= (1 << (bitPos % CHAR_BIT8)); | |||
450 | else | |||
451 | rawData[bitPos / CHAR_BIT8] &= ~(1 << (bitPos % CHAR_BIT8)); | |||
452 | } | |||
453 | ||||
454 | /// Return the value of the specified bit. | |||
455 | static bool getBit(const char *rawData, size_t bitPos) { | |||
456 | return (rawData[bitPos / CHAR_BIT8] & (1 << (bitPos % CHAR_BIT8))) != 0; | |||
457 | } | |||
458 | ||||
459 | /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for | |||
460 | /// BE format. | |||
461 | static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, | |||
462 | char *result) { | |||
463 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 464, __extension__ __PRETTY_FUNCTION__ )) | |||
464 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 464, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
465 | assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes) ? void (0) : __assert_fail ("value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes" , "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__ )); | |||
466 | ||||
467 | // Copy the words filled with data. | |||
468 | // For example, when `value` has 2 words, the first word is filled with data. | |||
469 | // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| | |||
470 | size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; | |||
471 | std::copy_n(reinterpret_cast<const char *>(value.getRawData()), | |||
472 | numFilledWords, result); | |||
473 | // Convert last word of APInt to LE format and store it in char | |||
474 | // array(`valueLE`). | |||
475 | // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| | |||
476 | size_t lastWordPos = numFilledWords; | |||
477 | SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); | |||
478 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
479 | reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, | |||
480 | valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); | |||
481 | // Extract actual APInt data from `valueLE`, convert endianness to BE format, | |||
482 | // and store it in `result`. | |||
483 | // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| | |||
484 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
485 | valueLE.begin(), result + lastWordPos, | |||
486 | (numBytes - lastWordPos) * CHAR_BIT8, 1); | |||
487 | } | |||
488 | ||||
489 | /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE | |||
490 | /// format. | |||
491 | static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, | |||
492 | APInt &result) { | |||
493 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 494, __extension__ __PRETTY_FUNCTION__ )) | |||
494 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 494, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
495 | assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes) ? void (0) : __assert_fail ("result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes" , "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__ )); | |||
496 | ||||
497 | // Copy the data that fills the word of `result` from `inArray`. | |||
498 | // For example, when `result` has 2 words, the first word will be filled with | |||
499 | // data. So, the first 8 bytes are copied from `inArray` here. | |||
500 | // `inArray` (10 bytes, BE): |abcdefgh|ij| | |||
501 | // ==> `result` (2 words, BE): |abcdefgh|--------| | |||
502 | size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; | |||
503 | std::copy_n( | |||
504 | inArray, numFilledWords, | |||
505 | const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); | |||
506 | ||||
507 | // Convert array data which will be last word of `result` to LE format, and | |||
508 | // store it in char array(`inArrayLE`). | |||
509 | // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| | |||
510 | size_t lastWordPos = numFilledWords; | |||
511 | SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); | |||
512 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
513 | inArray + lastWordPos, inArrayLE.begin(), | |||
514 | (numBytes - lastWordPos) * CHAR_BIT8, 1); | |||
515 | ||||
516 | // Convert `inArrayLE` to BE format, and store it in last word of `result`. | |||
517 | // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| | |||
518 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
519 | inArrayLE.begin(), | |||
520 | const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + | |||
521 | lastWordPos, | |||
522 | APInt::APINT_BITS_PER_WORD, 1); | |||
523 | } | |||
524 | ||||
525 | /// Writes value to the bit position `bitPos` in array `rawData`. | |||
526 | static void writeBits(char *rawData, size_t bitPos, APInt value) { | |||
527 | size_t bitWidth = value.getBitWidth(); | |||
528 | ||||
529 | // If the bitwidth is 1 we just toggle the specific bit. | |||
530 | if (bitWidth == 1) | |||
531 | return setBit(rawData, bitPos, value.isOneValue()); | |||
532 | ||||
533 | // Otherwise, the bit position is guaranteed to be byte aligned. | |||
534 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned" ) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 534, __extension__ __PRETTY_FUNCTION__ )); | |||
535 | if (llvm::support::endian::system_endianness() == | |||
536 | llvm::support::endianness::big) { | |||
537 | // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. | |||
538 | // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't | |||
539 | // work correctly in BE format. | |||
540 | // ex. `value` (2 words including 10 bytes) | |||
541 | // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| | |||
542 | copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
543 | rawData + (bitPos / CHAR_BIT8)); | |||
544 | } else { | |||
545 | std::copy_n(reinterpret_cast<const char *>(value.getRawData()), | |||
546 | llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
547 | rawData + (bitPos / CHAR_BIT8)); | |||
548 | } | |||
549 | } | |||
550 | ||||
551 | /// Reads the next `bitWidth` bits from the bit position `bitPos` in array | |||
552 | /// `rawData`. | |||
553 | static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { | |||
554 | // Handle a boolean bit position. | |||
555 | if (bitWidth == 1) | |||
556 | return APInt(1, getBit(rawData, bitPos) ? 1 : 0); | |||
557 | ||||
558 | // Otherwise, the bit position must be 8-bit aligned. | |||
559 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned" ) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 559, __extension__ __PRETTY_FUNCTION__ )); | |||
560 | APInt result(bitWidth, 0); | |||
561 | if (llvm::support::endian::system_endianness() == | |||
562 | llvm::support::endianness::big) { | |||
563 | // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. | |||
564 | // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't | |||
565 | // work correctly in BE format. | |||
566 | // ex. `result` (2 words including 10 bytes) | |||
567 | // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function | |||
568 | copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT8), | |||
569 | llvm::divideCeil(bitWidth, CHAR_BIT8), result); | |||
570 | } else { | |||
571 | std::copy_n(rawData + (bitPos / CHAR_BIT8), | |||
572 | llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
573 | const_cast<char *>( | |||
574 | reinterpret_cast<const char *>(result.getRawData()))); | |||
575 | } | |||
576 | return result; | |||
577 | } | |||
578 | ||||
579 | /// Returns true if 'values' corresponds to a splat, i.e. one element, or has | |||
580 | /// the same element count as 'type'. | |||
581 | template <typename Values> | |||
582 | static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { | |||
583 | return (values.size() == 1) || | |||
584 | (type.getNumElements() == static_cast<int64_t>(values.size())); | |||
585 | } | |||
586 | ||||
587 | //===----------------------------------------------------------------------===// | |||
588 | // DenseElementsAttr Iterators | |||
589 | //===----------------------------------------------------------------------===// | |||
590 | ||||
591 | //===----------------------------------------------------------------------===// | |||
592 | // AttributeElementIterator | |||
593 | ||||
594 | DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( | |||
595 | DenseElementsAttr attr, size_t index) | |||
596 | : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, | |||
597 | Attribute, Attribute, Attribute>( | |||
598 | attr.getAsOpaquePointer(), index) {} | |||
599 | ||||
600 | Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { | |||
601 | auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); | |||
602 | Type eltTy = owner.getElementType(); | |||
603 | if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) | |||
604 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); | |||
605 | if (eltTy.isa<IndexType>()) | |||
606 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); | |||
607 | if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { | |||
608 | IntElementIterator intIt(owner, index); | |||
609 | FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); | |||
610 | return FloatAttr::get(eltTy, *floatIt); | |||
611 | } | |||
612 | if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { | |||
613 | auto complexEltTy = complexTy.getElementType(); | |||
614 | ComplexIntElementIterator complexIntIt(owner, index); | |||
615 | if (complexEltTy.isa<IntegerType>()) { | |||
616 | auto value = *complexIntIt; | |||
617 | auto real = IntegerAttr::get(complexEltTy, value.real()); | |||
618 | auto imag = IntegerAttr::get(complexEltTy, value.imag()); | |||
619 | return ArrayAttr::get(complexTy.getContext(), | |||
620 | ArrayRef<Attribute>{real, imag}); | |||
621 | } | |||
622 | ||||
623 | ComplexFloatElementIterator complexFloatIt( | |||
624 | complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); | |||
625 | auto value = *complexFloatIt; | |||
626 | auto real = FloatAttr::get(complexEltTy, value.real()); | |||
627 | auto imag = FloatAttr::get(complexEltTy, value.imag()); | |||
628 | return ArrayAttr::get(complexTy.getContext(), | |||
629 | ArrayRef<Attribute>{real, imag}); | |||
630 | } | |||
631 | if (owner.isa<DenseStringElementsAttr>()) { | |||
632 | ArrayRef<StringRef> vals = owner.getRawStringData(); | |||
633 | return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); | |||
634 | } | |||
635 | llvm_unreachable("unexpected element type")::llvm::llvm_unreachable_internal("unexpected element type", "mlir/lib/IR/BuiltinAttributes.cpp" , 635); | |||
636 | } | |||
637 | ||||
638 | //===----------------------------------------------------------------------===// | |||
639 | // BoolElementIterator | |||
640 | ||||
641 | DenseElementsAttr::BoolElementIterator::BoolElementIterator( | |||
642 | DenseElementsAttr attr, size_t dataIndex) | |||
643 | : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( | |||
644 | attr.getRawData().data(), attr.isSplat(), dataIndex) {} | |||
645 | ||||
646 | bool DenseElementsAttr::BoolElementIterator::operator*() const { | |||
647 | return getBit(getData(), getDataIndex()); | |||
648 | } | |||
649 | ||||
650 | //===----------------------------------------------------------------------===// | |||
651 | // IntElementIterator | |||
652 | ||||
653 | DenseElementsAttr::IntElementIterator::IntElementIterator( | |||
654 | DenseElementsAttr attr, size_t dataIndex) | |||
655 | : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( | |||
656 | attr.getRawData().data(), attr.isSplat(), dataIndex), | |||
657 | bitWidth(getDenseElementBitWidth(attr.getElementType())) {} | |||
658 | ||||
659 | APInt DenseElementsAttr::IntElementIterator::operator*() const { | |||
660 | return readBits(getData(), | |||
661 | getDataIndex() * getDenseElementStorageWidth(bitWidth), | |||
662 | bitWidth); | |||
663 | } | |||
664 | ||||
665 | //===----------------------------------------------------------------------===// | |||
666 | // ComplexIntElementIterator | |||
667 | ||||
668 | DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( | |||
669 | DenseElementsAttr attr, size_t dataIndex) | |||
670 | : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, | |||
671 | std::complex<APInt>, std::complex<APInt>, | |||
672 | std::complex<APInt>>( | |||
673 | attr.getRawData().data(), attr.isSplat(), dataIndex) { | |||
674 | auto complexType = attr.getElementType().cast<ComplexType>(); | |||
675 | bitWidth = getDenseElementBitWidth(complexType.getElementType()); | |||
676 | } | |||
677 | ||||
678 | std::complex<APInt> | |||
679 | DenseElementsAttr::ComplexIntElementIterator::operator*() const { | |||
680 | size_t storageWidth = getDenseElementStorageWidth(bitWidth); | |||
681 | size_t offset = getDataIndex() * storageWidth * 2; | |||
682 | return {readBits(getData(), offset, bitWidth), | |||
683 | readBits(getData(), offset + storageWidth, bitWidth)}; | |||
684 | } | |||
685 | ||||
686 | //===----------------------------------------------------------------------===// | |||
687 | // DenseArrayAttr | |||
688 | //===----------------------------------------------------------------------===// | |||
689 | ||||
690 | const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const { | |||
691 | return cast<DenseBoolArrayAttr>().asArrayRef().begin(); | |||
692 | } | |||
693 | const int8_t * | |||
694 | DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const { | |||
695 | return cast<DenseI8ArrayAttr>().asArrayRef().begin(); | |||
696 | } | |||
697 | const int16_t * | |||
698 | DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const { | |||
699 | return cast<DenseI16ArrayAttr>().asArrayRef().begin(); | |||
700 | } | |||
701 | const int32_t * | |||
702 | DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const { | |||
703 | return cast<DenseI32ArrayAttr>().asArrayRef().begin(); | |||
704 | } | |||
705 | const int64_t * | |||
706 | DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const { | |||
707 | return cast<DenseI64ArrayAttr>().asArrayRef().begin(); | |||
708 | } | |||
709 | const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const { | |||
710 | return cast<DenseF32ArrayAttr>().asArrayRef().begin(); | |||
711 | } | |||
712 | const double * | |||
713 | DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const { | |||
714 | return cast<DenseF64ArrayAttr>().asArrayRef().begin(); | |||
715 | } | |||
716 | ||||
717 | void DenseArrayBaseAttr::print(AsmPrinter &printer) const { | |||
718 | print(printer.getStream()); | |||
719 | } | |||
720 | ||||
721 | void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const { | |||
722 | llvm::TypeSwitch<DenseArrayBaseAttr>(*this) | |||
723 | .Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr, | |||
724 | DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr, | |||
725 | DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); }); | |||
726 | } | |||
727 | ||||
728 | void DenseArrayBaseAttr::print(raw_ostream &os) const { | |||
729 | os << "["; | |||
730 | printWithoutBraces(os); | |||
731 | os << "]"; | |||
732 | } | |||
733 | ||||
734 | namespace { | |||
735 | /// Instantiations of this class provide utilities for interacting with native | |||
736 | /// data types in the context of DenseArrayAttr. | |||
737 | template <size_t width, | |||
738 | IntegerType::SignednessSemantics signedness = IntegerType::Signless> | |||
739 | struct DenseArrayAttrIntUtil { | |||
740 | static bool checkElementType(Type eltType) { | |||
741 | auto type = eltType.dyn_cast<IntegerType>(); | |||
742 | if (!type || type.getWidth() != width) | |||
743 | return false; | |||
744 | return type.getSignedness() == signedness; | |||
745 | } | |||
746 | ||||
747 | static Type getElementType(MLIRContext *ctx) { | |||
748 | return IntegerType::get(ctx, width, signedness); | |||
749 | } | |||
750 | ||||
751 | template <typename T> | |||
752 | static void printElement(raw_ostream &os, T value) { | |||
753 | os << value; | |||
754 | } | |||
755 | ||||
756 | template <typename T> | |||
757 | static ParseResult parseElement(AsmParser &parser, T &value) { | |||
758 | return parser.parseInteger(value); | |||
759 | } | |||
760 | }; | |||
761 | template <typename T> | |||
762 | struct DenseArrayAttrUtil; | |||
763 | ||||
764 | /// Specialization for boolean elements to print 'true' and 'false' literals for | |||
765 | /// elements. | |||
766 | template <> | |||
767 | struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> { | |||
768 | static void printElement(raw_ostream &os, bool value) { | |||
769 | os << (value ? "true" : "false"); | |||
770 | } | |||
771 | }; | |||
772 | ||||
773 | /// Specialization for 8-bit integers to ensure values are printed as integers | |||
774 | /// and not characters. | |||
775 | template <> | |||
776 | struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> { | |||
777 | static void printElement(raw_ostream &os, int8_t value) { | |||
778 | os << static_cast<int>(value); | |||
779 | } | |||
780 | }; | |||
781 | template <> | |||
782 | struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {}; | |||
783 | template <> | |||
784 | struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {}; | |||
785 | template <> | |||
786 | struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {}; | |||
787 | ||||
788 | /// Specialization for 32-bit floats. | |||
789 | template <> | |||
790 | struct DenseArrayAttrUtil<float> { | |||
791 | static bool checkElementType(Type eltType) { return eltType.isF32(); } | |||
792 | static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); } | |||
793 | static void printElement(raw_ostream &os, float value) { os << value; } | |||
794 | ||||
795 | /// Parse a double and cast it to a float. | |||
796 | static ParseResult parseElement(AsmParser &parser, float &value) { | |||
797 | double doubleVal; | |||
798 | if (parser.parseFloat(doubleVal)) | |||
799 | return failure(); | |||
800 | value = doubleVal; | |||
801 | return success(); | |||
802 | } | |||
803 | }; | |||
804 | ||||
805 | /// Specialization for 64-bit floats. | |||
806 | template <> | |||
807 | struct DenseArrayAttrUtil<double> { | |||
808 | static bool checkElementType(Type eltType) { return eltType.isF64(); } | |||
809 | static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); } | |||
810 | static void printElement(raw_ostream &os, float value) { os << value; } | |||
811 | static ParseResult parseElement(AsmParser &parser, double &value) { | |||
812 | return parser.parseFloat(value); | |||
813 | } | |||
814 | }; | |||
815 | } // namespace | |||
816 | ||||
817 | template <typename T> | |||
818 | void DenseArrayAttr<T>::print(AsmPrinter &printer) const { | |||
819 | print(printer.getStream()); | |||
820 | } | |||
821 | ||||
822 | template <typename T> | |||
823 | void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const { | |||
824 | llvm::interleaveComma(asArrayRef(), os, [&](T value) { | |||
825 | DenseArrayAttrUtil<T>::printElement(os, value); | |||
826 | }); | |||
827 | } | |||
828 | ||||
829 | template <typename T> | |||
830 | void DenseArrayAttr<T>::print(raw_ostream &os) const { | |||
831 | os << "["; | |||
832 | printWithoutBraces(os); | |||
833 | os << "]"; | |||
834 | } | |||
835 | ||||
836 | /// Parse a DenseArrayAttr without the braces: `1, 2, 3` | |||
837 | template <typename T> | |||
838 | Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser, | |||
839 | Type odsType) { | |||
840 | SmallVector<T> data; | |||
841 | if (failed(parser.parseCommaSeparatedList([&]() { | |||
842 | T value; | |||
| ||||
843 | if (DenseArrayAttrUtil<T>::parseElement(parser, value)) | |||
844 | return failure(); | |||
845 | data.push_back(value); | |||
| ||||
846 | return success(); | |||
847 | }))) | |||
848 | return {}; | |||
849 | return get(parser.getContext(), data); | |||
850 | } | |||
851 | ||||
852 | /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` | |||
853 | template <typename T> | |||
854 | Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) { | |||
855 | if (parser.parseLSquare()) | |||
856 | return {}; | |||
857 | // Handle empty list case. | |||
858 | if (succeeded(parser.parseOptionalRSquare())) | |||
859 | return get(parser.getContext(), {}); | |||
860 | Attribute result = parseWithoutBraces(parser, odsType); | |||
861 | if (parser.parseRSquare()) | |||
862 | return {}; | |||
863 | return result; | |||
864 | } | |||
865 | ||||
866 | /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. | |||
867 | template <typename T> | |||
868 | DenseArrayAttr<T>::operator ArrayRef<T>() const { | |||
869 | ArrayRef<char> raw = getRawData(); | |||
870 | assert((raw.size() % sizeof(T)) == 0)(static_cast <bool> ((raw.size() % sizeof(T)) == 0) ? void (0) : __assert_fail ("(raw.size() % sizeof(T)) == 0", "mlir/lib/IR/BuiltinAttributes.cpp" , 870, __extension__ __PRETTY_FUNCTION__)); | |||
871 | return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), | |||
872 | raw.size() / sizeof(T)); | |||
873 | } | |||
874 | ||||
875 | /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. | |||
876 | template <typename T> | |||
877 | DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context, | |||
878 | ArrayRef<T> content) { | |||
879 | auto shapedType = RankedTensorType::get( | |||
880 | content.size(), DenseArrayAttrUtil<T>::getElementType(context)); | |||
881 | auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), | |||
882 | content.size() * sizeof(T)); | |||
883 | return Base::get(context, shapedType, rawArray) | |||
884 | .template cast<DenseArrayAttr<T>>(); | |||
885 | } | |||
886 | ||||
887 | template <typename T> | |||
888 | bool DenseArrayAttr<T>::classof(Attribute attr) { | |||
889 | if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>()) | |||
890 | return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType()); | |||
891 | return false; | |||
892 | } | |||
893 | ||||
894 | namespace mlir { | |||
895 | namespace detail { | |||
896 | // Explicit instantiation for all the supported DenseArrayAttr. | |||
897 | template class DenseArrayAttr<bool>; | |||
898 | template class DenseArrayAttr<int8_t>; | |||
899 | template class DenseArrayAttr<int16_t>; | |||
900 | template class DenseArrayAttr<int32_t>; | |||
901 | template class DenseArrayAttr<int64_t>; | |||
902 | template class DenseArrayAttr<float>; | |||
903 | template class DenseArrayAttr<double>; | |||
904 | } // namespace detail | |||
905 | } // namespace mlir | |||
906 | ||||
907 | //===----------------------------------------------------------------------===// | |||
908 | // DenseElementsAttr | |||
909 | //===----------------------------------------------------------------------===// | |||
910 | ||||
911 | /// Method for support type inquiry through isa, cast and dyn_cast. | |||
912 | bool DenseElementsAttr::classof(Attribute attr) { | |||
913 | return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); | |||
914 | } | |||
915 | ||||
916 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
917 | ArrayRef<Attribute> values) { | |||
918 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__ )); | |||
919 | ||||
920 | // If the element type is not based on int/float/index, assume it is a string | |||
921 | // type. | |||
922 | Type eltType = type.getElementType(); | |||
923 | if (!eltType.isIntOrIndexOrFloat()) { | |||
924 | SmallVector<StringRef, 8> stringValues; | |||
925 | stringValues.reserve(values.size()); | |||
926 | for (Attribute attr : values) { | |||
927 | assert(attr.isa<StringAttr>() &&(static_cast <bool> (attr.isa<StringAttr>() && "expected string value for non integer/index/float element") ? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__ )) | |||
928 | "expected string value for non integer/index/float element")(static_cast <bool> (attr.isa<StringAttr>() && "expected string value for non integer/index/float element") ? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__ )); | |||
929 | stringValues.push_back(attr.cast<StringAttr>().getValue()); | |||
930 | } | |||
931 | return get(type, stringValues); | |||
932 | } | |||
933 | ||||
934 | // Otherwise, get the raw storage width to use for the allocation. | |||
935 | size_t bitWidth = getDenseElementBitWidth(eltType); | |||
936 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); | |||
937 | ||||
938 | // Compress the attribute values into a character buffer. | |||
939 | SmallVector<char, 8> data( | |||
940 | llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT8)); | |||
941 | APInt intVal; | |||
942 | for (unsigned i = 0, e = values.size(); i < e; ++i) { | |||
943 | if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) { | |||
944 | assert(floatAttr.getType() == eltType &&(static_cast <bool> (floatAttr.getType() == eltType && "expected float attribute type to equal element type") ? void (0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 945, __extension__ __PRETTY_FUNCTION__ )) | |||
945 | "expected float attribute type to equal element type")(static_cast <bool> (floatAttr.getType() == eltType && "expected float attribute type to equal element type") ? void (0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 945, __extension__ __PRETTY_FUNCTION__ )); | |||
946 | intVal = floatAttr.getValue().bitcastToAPInt(); | |||
947 | } else { | |||
948 | auto intAttr = values[i].cast<IntegerAttr>(); | |||
949 | assert(intAttr.getType() == eltType &&(static_cast <bool> (intAttr.getType() == eltType && "expected integer attribute type to equal element type") ? void (0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 950, __extension__ __PRETTY_FUNCTION__ )) | |||
950 | "expected integer attribute type to equal element type")(static_cast <bool> (intAttr.getType() == eltType && "expected integer attribute type to equal element type") ? void (0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 950, __extension__ __PRETTY_FUNCTION__ )); | |||
951 | intVal = intAttr.getValue(); | |||
952 | } | |||
953 | ||||
954 | assert(intVal.getBitWidth() == bitWidth &&(static_cast <bool> (intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type") ? void (0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 955, __extension__ __PRETTY_FUNCTION__ )) | |||
955 | "expected value to have same bitwidth as element type")(static_cast <bool> (intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type") ? void (0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 955, __extension__ __PRETTY_FUNCTION__ )); | |||
956 | writeBits(data.data(), i * storageBitWidth, intVal); | |||
957 | } | |||
958 | ||||
959 | // Handle the special encoding of splat of bool. | |||
960 | if (values.size() == 1 && eltType.isInteger(1)) | |||
961 | data[0] = data[0] ? -1 : 0; | |||
962 | ||||
963 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
964 | } | |||
965 | ||||
966 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
967 | ArrayRef<bool> values) { | |||
968 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 968, __extension__ __PRETTY_FUNCTION__ )); | |||
969 | assert(type.getElementType().isInteger(1))(static_cast <bool> (type.getElementType().isInteger(1) ) ? void (0) : __assert_fail ("type.getElementType().isInteger(1)" , "mlir/lib/IR/BuiltinAttributes.cpp", 969, __extension__ __PRETTY_FUNCTION__ )); | |||
970 | ||||
971 | std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT8)); | |||
972 | ||||
973 | if (!values.empty()) { | |||
974 | bool isSplat = true; | |||
975 | bool firstValue = values[0]; | |||
976 | for (int i = 0, e = values.size(); i != e; ++i) { | |||
977 | isSplat &= values[i] == firstValue; | |||
978 | setBit(buff.data(), i, values[i]); | |||
979 | } | |||
980 | ||||
981 | // Splat of bool is encoded as a byte with all-ones in it. | |||
982 | if (isSplat) { | |||
983 | buff.resize(1); | |||
984 | buff[0] = values[0] ? -1 : 0; | |||
985 | } | |||
986 | } | |||
987 | ||||
988 | return DenseIntOrFPElementsAttr::getRaw(type, buff); | |||
989 | } | |||
990 | ||||
991 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
992 | ArrayRef<StringRef> values) { | |||
993 | assert(!type.getElementType().isIntOrFloat())(static_cast <bool> (!type.getElementType().isIntOrFloat ()) ? void (0) : __assert_fail ("!type.getElementType().isIntOrFloat()" , "mlir/lib/IR/BuiltinAttributes.cpp", 993, __extension__ __PRETTY_FUNCTION__ )); | |||
994 | return DenseStringElementsAttr::get(type, values); | |||
995 | } | |||
996 | ||||
997 | /// Constructs a dense integer elements attribute from an array of APInt | |||
998 | /// values. Each APInt value is expected to have the same bitwidth as the | |||
999 | /// element type of 'type'. | |||
1000 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
1001 | ArrayRef<APInt> values) { | |||
1002 | assert(type.getElementType().isIntOrIndex())(static_cast <bool> (type.getElementType().isIntOrIndex ()) ? void (0) : __assert_fail ("type.getElementType().isIntOrIndex()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1002, __extension__ __PRETTY_FUNCTION__ )); | |||
1003 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1003, __extension__ __PRETTY_FUNCTION__ )); | |||
1004 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); | |||
1005 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); | |||
1006 | } | |||
1007 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
1008 | ArrayRef<std::complex<APInt>> values) { | |||
1009 | ComplexType complex = type.getElementType().cast<ComplexType>(); | |||
1010 | assert(complex.getElementType().isa<IntegerType>())(static_cast <bool> (complex.getElementType().isa<IntegerType >()) ? void (0) : __assert_fail ("complex.getElementType().isa<IntegerType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1010, __extension__ __PRETTY_FUNCTION__ )); | |||
1011 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1011, __extension__ __PRETTY_FUNCTION__ )); | |||
1012 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; | |||
1013 | ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), | |||
1014 | values.size() * 2); | |||
1015 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); | |||
1016 | } | |||
1017 | ||||
1018 | // Constructs a dense float elements attribute from an array of APFloat | |||
1019 | // values. Each APFloat value is expected to have the same bitwidth as the | |||
1020 | // element type of 'type'. | |||
1021 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
1022 | ArrayRef<APFloat> values) { | |||
1023 | assert(type.getElementType().isa<FloatType>())(static_cast <bool> (type.getElementType().isa<FloatType >()) ? void (0) : __assert_fail ("type.getElementType().isa<FloatType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1023, __extension__ __PRETTY_FUNCTION__ )); | |||
1024 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1024, __extension__ __PRETTY_FUNCTION__ )); | |||
1025 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); | |||
1026 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); | |||
1027 | } | |||
1028 | DenseElementsAttr | |||
1029 | DenseElementsAttr::get(ShapedType type, | |||
1030 | ArrayRef<std::complex<APFloat>> values) { | |||
1031 | ComplexType complex = type.getElementType().cast<ComplexType>(); | |||
1032 | assert(complex.getElementType().isa<FloatType>())(static_cast <bool> (complex.getElementType().isa<FloatType >()) ? void (0) : __assert_fail ("complex.getElementType().isa<FloatType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1032, __extension__ __PRETTY_FUNCTION__ )); | |||
1033 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1033, __extension__ __PRETTY_FUNCTION__ )); | |||
1034 | ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), | |||
1035 | values.size() * 2); | |||
1036 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; | |||
1037 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); | |||
1038 | } | |||
1039 | ||||
1040 | /// Construct a dense elements attribute from a raw buffer representing the | |||
1041 | /// data for this attribute. Users should generally not use this methods as | |||
1042 | /// the expected buffer format may not be a form the user expects. | |||
1043 | DenseElementsAttr | |||
1044 | DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { | |||
1045 | return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); | |||
1046 | } | |||
1047 | ||||
1048 | /// Returns true if the given buffer is a valid raw buffer for the given type. | |||
1049 | bool DenseElementsAttr::isValidRawBuffer(ShapedType type, | |||
1050 | ArrayRef<char> rawBuffer, | |||
1051 | bool &detectedSplat) { | |||
1052 | size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); | |||
1053 | size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT8; | |||
1054 | int64_t numElements = type.getNumElements(); | |||
1055 | ||||
1056 | // The initializer is always a splat if the result type has a single element. | |||
1057 | detectedSplat = numElements == 1; | |||
1058 | ||||
1059 | // Storage width of 1 is special as it is packed by the bit. | |||
1060 | if (storageWidth == 1) { | |||
1061 | // Check for a splat, or a buffer equal to the number of elements which | |||
1062 | // consists of either all 0's or all 1's. | |||
1063 | if (rawBuffer.size() == 1) { | |||
1064 | auto rawByte = static_cast<uint8_t>(rawBuffer[0]); | |||
1065 | if (rawByte == 0 || rawByte == 0xff) { | |||
1066 | detectedSplat = true; | |||
1067 | return true; | |||
1068 | } | |||
1069 | } | |||
1070 | ||||
1071 | // This is a valid non-splat buffer if it has the right size. | |||
1072 | return rawBufferWidth == llvm::alignTo<8>(numElements); | |||
1073 | } | |||
1074 | ||||
1075 | // All other types are 8-bit aligned, so we can just check the buffer width | |||
1076 | // to know if only a single initializer element was passed in. | |||
1077 | if (rawBufferWidth == storageWidth) { | |||
1078 | detectedSplat = true; | |||
1079 | return true; | |||
1080 | } | |||
1081 | ||||
1082 | // The raw buffer is valid if it has the right size. | |||
1083 | return rawBufferWidth == storageWidth * numElements; | |||
1084 | } | |||
1085 | ||||
1086 | /// Check the information for a C++ data type, check if this type is valid for | |||
1087 | /// the current attribute. This method is used to verify specific type | |||
1088 | /// invariants that the templatized 'getValues' method cannot. | |||
1089 | static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, | |||
1090 | bool isSigned) { | |||
1091 | // Make sure that the data element size is the same as the type element width. | |||
1092 | if (getDenseElementBitWidth(type) != | |||
1093 | static_cast<size_t>(dataEltSize * CHAR_BIT8)) | |||
1094 | return false; | |||
1095 | ||||
1096 | // Check that the element type is either float or integer or index. | |||
1097 | if (!isInt) | |||
1098 | return type.isa<FloatType>(); | |||
1099 | if (type.isIndex()) | |||
1100 | return true; | |||
1101 | ||||
1102 | auto intType = type.dyn_cast<IntegerType>(); | |||
1103 | if (!intType) | |||
1104 | return false; | |||
1105 | ||||
1106 | // Make sure signedness semantics is consistent. | |||
1107 | if (intType.isSignless()) | |||
1108 | return true; | |||
1109 | return intType.isSigned() ? isSigned : !isSigned; | |||
1110 | } | |||
1111 | ||||
1112 | /// Defaults down the subclass implementation. | |||
1113 | DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, | |||
1114 | ArrayRef<char> data, | |||
1115 | int64_t dataEltSize, | |||
1116 | bool isInt, bool isSigned) { | |||
1117 | return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, | |||
1118 | isSigned); | |||
1119 | } | |||
1120 | DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, | |||
1121 | ArrayRef<char> data, | |||
1122 | int64_t dataEltSize, | |||
1123 | bool isInt, | |||
1124 | bool isSigned) { | |||
1125 | return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, | |||
1126 | isInt, isSigned); | |||
1127 | } | |||
1128 | ||||
1129 | bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, | |||
1130 | bool isSigned) const { | |||
1131 | return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); | |||
1132 | } | |||
1133 | bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, | |||
1134 | bool isSigned) const { | |||
1135 | return ::isValidIntOrFloat( | |||
1136 | getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, | |||
1137 | isInt, isSigned); | |||
1138 | } | |||
1139 | ||||
1140 | /// Returns true if this attribute corresponds to a splat, i.e. if all element | |||
1141 | /// values are the same. | |||
1142 | bool DenseElementsAttr::isSplat() const { | |||
1143 | return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; | |||
1144 | } | |||
1145 | ||||
1146 | /// Return if the given complex type has an integer element type. | |||
1147 | LLVM_ATTRIBUTE_UNUSED__attribute__((__unused__)) static bool isComplexOfIntType(Type type) { | |||
1148 | return type.cast<ComplexType>().getElementType().isa<IntegerType>(); | |||
1149 | } | |||
1150 | ||||
1151 | auto DenseElementsAttr::getComplexIntValues() const | |||
1152 | -> iterator_range_impl<ComplexIntElementIterator> { | |||
1153 | assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1154, __extension__ __PRETTY_FUNCTION__ )) | |||
1154 | "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1154, __extension__ __PRETTY_FUNCTION__ )); | |||
1155 | return {getType(), ComplexIntElementIterator(*this, 0), | |||
1156 | ComplexIntElementIterator(*this, getNumElements())}; | |||
1157 | } | |||
1158 | auto DenseElementsAttr::complex_value_begin() const | |||
1159 | -> ComplexIntElementIterator { | |||
1160 | assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1161, __extension__ __PRETTY_FUNCTION__ )) | |||
1161 | "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1161, __extension__ __PRETTY_FUNCTION__ )); | |||
1162 | return ComplexIntElementIterator(*this, 0); | |||
1163 | } | |||
1164 | auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator { | |||
1165 | assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1166, __extension__ __PRETTY_FUNCTION__ )) | |||
1166 | "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType( )) && "expected complex integral type") ? void (0) : __assert_fail ("isComplexOfIntType(getElementType()) && \"expected complex integral type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1166, __extension__ __PRETTY_FUNCTION__ )); | |||
1167 | return ComplexIntElementIterator(*this, getNumElements()); | |||
1168 | } | |||
1169 | ||||
1170 | /// Return the held element values as a range of APFloat. The element type of | |||
1171 | /// this attribute must be of float type. | |||
1172 | auto DenseElementsAttr::getFloatValues() const | |||
1173 | -> iterator_range_impl<FloatElementIterator> { | |||
1174 | auto elementType = getElementType().cast<FloatType>(); | |||
1175 | const auto &elementSemantics = elementType.getFloatSemantics(); | |||
1176 | return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()), | |||
1177 | FloatElementIterator(elementSemantics, raw_int_end())}; | |||
1178 | } | |||
1179 | auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator { | |||
1180 | auto elementType = getElementType().cast<FloatType>(); | |||
1181 | return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin()); | |||
1182 | } | |||
1183 | auto DenseElementsAttr::float_value_end() const -> FloatElementIterator { | |||
1184 | auto elementType = getElementType().cast<FloatType>(); | |||
1185 | return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end()); | |||
1186 | } | |||
1187 | ||||
1188 | auto DenseElementsAttr::getComplexFloatValues() const | |||
1189 | -> iterator_range_impl<ComplexFloatElementIterator> { | |||
1190 | Type eltTy = getElementType().cast<ComplexType>().getElementType(); | |||
1191 | assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() && "expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1191, __extension__ __PRETTY_FUNCTION__ )); | |||
1192 | const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics(); | |||
1193 | return {getType(), | |||
1194 | {semantics, {*this, 0}}, | |||
1195 | {semantics, {*this, static_cast<size_t>(getNumElements())}}}; | |||
1196 | } | |||
1197 | auto DenseElementsAttr::complex_float_value_begin() const | |||
1198 | -> ComplexFloatElementIterator { | |||
1199 | Type eltTy = getElementType().cast<ComplexType>().getElementType(); | |||
1200 | assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() && "expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1200, __extension__ __PRETTY_FUNCTION__ )); | |||
1201 | return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}}; | |||
1202 | } | |||
1203 | auto DenseElementsAttr::complex_float_value_end() const | |||
1204 | -> ComplexFloatElementIterator { | |||
1205 | Type eltTy = getElementType().cast<ComplexType>().getElementType(); | |||
1206 | assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() && "expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__ )); | |||
1207 | return {eltTy.cast<FloatType>().getFloatSemantics(), | |||
1208 | {*this, static_cast<size_t>(getNumElements())}}; | |||
1209 | } | |||
1210 | ||||
1211 | /// Return the raw storage data held by this attribute. | |||
1212 | ArrayRef<char> DenseElementsAttr::getRawData() const { | |||
1213 | return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; | |||
1214 | } | |||
1215 | ||||
1216 | ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { | |||
1217 | return static_cast<DenseStringElementsAttrStorage *>(impl)->data; | |||
1218 | } | |||
1219 | ||||
1220 | /// Return a new DenseElementsAttr that has the same data as the current | |||
1221 | /// attribute, but has been reshaped to 'newType'. The new type must have the | |||
1222 | /// same total number of elements as well as element type. | |||
1223 | DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { | |||
1224 | ShapedType curType = getType(); | |||
1225 | if (curType == newType) | |||
1226 | return *this; | |||
1227 | ||||
1228 | assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1229, __extension__ __PRETTY_FUNCTION__ )) | |||
1229 | "expected the same element type")(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1229, __extension__ __PRETTY_FUNCTION__ )); | |||
1230 | assert(newType.getNumElements() == curType.getNumElements() &&(static_cast <bool> (newType.getNumElements() == curType .getNumElements() && "expected the same number of elements" ) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1231, __extension__ __PRETTY_FUNCTION__ )) | |||
1231 | "expected the same number of elements")(static_cast <bool> (newType.getNumElements() == curType .getNumElements() && "expected the same number of elements" ) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1231, __extension__ __PRETTY_FUNCTION__ )); | |||
1232 | return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); | |||
1233 | } | |||
1234 | ||||
1235 | DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { | |||
1236 | assert(isSplat() && "expected a splat type")(static_cast <bool> (isSplat() && "expected a splat type" ) ? void (0) : __assert_fail ("isSplat() && \"expected a splat type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1236, __extension__ __PRETTY_FUNCTION__ )); | |||
1237 | ||||
1238 | ShapedType curType = getType(); | |||
1239 | if (curType == newType) | |||
1240 | return *this; | |||
1241 | ||||
1242 | assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1243, __extension__ __PRETTY_FUNCTION__ )) | |||
1243 | "expected the same element type")(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1243, __extension__ __PRETTY_FUNCTION__ )); | |||
1244 | return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); | |||
1245 | } | |||
1246 | ||||
1247 | /// Return a new DenseElementsAttr that has the same data as the current | |||
1248 | /// attribute, but has bitcast elements such that it is now 'newType'. The new | |||
1249 | /// type must have the same shape and element types of the same bitwidth as the | |||
1250 | /// current type. | |||
1251 | DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { | |||
1252 | ShapedType curType = getType(); | |||
1253 | Type curElType = curType.getElementType(); | |||
1254 | if (curElType == newElType) | |||
1255 | return *this; | |||
1256 | ||||
1257 | assert(getDenseElementBitWidth(newElType) ==(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__ )) | |||
1258 | getDenseElementBitWidth(curElType) &&(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__ )) | |||
1259 | "expected element types with the same bitwidth")(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__ )); | |||
1260 | return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), | |||
1261 | getRawData()); | |||
1262 | } | |||
1263 | ||||
1264 | DenseElementsAttr | |||
1265 | DenseElementsAttr::mapValues(Type newElementType, | |||
1266 | function_ref<APInt(const APInt &)> mapping) const { | |||
1267 | return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); | |||
1268 | } | |||
1269 | ||||
1270 | DenseElementsAttr DenseElementsAttr::mapValues( | |||
1271 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { | |||
1272 | return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); | |||
1273 | } | |||
1274 | ||||
1275 | ShapedType DenseElementsAttr::getType() const { | |||
1276 | return static_cast<const DenseElementsAttributeStorage *>(impl)->type; | |||
1277 | } | |||
1278 | ||||
1279 | Type DenseElementsAttr::getElementType() const { | |||
1280 | return getType().getElementType(); | |||
1281 | } | |||
1282 | ||||
1283 | int64_t DenseElementsAttr::getNumElements() const { | |||
1284 | return getType().getNumElements(); | |||
1285 | } | |||
1286 | ||||
1287 | //===----------------------------------------------------------------------===// | |||
1288 | // DenseIntOrFPElementsAttr | |||
1289 | //===----------------------------------------------------------------------===// | |||
1290 | ||||
1291 | /// Utility method to write a range of APInt values to a buffer. | |||
1292 | template <typename APRangeT> | |||
1293 | static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, | |||
1294 | APRangeT &&values) { | |||
1295 | size_t numValues = llvm::size(values); | |||
1296 | data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT8)); | |||
1297 | size_t offset = 0; | |||
1298 | for (auto it = values.begin(), e = values.end(); it != e; | |||
1299 | ++it, offset += storageWidth) { | |||
1300 | assert((*it).getBitWidth() <= storageWidth)(static_cast <bool> ((*it).getBitWidth() <= storageWidth ) ? void (0) : __assert_fail ("(*it).getBitWidth() <= storageWidth" , "mlir/lib/IR/BuiltinAttributes.cpp", 1300, __extension__ __PRETTY_FUNCTION__ )); | |||
1301 | writeBits(data.data(), offset, *it); | |||
1302 | } | |||
1303 | ||||
1304 | // Handle the special encoding of splat of a boolean. | |||
1305 | if (numValues == 1 && (*values.begin()).getBitWidth() == 1) | |||
1306 | data[0] = data[0] ? -1 : 0; | |||
1307 | } | |||
1308 | ||||
1309 | /// Constructs a dense elements attribute from an array of raw APFloat values. | |||
1310 | /// Each APFloat value is expected to have the same bitwidth as the element | |||
1311 | /// type of 'type'. 'type' must be a vector or tensor with static shape. | |||
1312 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1313 | size_t storageWidth, | |||
1314 | ArrayRef<APFloat> values) { | |||
1315 | std::vector<char> data; | |||
1316 | auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; | |||
1317 | writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); | |||
1318 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
1319 | } | |||
1320 | ||||
1321 | /// Constructs a dense elements attribute from an array of raw APInt values. | |||
1322 | /// Each APInt value is expected to have the same bitwidth as the element type | |||
1323 | /// of 'type'. | |||
1324 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1325 | size_t storageWidth, | |||
1326 | ArrayRef<APInt> values) { | |||
1327 | std::vector<char> data; | |||
1328 | writeAPIntsToBuffer(storageWidth, data, values); | |||
1329 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
1330 | } | |||
1331 | ||||
1332 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1333 | ArrayRef<char> data) { | |||
1334 | assert((type.isa<RankedTensorType, VectorType>()) &&(static_cast <bool> ((type.isa<RankedTensorType, VectorType >()) && "type must be ranked tensor or vector") ? void (0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1335, __extension__ __PRETTY_FUNCTION__ )) | |||
1335 | "type must be ranked tensor or vector")(static_cast <bool> ((type.isa<RankedTensorType, VectorType >()) && "type must be ranked tensor or vector") ? void (0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1335, __extension__ __PRETTY_FUNCTION__ )); | |||
1336 | assert(type.hasStaticShape() && "type must have static shape")(static_cast <bool> (type.hasStaticShape() && "type must have static shape" ) ? void (0) : __assert_fail ("type.hasStaticShape() && \"type must have static shape\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1336, __extension__ __PRETTY_FUNCTION__ )); | |||
1337 | bool isSplat = false; | |||
1338 | bool isValid = isValidRawBuffer(type, data, isSplat); | |||
1339 | assert(isValid)(static_cast <bool> (isValid) ? void (0) : __assert_fail ("isValid", "mlir/lib/IR/BuiltinAttributes.cpp", 1339, __extension__ __PRETTY_FUNCTION__)); | |||
1340 | (void)isValid; | |||
1341 | return Base::get(type.getContext(), type, data, isSplat); | |||
1342 | } | |||
1343 | ||||
1344 | /// Overload of the raw 'get' method that asserts that the given type is of | |||
1345 | /// complex type. This method is used to verify type invariants that the | |||
1346 | /// templatized 'get' method cannot. | |||
1347 | DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, | |||
1348 | ArrayRef<char> data, | |||
1349 | int64_t dataEltSize, | |||
1350 | bool isInt, | |||
1351 | bool isSigned) { | |||
1352 | assert(::isValidIntOrFloat((static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__ )) | |||
1353 | type.getElementType().cast<ComplexType>().getElementType(),(static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__ )) | |||
1354 | dataEltSize / 2, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__ )); | |||
1355 | ||||
1356 | int64_t numElements = data.size() / dataEltSize; | |||
1357 | (void)numElements; | |||
1358 | assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements == type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1358, __extension__ __PRETTY_FUNCTION__ )); | |||
1359 | return getRaw(type, data); | |||
1360 | } | |||
1361 | ||||
1362 | /// Overload of the 'getRaw' method that asserts that the given type is of | |||
1363 | /// integer type. This method is used to verify type invariants that the | |||
1364 | /// templatized 'get' method cannot. | |||
1365 | DenseElementsAttr | |||
1366 | DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, | |||
1367 | int64_t dataEltSize, bool isInt, | |||
1368 | bool isSigned) { | |||
1369 | assert((static_cast <bool> (::isValidIntOrFloat(type.getElementType (), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1370, __extension__ __PRETTY_FUNCTION__ )) | |||
1370 | ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat(type.getElementType (), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1370, __extension__ __PRETTY_FUNCTION__ )); | |||
1371 | ||||
1372 | int64_t numElements = data.size() / dataEltSize; | |||
1373 | assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements == type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1373, __extension__ __PRETTY_FUNCTION__ )); | |||
1374 | (void)numElements; | |||
1375 | return getRaw(type, data); | |||
1376 | } | |||
1377 | ||||
1378 | void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
1379 | const char *inRawData, char *outRawData, size_t elementBitWidth, | |||
1380 | size_t numElements) { | |||
1381 | using llvm::support::ulittle16_t; | |||
1382 | using llvm::support::ulittle32_t; | |||
1383 | using llvm::support::ulittle64_t; | |||
1384 | ||||
1385 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 1386, __extension__ __PRETTY_FUNCTION__ )) | |||
1386 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 1386, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
1387 | // NOLINT to avoid warning message about replacing by static_assert() | |||
1388 | ||||
1389 | // Following std::copy_n always converts endianness on BE machine. | |||
1390 | switch (elementBitWidth) { | |||
1391 | case 16: { | |||
1392 | const ulittle16_t *inRawDataPos = | |||
1393 | reinterpret_cast<const ulittle16_t *>(inRawData); | |||
1394 | uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); | |||
1395 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1396 | break; | |||
1397 | } | |||
1398 | case 32: { | |||
1399 | const ulittle32_t *inRawDataPos = | |||
1400 | reinterpret_cast<const ulittle32_t *>(inRawData); | |||
1401 | uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); | |||
1402 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1403 | break; | |||
1404 | } | |||
1405 | case 64: { | |||
1406 | const ulittle64_t *inRawDataPos = | |||
1407 | reinterpret_cast<const ulittle64_t *>(inRawData); | |||
1408 | uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); | |||
1409 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1410 | break; | |||
1411 | } | |||
1412 | default: { | |||
1413 | size_t nBytes = elementBitWidth / CHAR_BIT8; | |||
1414 | for (size_t i = 0; i < nBytes; i++) | |||
1415 | std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); | |||
1416 | break; | |||
1417 | } | |||
1418 | } | |||
1419 | } | |||
1420 | ||||
1421 | void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( | |||
1422 | ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, | |||
1423 | ShapedType type) { | |||
1424 | size_t numElements = type.getNumElements(); | |||
1425 | Type elementType = type.getElementType(); | |||
1426 | if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { | |||
1427 | elementType = complexTy.getElementType(); | |||
1428 | numElements = numElements * 2; | |||
1429 | } | |||
1430 | size_t elementBitWidth = getDenseElementStorageWidth(elementType); | |||
1431 | assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&(static_cast <bool> (numElements * elementBitWidth == inRawData .size() * 8 && inRawData.size() <= outRawData.size ()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1432, __extension__ __PRETTY_FUNCTION__ )) | |||
1432 | inRawData.size() <= outRawData.size())(static_cast <bool> (numElements * elementBitWidth == inRawData .size() * 8 && inRawData.size() <= outRawData.size ()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1432, __extension__ __PRETTY_FUNCTION__ )); | |||
1433 | if (elementBitWidth <= CHAR_BIT8) | |||
1434 | std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); | |||
1435 | else | |||
1436 | convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), | |||
1437 | elementBitWidth, numElements); | |||
1438 | } | |||
1439 | ||||
1440 | //===----------------------------------------------------------------------===// | |||
1441 | // DenseFPElementsAttr | |||
1442 | //===----------------------------------------------------------------------===// | |||
1443 | ||||
1444 | template <typename Fn, typename Attr> | |||
1445 | static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, | |||
1446 | Type newElementType, | |||
1447 | llvm::SmallVectorImpl<char> &data) { | |||
1448 | size_t bitWidth = getDenseElementBitWidth(newElementType); | |||
1449 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); | |||
1450 | ||||
1451 | ShapedType newArrayType; | |||
1452 | if (inType.isa<RankedTensorType>()) | |||
1453 | newArrayType = RankedTensorType::get(inType.getShape(), newElementType); | |||
1454 | else if (inType.isa<UnrankedTensorType>()) | |||
1455 | newArrayType = RankedTensorType::get(inType.getShape(), newElementType); | |||
1456 | else if (auto vType = inType.dyn_cast<VectorType>()) | |||
1457 | newArrayType = VectorType::get(vType.getShape(), newElementType, | |||
1458 | vType.getNumScalableDims()); | |||
1459 | else | |||
1460 | assert(newArrayType && "Unhandled tensor type")(static_cast <bool> (newArrayType && "Unhandled tensor type" ) ? void (0) : __assert_fail ("newArrayType && \"Unhandled tensor type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1460, __extension__ __PRETTY_FUNCTION__ )); | |||
1461 | ||||
1462 | size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); | |||
1463 | data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT8)); | |||
1464 | ||||
1465 | // Functor used to process a single element value of the attribute. | |||
1466 | auto processElt = [&](decltype(*attr.begin()) value, size_t index) { | |||
1467 | auto newInt = mapping(value); | |||
1468 | assert(newInt.getBitWidth() == bitWidth)(static_cast <bool> (newInt.getBitWidth() == bitWidth) ? void (0) : __assert_fail ("newInt.getBitWidth() == bitWidth" , "mlir/lib/IR/BuiltinAttributes.cpp", 1468, __extension__ __PRETTY_FUNCTION__ )); | |||
1469 | writeBits(data.data(), index * storageBitWidth, newInt); | |||
1470 | }; | |||
1471 | ||||
1472 | // Check for the splat case. | |||
1473 | if (attr.isSplat()) { | |||
1474 | processElt(*attr.begin(), /*index=*/0); | |||
1475 | return newArrayType; | |||
1476 | } | |||
1477 | ||||
1478 | // Otherwise, process all of the element values. | |||
1479 | uint64_t elementIdx = 0; | |||
1480 | for (auto value : attr) | |||
1481 | processElt(value, elementIdx++); | |||
1482 | return newArrayType; | |||
1483 | } | |||
1484 | ||||
1485 | DenseElementsAttr DenseFPElementsAttr::mapValues( | |||
1486 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { | |||
1487 | llvm::SmallVector<char, 8> elementData; | |||
1488 | auto newArrayType = | |||
1489 | mappingHelper(mapping, *this, getType(), newElementType, elementData); | |||
1490 | ||||
1491 | return getRaw(newArrayType, elementData); | |||
1492 | } | |||
1493 | ||||
1494 | /// Method for supporting type inquiry through isa, cast and dyn_cast. | |||
1495 | bool DenseFPElementsAttr::classof(Attribute attr) { | |||
1496 | if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) | |||
1497 | return denseAttr.getType().getElementType().isa<FloatType>(); | |||
1498 | return false; | |||
1499 | } | |||
1500 | ||||
1501 | //===----------------------------------------------------------------------===// | |||
1502 | // DenseIntElementsAttr | |||
1503 | //===----------------------------------------------------------------------===// | |||
1504 | ||||
1505 | DenseElementsAttr DenseIntElementsAttr::mapValues( | |||
1506 | Type newElementType, function_ref<APInt(const APInt &)> mapping) const { | |||
1507 | llvm::SmallVector<char, 8> elementData; | |||
1508 | auto newArrayType = | |||
1509 | mappingHelper(mapping, *this, getType(), newElementType, elementData); | |||
1510 | return getRaw(newArrayType, elementData); | |||
1511 | } | |||
1512 | ||||
1513 | /// Method for supporting type inquiry through isa, cast and dyn_cast. | |||
1514 | bool DenseIntElementsAttr::classof(Attribute attr) { | |||
1515 | if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) | |||
1516 | return denseAttr.getType().getElementType().isIntOrIndex(); | |||
1517 | return false; | |||
1518 | } | |||
1519 | ||||
1520 | //===----------------------------------------------------------------------===// | |||
1521 | // DenseResourceElementsAttr | |||
1522 | //===----------------------------------------------------------------------===// | |||
1523 | ||||
1524 | DenseResourceElementsAttr | |||
1525 | DenseResourceElementsAttr::get(ShapedType type, | |||
1526 | DenseResourceElementsHandle handle) { | |||
1527 | return Base::get(type.getContext(), type, handle); | |||
1528 | } | |||
1529 | ||||
1530 | DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, | |||
1531 | StringRef blobName, | |||
1532 | AsmResourceBlob blob) { | |||
1533 | // Extract the builtin dialect resource manager from context and construct a | |||
1534 | // handle by inserting a new resource using the provided blob. | |||
1535 | auto &manager = | |||
1536 | DenseResourceElementsHandle::getManagerInterface(type.getContext()); | |||
1537 | return get(type, manager.insert(blobName, std::move(blob))); | |||
1538 | } | |||
1539 | ||||
1540 | //===----------------------------------------------------------------------===// | |||
1541 | // DenseResourceElementsAttrBase | |||
1542 | ||||
1543 | namespace { | |||
1544 | /// Instantiations of this class provide utilities for interacting with native | |||
1545 | /// data types in the context of DenseResourceElementsAttr. | |||
1546 | template <typename T> | |||
1547 | struct DenseResourceAttrUtil; | |||
1548 | template <size_t width, bool isSigned> | |||
1549 | struct DenseResourceElementsAttrIntUtil { | |||
1550 | static bool checkElementType(Type eltType) { | |||
1551 | IntegerType type = eltType.dyn_cast<IntegerType>(); | |||
1552 | if (!type || type.getWidth() != width) | |||
1553 | return false; | |||
1554 | return isSigned ? !type.isUnsigned() : !type.isSigned(); | |||
1555 | } | |||
1556 | }; | |||
1557 | template <> | |||
1558 | struct DenseResourceAttrUtil<bool> { | |||
1559 | static bool checkElementType(Type eltType) { | |||
1560 | return eltType.isSignlessInteger(1); | |||
1561 | } | |||
1562 | }; | |||
1563 | template <> | |||
1564 | struct DenseResourceAttrUtil<int8_t> | |||
1565 | : public DenseResourceElementsAttrIntUtil<8, true> {}; | |||
1566 | template <> | |||
1567 | struct DenseResourceAttrUtil<uint8_t> | |||
1568 | : public DenseResourceElementsAttrIntUtil<8, false> {}; | |||
1569 | template <> | |||
1570 | struct DenseResourceAttrUtil<int16_t> | |||
1571 | : public DenseResourceElementsAttrIntUtil<16, true> {}; | |||
1572 | template <> | |||
1573 | struct DenseResourceAttrUtil<uint16_t> | |||
1574 | : public DenseResourceElementsAttrIntUtil<16, false> {}; | |||
1575 | template <> | |||
1576 | struct DenseResourceAttrUtil<int32_t> | |||
1577 | : public DenseResourceElementsAttrIntUtil<32, true> {}; | |||
1578 | template <> | |||
1579 | struct DenseResourceAttrUtil<uint32_t> | |||
1580 | : public DenseResourceElementsAttrIntUtil<32, false> {}; | |||
1581 | template <> | |||
1582 | struct DenseResourceAttrUtil<int64_t> | |||
1583 | : public DenseResourceElementsAttrIntUtil<64, true> {}; | |||
1584 | template <> | |||
1585 | struct DenseResourceAttrUtil<uint64_t> | |||
1586 | : public DenseResourceElementsAttrIntUtil<64, false> {}; | |||
1587 | template <> | |||
1588 | struct DenseResourceAttrUtil<float> { | |||
1589 | static bool checkElementType(Type eltType) { return eltType.isF32(); } | |||
1590 | }; | |||
1591 | template <> | |||
1592 | struct DenseResourceAttrUtil<double> { | |||
1593 | static bool checkElementType(Type eltType) { return eltType.isF64(); } | |||
1594 | }; | |||
1595 | } // namespace | |||
1596 | ||||
1597 | template <typename T> | |||
1598 | DenseResourceElementsAttrBase<T> | |||
1599 | DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName, | |||
1600 | AsmResourceBlob blob) { | |||
1601 | // Check that the blob is in the form we were expecting. | |||
1602 | assert(blob.getDataAlignment() == alignof(T) &&(static_cast <bool> (blob.getDataAlignment() == alignof (T) && "alignment mismatch between expected alignment and blob alignment" ) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1603, __extension__ __PRETTY_FUNCTION__ )) | |||
1603 | "alignment mismatch between expected alignment and blob alignment")(static_cast <bool> (blob.getDataAlignment() == alignof (T) && "alignment mismatch between expected alignment and blob alignment" ) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1603, __extension__ __PRETTY_FUNCTION__ )); | |||
1604 | assert(((blob.getData().size() % sizeof(T)) == 0) &&(static_cast <bool> (((blob.getData().size() % sizeof(T )) == 0) && "size mismatch between expected element width and blob size" ) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1605, __extension__ __PRETTY_FUNCTION__ )) | |||
1605 | "size mismatch between expected element width and blob size")(static_cast <bool> (((blob.getData().size() % sizeof(T )) == 0) && "size mismatch between expected element width and blob size" ) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1605, __extension__ __PRETTY_FUNCTION__ )); | |||
1606 | assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType (type.getElementType()) && "invalid shape element type for provided type `T`" ) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1607, __extension__ __PRETTY_FUNCTION__ )) | |||
1607 | "invalid shape element type for provided type `T`")(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType (type.getElementType()) && "invalid shape element type for provided type `T`" ) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1607, __extension__ __PRETTY_FUNCTION__ )); | |||
1608 | return DenseResourceElementsAttr::get(type, blobName, std::move(blob)) | |||
1609 | .template cast<DenseResourceElementsAttrBase<T>>(); | |||
1610 | } | |||
1611 | ||||
1612 | template <typename T> | |||
1613 | Optional<ArrayRef<T>> | |||
1614 | DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const { | |||
1615 | if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) | |||
1616 | return blob->template getDataAs<T>(); | |||
1617 | return llvm::None; | |||
1618 | } | |||
1619 | ||||
1620 | template <typename T> | |||
1621 | bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) { | |||
1622 | auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>(); | |||
1623 | return resourceAttr && DenseResourceAttrUtil<T>::checkElementType( | |||
1624 | resourceAttr.getElementType()); | |||
1625 | } | |||
1626 | ||||
1627 | namespace mlir { | |||
1628 | namespace detail { | |||
1629 | // Explicit instantiation for all the supported DenseResourceElementsAttr. | |||
1630 | template class DenseResourceElementsAttrBase<bool>; | |||
1631 | template class DenseResourceElementsAttrBase<int8_t>; | |||
1632 | template class DenseResourceElementsAttrBase<int16_t>; | |||
1633 | template class DenseResourceElementsAttrBase<int32_t>; | |||
1634 | template class DenseResourceElementsAttrBase<int64_t>; | |||
1635 | template class DenseResourceElementsAttrBase<uint8_t>; | |||
1636 | template class DenseResourceElementsAttrBase<uint16_t>; | |||
1637 | template class DenseResourceElementsAttrBase<uint32_t>; | |||
1638 | template class DenseResourceElementsAttrBase<uint64_t>; | |||
1639 | template class DenseResourceElementsAttrBase<float>; | |||
1640 | template class DenseResourceElementsAttrBase<double>; | |||
1641 | } // namespace detail | |||
1642 | } // namespace mlir | |||
1643 | ||||
1644 | //===----------------------------------------------------------------------===// | |||
1645 | // SparseElementsAttr | |||
1646 | //===----------------------------------------------------------------------===// | |||
1647 | ||||
1648 | /// Get a zero APFloat for the given sparse attribute. | |||
1649 | APFloat SparseElementsAttr::getZeroAPFloat() const { | |||
1650 | auto eltType = getElementType().cast<FloatType>(); | |||
1651 | return APFloat(eltType.getFloatSemantics()); | |||
1652 | } | |||
1653 | ||||
1654 | /// Get a zero APInt for the given sparse attribute. | |||
1655 | APInt SparseElementsAttr::getZeroAPInt() const { | |||
1656 | auto eltType = getElementType().cast<IntegerType>(); | |||
1657 | return APInt::getZero(eltType.getWidth()); | |||
1658 | } | |||
1659 | ||||
1660 | /// Get a zero attribute for the given attribute type. | |||
1661 | Attribute SparseElementsAttr::getZeroAttr() const { | |||
1662 | auto eltType = getElementType(); | |||
1663 | ||||
1664 | // Handle floating point elements. | |||
1665 | if (eltType.isa<FloatType>()) | |||
1666 | return FloatAttr::get(eltType, 0); | |||
1667 | ||||
1668 | // Handle complex elements. | |||
1669 | if (auto complexTy = eltType.dyn_cast<ComplexType>()) { | |||
1670 | auto eltType = complexTy.getElementType(); | |||
1671 | Attribute zero; | |||
1672 | if (eltType.isa<FloatType>()) | |||
1673 | zero = FloatAttr::get(eltType, 0); | |||
1674 | else // must be integer | |||
1675 | zero = IntegerAttr::get(eltType, 0); | |||
1676 | return ArrayAttr::get(complexTy.getContext(), | |||
1677 | ArrayRef<Attribute>{zero, zero}); | |||
1678 | } | |||
1679 | ||||
1680 | // Handle string type. | |||
1681 | if (getValues().isa<DenseStringElementsAttr>()) | |||
1682 | return StringAttr::get("", eltType); | |||
1683 | ||||
1684 | // Otherwise, this is an integer. | |||
1685 | return IntegerAttr::get(eltType, 0); | |||
1686 | } | |||
1687 | ||||
1688 | /// Flatten, and return, all of the sparse indices in this attribute in | |||
1689 | /// row-major order. | |||
1690 | std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { | |||
1691 | std::vector<ptrdiff_t> flatSparseIndices; | |||
1692 | ||||
1693 | // The sparse indices are 64-bit integers, so we can reinterpret the raw data | |||
1694 | // as a 1-D index array. | |||
1695 | auto sparseIndices = getIndices(); | |||
1696 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); | |||
1697 | if (sparseIndices.isSplat()) { | |||
1698 | SmallVector<uint64_t, 8> indices(getType().getRank(), | |||
1699 | *sparseIndexValues.begin()); | |||
1700 | flatSparseIndices.push_back(getFlattenedIndex(indices)); | |||
1701 | return flatSparseIndices; | |||
1702 | } | |||
1703 | ||||
1704 | // Otherwise, reinterpret each index as an ArrayRef when flattening. | |||
1705 | auto numSparseIndices = sparseIndices.getType().getDimSize(0); | |||
1706 | size_t rank = getType().getRank(); | |||
1707 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) | |||
1708 | flatSparseIndices.push_back(getFlattenedIndex( | |||
1709 | {&*std::next(sparseIndexValues.begin(), i * rank), rank})); | |||
1710 | return flatSparseIndices; | |||
1711 | } | |||
1712 | ||||
1713 | LogicalResult | |||
1714 | SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
1715 | ShapedType type, DenseIntElementsAttr sparseIndices, | |||
1716 | DenseElementsAttr values) { | |||
1717 | ShapedType valuesType = values.getType(); | |||
1718 | if (valuesType.getRank() != 1) | |||
1719 | return emitError() << "expected 1-d tensor for sparse element values"; | |||
1720 | ||||
1721 | // Verify the indices and values shape. | |||
1722 | ShapedType indicesType = sparseIndices.getType(); | |||
1723 | auto emitShapeError = [&]() { | |||
1724 | return emitError() << "expected shape ([" << type.getShape() | |||
1725 | << "]); inferred shape of indices literal ([" | |||
1726 | << indicesType.getShape() | |||
1727 | << "]); inferred shape of values literal ([" | |||
1728 | << valuesType.getShape() << "])"; | |||
1729 | }; | |||
1730 | // Verify indices shape. | |||
1731 | size_t rank = type.getRank(), indicesRank = indicesType.getRank(); | |||
1732 | if (indicesRank == 2) { | |||
1733 | if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) | |||
1734 | return emitShapeError(); | |||
1735 | } else if (indicesRank != 1 || rank != 1) { | |||
1736 | return emitShapeError(); | |||
1737 | } | |||
1738 | // Verify the values shape. | |||
1739 | int64_t numSparseIndices = indicesType.getDimSize(0); | |||
1740 | if (numSparseIndices != valuesType.getDimSize(0)) | |||
1741 | return emitShapeError(); | |||
1742 | ||||
1743 | // Verify that the sparse indices are within the value shape. | |||
1744 | auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { | |||
1745 | return emitError() | |||
1746 | << "sparse index #" << indexNum | |||
1747 | << " is not contained within the value shape, with index=[" << index | |||
1748 | << "], and type=" << type; | |||
1749 | }; | |||
1750 | ||||
1751 | // Handle the case where the index values are a splat. | |||
1752 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); | |||
1753 | if (sparseIndices.isSplat()) { | |||
1754 | SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); | |||
1755 | if (!ElementsAttr::isValidIndex(type, indices)) | |||
1756 | return emitIndexError(0, indices); | |||
1757 | return success(); | |||
1758 | } | |||
1759 | ||||
1760 | // Otherwise, reinterpret each index as an ArrayRef. | |||
1761 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) { | |||
1762 | ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), | |||
1763 | rank); | |||
1764 | if (!ElementsAttr::isValidIndex(type, index)) | |||
1765 | return emitIndexError(i, index); | |||
1766 | } | |||
1767 | ||||
1768 | return success(); | |||
1769 | } | |||
1770 | ||||
1771 | //===----------------------------------------------------------------------===// | |||
1772 | // TypeAttr | |||
1773 | //===----------------------------------------------------------------------===// | |||
1774 | ||||
1775 | void TypeAttr::walkImmediateSubElements( | |||
1776 | function_ref<void(Attribute)> walkAttrsFn, | |||
1777 | function_ref<void(Type)> walkTypesFn) const { | |||
1778 | walkTypesFn(getValue()); | |||
1779 | } | |||
1780 | ||||
1781 | Attribute | |||
1782 | TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs, | |||
1783 | ArrayRef<Type> replTypes) const { | |||
1784 | return get(replTypes[0]); | |||
1785 | } |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | class AsmParsedResourceEntry; |
24 | class AsmResourceBuilder; |
25 | class Builder; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // AsmDialectResourceHandle |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// This class represents an opaque handle to a dialect resource entry. |
32 | class AsmDialectResourceHandle { |
33 | public: |
34 | AsmDialectResourceHandle() = default; |
35 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
36 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
37 | bool operator==(const AsmDialectResourceHandle &other) const { |
38 | return resource == other.resource; |
39 | } |
40 | |
41 | /// Return an opaque pointer to the referenced resource. |
42 | void *getResource() const { return resource; } |
43 | |
44 | /// Return the type ID of the resource. |
45 | TypeID getTypeID() const { return opaqueID; } |
46 | |
47 | /// Return the dialect that owns the resource. |
48 | Dialect *getDialect() const { return dialect; } |
49 | |
50 | private: |
51 | /// The opaque handle to the dialect resource. |
52 | void *resource = nullptr; |
53 | /// The type of the resource referenced. |
54 | TypeID opaqueID; |
55 | /// The dialect owning the given resource. |
56 | Dialect *dialect; |
57 | }; |
58 | |
59 | /// This class represents a CRTP base class for dialect resource handles. It |
60 | /// abstracts away various utilities necessary for defined derived resource |
61 | /// handles. |
62 | template <typename DerivedT, typename ResourceT, typename DialectT> |
63 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
64 | public: |
65 | using Dialect = DialectT; |
66 | |
67 | /// Construct a handle from a pointer to the resource. The given pointer |
68 | /// should be guaranteed to live beyond the life of this handle. |
69 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
70 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
71 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
72 | : AsmDialectResourceHandle(handle) { |
73 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__ __PRETTY_FUNCTION__)); |
74 | } |
75 | |
76 | /// Return the resource referenced by this handle. |
77 | ResourceT *getResource() { |
78 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
79 | } |
80 | const ResourceT *getResource() const { |
81 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
82 | } |
83 | |
84 | /// Return the dialect that owns the resource. |
85 | DialectT *getDialect() const { |
86 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
87 | } |
88 | |
89 | /// Support llvm style casting. |
90 | static bool classof(const AsmDialectResourceHandle *handle) { |
91 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
92 | } |
93 | }; |
94 | |
95 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
96 | return llvm::hash_value(param.getResource()); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // AsmPrinter |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | /// This base class exposes generic asm printer hooks, usable across the various |
104 | /// derived printers. |
105 | class AsmPrinter { |
106 | public: |
107 | /// This class contains the internal default implementation of the base |
108 | /// printer methods. |
109 | class Impl; |
110 | |
111 | /// Initialize the printer with the given internal implementation. |
112 | AsmPrinter(Impl &impl) : impl(&impl) {} |
113 | virtual ~AsmPrinter(); |
114 | |
115 | /// Return the raw output stream used by this printer. |
116 | virtual raw_ostream &getStream() const; |
117 | |
118 | /// Print the given floating point value in a stabilized form that can be |
119 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
120 | /// hook on the AsmParser. |
121 | virtual void printFloat(const APFloat &value); |
122 | |
123 | virtual void printType(Type type); |
124 | virtual void printAttribute(Attribute attr); |
125 | |
126 | /// Trait to check if `AttrType` provides a `print` method. |
127 | template <typename AttrOrType> |
128 | using has_print_method = |
129 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
130 | template <typename AttrOrType> |
131 | using detect_has_print_method = |
132 | llvm::is_detected<has_print_method, AttrOrType>; |
133 | |
134 | /// Print the provided attribute in the context of an operation custom |
135 | /// printer/parser: this will invoke directly the print method on the |
136 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
137 | template <typename AttrOrType, |
138 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
139 | *sfinae = nullptr> |
140 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
141 | if (succeeded(printAlias(attrOrType))) |
142 | return; |
143 | attrOrType.print(*this); |
144 | } |
145 | |
146 | /// Print the provided array of attributes or types in the context of an |
147 | /// operation custom printer/parser: this will invoke directly the print |
148 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
149 | /// most cases. |
150 | template <typename AttrOrType, |
151 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
152 | *sfinae = nullptr> |
153 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
154 | llvm::interleaveComma( |
155 | attrOrTypes, getStream(), |
156 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
157 | } |
158 | |
159 | /// SFINAE for printing the provided attribute in the context of an operation |
160 | /// custom printer in the case where the attribute does not define a print |
161 | /// method. |
162 | template <typename AttrOrType, |
163 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
164 | *sfinae = nullptr> |
165 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
166 | *this << attrOrType; |
167 | } |
168 | |
169 | /// Print the given attribute without its type. The corresponding parser must |
170 | /// provide a valid type for the attribute. |
171 | virtual void printAttributeWithoutType(Attribute attr); |
172 | |
173 | /// Print the given string as a keyword, or a quoted and escaped string if it |
174 | /// has any special or non-printable characters in it. |
175 | virtual void printKeywordOrString(StringRef keyword); |
176 | |
177 | /// Print the given string as a symbol reference, i.e. a form representable by |
178 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
179 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
180 | /// special or non-printable characters in it. |
181 | virtual void printSymbolName(StringRef symbolRef); |
182 | |
183 | /// Print a handle to the given dialect resource. |
184 | void printResourceHandle(const AsmDialectResourceHandle &resource); |
185 | |
186 | /// Print an optional arrow followed by a type list. |
187 | template <typename TypeRange> |
188 | void printOptionalArrowTypeList(TypeRange &&types) { |
189 | if (types.begin() != types.end()) |
190 | printArrowTypeList(types); |
191 | } |
192 | template <typename TypeRange> |
193 | void printArrowTypeList(TypeRange &&types) { |
194 | auto &os = getStream() << " -> "; |
195 | |
196 | bool wrapped = !llvm::hasSingleElement(types) || |
197 | (*types.begin()).template isa<FunctionType>(); |
198 | if (wrapped) |
199 | os << '('; |
200 | llvm::interleaveComma(types, *this); |
201 | if (wrapped) |
202 | os << ')'; |
203 | } |
204 | |
205 | /// Print the two given type ranges in a functional form. |
206 | template <typename InputRangeT, typename ResultRangeT> |
207 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
208 | auto &os = getStream(); |
209 | os << '('; |
210 | llvm::interleaveComma(inputs, *this); |
211 | os << ')'; |
212 | printArrowTypeList(results); |
213 | } |
214 | |
215 | protected: |
216 | /// Initialize the printer with no internal implementation. In this case, all |
217 | /// virtual methods of this class must be overriden. |
218 | AsmPrinter() {} |
219 | |
220 | private: |
221 | AsmPrinter(const AsmPrinter &) = delete; |
222 | void operator=(const AsmPrinter &) = delete; |
223 | |
224 | /// Print the alias for the given attribute, return failure if no alias could |
225 | /// be printed. |
226 | virtual LogicalResult printAlias(Attribute attr); |
227 | |
228 | /// Print the alias for the given type, return failure if no alias could |
229 | /// be printed. |
230 | virtual LogicalResult printAlias(Type type); |
231 | |
232 | /// The internal implementation of the printer. |
233 | Impl *impl{nullptr}; |
234 | }; |
235 | |
236 | template <typename AsmPrinterT> |
237 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
238 | AsmPrinterT &> |
239 | operator<<(AsmPrinterT &p, Type type) { |
240 | p.printType(type); |
241 | return p; |
242 | } |
243 | |
244 | template <typename AsmPrinterT> |
245 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
246 | AsmPrinterT &> |
247 | operator<<(AsmPrinterT &p, Attribute attr) { |
248 | p.printAttribute(attr); |
249 | return p; |
250 | } |
251 | |
252 | template <typename AsmPrinterT> |
253 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
254 | AsmPrinterT &> |
255 | operator<<(AsmPrinterT &p, const APFloat &value) { |
256 | p.printFloat(value); |
257 | return p; |
258 | } |
259 | template <typename AsmPrinterT> |
260 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
261 | AsmPrinterT &> |
262 | operator<<(AsmPrinterT &p, float value) { |
263 | return p << APFloat(value); |
264 | } |
265 | template <typename AsmPrinterT> |
266 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
267 | AsmPrinterT &> |
268 | operator<<(AsmPrinterT &p, double value) { |
269 | return p << APFloat(value); |
270 | } |
271 | |
272 | // Support printing anything that isn't convertible to one of the other |
273 | // streamable types, even if it isn't exactly one of them. For example, we want |
274 | // to print FunctionType with the Type version above, not have it match this. |
275 | template < |
276 | typename AsmPrinterT, typename T, |
277 | typename std::enable_if<!std::is_convertible<T &, Value &>::value && |
278 | !std::is_convertible<T &, Type &>::value && |
279 | !std::is_convertible<T &, Attribute &>::value && |
280 | !std::is_convertible<T &, ValueRange>::value && |
281 | !std::is_convertible<T &, APFloat &>::value && |
282 | !llvm::is_one_of<T, bool, float, double>::value, |
283 | T>::type * = nullptr> |
284 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
285 | AsmPrinterT &> |
286 | operator<<(AsmPrinterT &p, const T &other) { |
287 | p.getStream() << other; |
288 | return p; |
289 | } |
290 | |
291 | template <typename AsmPrinterT> |
292 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
293 | AsmPrinterT &> |
294 | operator<<(AsmPrinterT &p, bool value) { |
295 | return p << (value ? StringRef("true") : "false"); |
296 | } |
297 | |
298 | template <typename AsmPrinterT, typename ValueRangeT> |
299 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
300 | AsmPrinterT &> |
301 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
302 | llvm::interleaveComma(types, p); |
303 | return p; |
304 | } |
305 | template <typename AsmPrinterT> |
306 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
307 | AsmPrinterT &> |
308 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
309 | llvm::interleaveComma(types, p); |
310 | return p; |
311 | } |
312 | template <typename AsmPrinterT, typename ElementT> |
313 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
314 | AsmPrinterT &> |
315 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
316 | llvm::interleaveComma(types, p); |
317 | return p; |
318 | } |
319 | |
320 | //===----------------------------------------------------------------------===// |
321 | // OpAsmPrinter |
322 | //===----------------------------------------------------------------------===// |
323 | |
324 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
325 | /// necessary to implement a custom print() method. |
326 | class OpAsmPrinter : public AsmPrinter { |
327 | public: |
328 | using AsmPrinter::AsmPrinter; |
329 | ~OpAsmPrinter() override; |
330 | |
331 | /// Print a newline and indent the printer to the start of the current |
332 | /// operation. |
333 | virtual void printNewline() = 0; |
334 | |
335 | /// Print a block argument in the usual format of: |
336 | /// %ssaName : type {attr1=42} loc("here") |
337 | /// where location printing is controlled by the standard internal option. |
338 | /// You may pass omitType=true to not print a type, and pass an empty |
339 | /// attribute list if you don't care for attributes. |
340 | virtual void printRegionArgument(BlockArgument arg, |
341 | ArrayRef<NamedAttribute> argAttrs = {}, |
342 | bool omitType = false) = 0; |
343 | |
344 | /// Print implementations for various things an operation contains. |
345 | virtual void printOperand(Value value) = 0; |
346 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
347 | |
348 | /// Print a comma separated list of operands. |
349 | template <typename ContainerType> |
350 | void printOperands(const ContainerType &container) { |
351 | printOperands(container.begin(), container.end()); |
352 | } |
353 | |
354 | /// Print a comma separated list of operands. |
355 | template <typename IteratorType> |
356 | void printOperands(IteratorType it, IteratorType end) { |
357 | if (it == end) |
358 | return; |
359 | printOperand(*it); |
360 | for (++it; it != end; ++it) { |
361 | getStream() << ", "; |
362 | printOperand(*it); |
363 | } |
364 | } |
365 | |
366 | /// Print the given successor. |
367 | virtual void printSuccessor(Block *successor) = 0; |
368 | |
369 | /// Print the successor and its operands. |
370 | virtual void printSuccessorAndUseList(Block *successor, |
371 | ValueRange succOperands) = 0; |
372 | |
373 | /// If the specified operation has attributes, print out an attribute |
374 | /// dictionary with their values. elidedAttrs allows the client to ignore |
375 | /// specific well known attributes, commonly used if the attribute value is |
376 | /// printed some other way (like as a fixed operand). |
377 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
378 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
379 | |
380 | /// If the specified operation has attributes, print out an attribute |
381 | /// dictionary prefixed with 'attributes'. |
382 | virtual void |
383 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
384 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
385 | |
386 | /// Print the entire operation with the default generic assembly form. |
387 | /// If `printOpName` is true, then the operation name is printed (the default) |
388 | /// otherwise it is omitted and the print will start with the operand list. |
389 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
390 | |
391 | /// Prints a region. |
392 | /// If 'printEntryBlockArgs' is false, the arguments of the |
393 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
394 | /// operation of the block is not printed. If printEmptyBlock is true, then |
395 | /// the block header is printed even if the block is empty. |
396 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
397 | bool printBlockTerminators = true, |
398 | bool printEmptyBlock = false) = 0; |
399 | |
400 | /// Renumber the arguments for the specified region to the same names as the |
401 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
402 | /// operations. If any entry in namesToUse is null, the corresponding |
403 | /// argument name is left alone. |
404 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
405 | |
406 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
407 | /// of dims/symbols. |
408 | /// Operand values must come from single-result sources, and be valid |
409 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
410 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
411 | ValueRange operands) = 0; |
412 | |
413 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
414 | /// dims and symbols. |
415 | /// Operand values must come from single-result sources, and be valid |
416 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
417 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
418 | ValueRange symOperands) = 0; |
419 | |
420 | /// Print the complete type of an operation in functional form. |
421 | void printFunctionalType(Operation *op); |
422 | using AsmPrinter::printFunctionalType; |
423 | }; |
424 | |
425 | // Make the implementations convenient to use. |
426 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
427 | p.printOperand(value); |
428 | return p; |
429 | } |
430 | |
431 | template <typename T, |
432 | typename std::enable_if<std::is_convertible<T &, ValueRange>::value && |
433 | !std::is_convertible<T &, Value &>::value, |
434 | T>::type * = nullptr> |
435 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
436 | p.printOperands(values); |
437 | return p; |
438 | } |
439 | |
440 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
441 | p.printSuccessor(value); |
442 | return p; |
443 | } |
444 | |
445 | //===----------------------------------------------------------------------===// |
446 | // AsmParser |
447 | //===----------------------------------------------------------------------===// |
448 | |
449 | /// This base class exposes generic asm parser hooks, usable across the various |
450 | /// derived parsers. |
451 | class AsmParser { |
452 | public: |
453 | AsmParser() = default; |
454 | virtual ~AsmParser(); |
455 | |
456 | MLIRContext *getContext() const; |
457 | |
458 | /// Return the location of the original name token. |
459 | virtual SMLoc getNameLoc() const = 0; |
460 | |
461 | //===--------------------------------------------------------------------===// |
462 | // Utilities |
463 | //===--------------------------------------------------------------------===// |
464 | |
465 | /// Emit a diagnostic at the specified location and return failure. |
466 | virtual InFlightDiagnostic emitError(SMLoc loc, |
467 | const Twine &message = {}) = 0; |
468 | |
469 | /// Return a builder which provides useful access to MLIRContext, global |
470 | /// objects like types and attributes. |
471 | virtual Builder &getBuilder() const = 0; |
472 | |
473 | /// Get the location of the next token and store it into the argument. This |
474 | /// always succeeds. |
475 | virtual SMLoc getCurrentLocation() = 0; |
476 | ParseResult getCurrentLocation(SMLoc *loc) { |
477 | *loc = getCurrentLocation(); |
478 | return success(); |
479 | } |
480 | |
481 | /// Re-encode the given source location as an MLIR location and return it. |
482 | /// Note: This method should only be used when a `Location` is necessary, as |
483 | /// the encoding process is not efficient. |
484 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
485 | |
486 | //===--------------------------------------------------------------------===// |
487 | // Token Parsing |
488 | //===--------------------------------------------------------------------===// |
489 | |
490 | /// Parse a '->' token. |
491 | virtual ParseResult parseArrow() = 0; |
492 | |
493 | /// Parse a '->' token if present |
494 | virtual ParseResult parseOptionalArrow() = 0; |
495 | |
496 | /// Parse a `{` token. |
497 | virtual ParseResult parseLBrace() = 0; |
498 | |
499 | /// Parse a `{` token if present. |
500 | virtual ParseResult parseOptionalLBrace() = 0; |
501 | |
502 | /// Parse a `}` token. |
503 | virtual ParseResult parseRBrace() = 0; |
504 | |
505 | /// Parse a `}` token if present. |
506 | virtual ParseResult parseOptionalRBrace() = 0; |
507 | |
508 | /// Parse a `:` token. |
509 | virtual ParseResult parseColon() = 0; |
510 | |
511 | /// Parse a `:` token if present. |
512 | virtual ParseResult parseOptionalColon() = 0; |
513 | |
514 | /// Parse a `,` token. |
515 | virtual ParseResult parseComma() = 0; |
516 | |
517 | /// Parse a `,` token if present. |
518 | virtual ParseResult parseOptionalComma() = 0; |
519 | |
520 | /// Parse a `=` token. |
521 | virtual ParseResult parseEqual() = 0; |
522 | |
523 | /// Parse a `=` token if present. |
524 | virtual ParseResult parseOptionalEqual() = 0; |
525 | |
526 | /// Parse a '<' token. |
527 | virtual ParseResult parseLess() = 0; |
528 | |
529 | /// Parse a '<' token if present. |
530 | virtual ParseResult parseOptionalLess() = 0; |
531 | |
532 | /// Parse a '>' token. |
533 | virtual ParseResult parseGreater() = 0; |
534 | |
535 | /// Parse a '>' token if present. |
536 | virtual ParseResult parseOptionalGreater() = 0; |
537 | |
538 | /// Parse a '?' token. |
539 | virtual ParseResult parseQuestion() = 0; |
540 | |
541 | /// Parse a '?' token if present. |
542 | virtual ParseResult parseOptionalQuestion() = 0; |
543 | |
544 | /// Parse a '+' token. |
545 | virtual ParseResult parsePlus() = 0; |
546 | |
547 | /// Parse a '+' token if present. |
548 | virtual ParseResult parseOptionalPlus() = 0; |
549 | |
550 | /// Parse a '*' token. |
551 | virtual ParseResult parseStar() = 0; |
552 | |
553 | /// Parse a '*' token if present. |
554 | virtual ParseResult parseOptionalStar() = 0; |
555 | |
556 | /// Parse a '|' token. |
557 | virtual ParseResult parseVerticalBar() = 0; |
558 | |
559 | /// Parse a '|' token if present. |
560 | virtual ParseResult parseOptionalVerticalBar() = 0; |
561 | |
562 | /// Parse a quoted string token. |
563 | ParseResult parseString(std::string *string) { |
564 | auto loc = getCurrentLocation(); |
565 | if (parseOptionalString(string)) |
566 | return emitError(loc, "expected string"); |
567 | return success(); |
568 | } |
569 | |
570 | /// Parse a quoted string token if present. |
571 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
572 | |
573 | /// Parse a `(` token. |
574 | virtual ParseResult parseLParen() = 0; |
575 | |
576 | /// Parse a `(` token if present. |
577 | virtual ParseResult parseOptionalLParen() = 0; |
578 | |
579 | /// Parse a `)` token. |
580 | virtual ParseResult parseRParen() = 0; |
581 | |
582 | /// Parse a `)` token if present. |
583 | virtual ParseResult parseOptionalRParen() = 0; |
584 | |
585 | /// Parse a `[` token. |
586 | virtual ParseResult parseLSquare() = 0; |
587 | |
588 | /// Parse a `[` token if present. |
589 | virtual ParseResult parseOptionalLSquare() = 0; |
590 | |
591 | /// Parse a `]` token. |
592 | virtual ParseResult parseRSquare() = 0; |
593 | |
594 | /// Parse a `]` token if present. |
595 | virtual ParseResult parseOptionalRSquare() = 0; |
596 | |
597 | /// Parse a `...` token if present; |
598 | virtual ParseResult parseOptionalEllipsis() = 0; |
599 | |
600 | /// Parse a floating point value from the stream. |
601 | virtual ParseResult parseFloat(double &result) = 0; |
602 | |
603 | /// Parse an integer value from the stream. |
604 | template <typename IntT> |
605 | ParseResult parseInteger(IntT &result) { |
606 | auto loc = getCurrentLocation(); |
607 | OptionalParseResult parseResult = parseOptionalInteger(result); |
608 | if (!parseResult.has_value()) |
609 | return emitError(loc, "expected integer value"); |
610 | return *parseResult; |
611 | } |
612 | |
613 | /// Parse an optional integer value from the stream. |
614 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
615 | |
616 | template <typename IntT> |
617 | OptionalParseResult parseOptionalInteger(IntT &result) { |
618 | auto loc = getCurrentLocation(); |
619 | |
620 | // Parse the unsigned variant. |
621 | APInt uintResult; |
622 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
623 | if (!parseResult.has_value() || failed(*parseResult)) |
624 | return parseResult; |
625 | |
626 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
627 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
628 | // zero for non-negated integers. |
629 | result = |
630 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
631 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
632 | return emitError(loc, "integer value too large"); |
633 | return success(); |
634 | } |
635 | |
636 | /// These are the supported delimiters around operand lists and region |
637 | /// argument lists, used by parseOperandList. |
638 | enum class Delimiter { |
639 | /// Zero or more operands with no delimiters. |
640 | None, |
641 | /// Parens surrounding zero or more operands. |
642 | Paren, |
643 | /// Square brackets surrounding zero or more operands. |
644 | Square, |
645 | /// <> brackets surrounding zero or more operands. |
646 | LessGreater, |
647 | /// {} brackets surrounding zero or more operands. |
648 | Braces, |
649 | /// Parens supporting zero or more operands, or nothing. |
650 | OptionalParen, |
651 | /// Square brackets supporting zero or more ops, or nothing. |
652 | OptionalSquare, |
653 | /// <> brackets supporting zero or more ops, or nothing. |
654 | OptionalLessGreater, |
655 | /// {} brackets surrounding zero or more operands, or nothing. |
656 | OptionalBraces, |
657 | }; |
658 | |
659 | /// Parse a list of comma-separated items with an optional delimiter. If a |
660 | /// delimiter is provided, then an empty list is allowed. If not, then at |
661 | /// least one element will be parsed. |
662 | /// |
663 | /// contextMessage is an optional message appended to "expected '('" sorts of |
664 | /// diagnostics when parsing the delimeters. |
665 | virtual ParseResult |
666 | parseCommaSeparatedList(Delimiter delimiter, |
667 | function_ref<ParseResult()> parseElementFn, |
668 | StringRef contextMessage = StringRef()) = 0; |
669 | |
670 | /// Parse a comma separated list of elements that must have at least one entry |
671 | /// in it. |
672 | ParseResult |
673 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
674 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
675 | } |
676 | |
677 | //===--------------------------------------------------------------------===// |
678 | // Keyword Parsing |
679 | //===--------------------------------------------------------------------===// |
680 | |
681 | /// This class represents a StringSwitch like class that is useful for parsing |
682 | /// expected keywords. On construction, it invokes `parseKeyword` and |
683 | /// processes each of the provided cases statements until a match is hit. The |
684 | /// provided `ResultT` must be assignable from `failure()`. |
685 | template <typename ResultT = ParseResult> |
686 | class KeywordSwitch { |
687 | public: |
688 | KeywordSwitch(AsmParser &parser) |
689 | : parser(parser), loc(parser.getCurrentLocation()) { |
690 | if (failed(parser.parseKeywordOrCompletion(&keyword))) |
691 | result = failure(); |
692 | } |
693 | |
694 | /// Case that uses the provided value when true. |
695 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
696 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
697 | } |
698 | KeywordSwitch &Default(ResultT value) { |
699 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
700 | } |
701 | /// Case that invokes the provided functor when true. The parameters passed |
702 | /// to the functor are the keyword, and the location of the keyword (in case |
703 | /// any errors need to be emitted). |
704 | template <typename FnT> |
705 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
706 | Case(StringLiteral str, FnT &&fn) { |
707 | if (result) |
708 | return *this; |
709 | |
710 | // If the word was empty, record this as a completion. |
711 | if (keyword.empty()) |
712 | parser.codeCompleteExpectedTokens(str); |
713 | else if (keyword == str) |
714 | result.emplace(std::move(fn(keyword, loc))); |
715 | return *this; |
716 | } |
717 | template <typename FnT> |
718 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
719 | Default(FnT &&fn) { |
720 | if (!result) |
721 | result.emplace(fn(keyword, loc)); |
722 | return *this; |
723 | } |
724 | |
725 | /// Returns true if this switch has a value yet. |
726 | bool hasValue() const { return result.has_value(); } |
727 | |
728 | /// Return the result of the switch. |
729 | [[nodiscard]] operator ResultT() { |
730 | if (!result) |
731 | return parser.emitError(loc, "unexpected keyword: ") << keyword; |
732 | return std::move(*result); |
733 | } |
734 | |
735 | private: |
736 | /// The parser used to construct this switch. |
737 | AsmParser &parser; |
738 | |
739 | /// The location of the keyword, used to emit errors as necessary. |
740 | SMLoc loc; |
741 | |
742 | /// The parsed keyword itself. |
743 | StringRef keyword; |
744 | |
745 | /// The result of the switch statement or none if currently unknown. |
746 | Optional<ResultT> result; |
747 | }; |
748 | |
749 | /// Parse a given keyword. |
750 | ParseResult parseKeyword(StringRef keyword) { |
751 | return parseKeyword(keyword, ""); |
752 | } |
753 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
754 | |
755 | /// Parse a keyword into 'keyword'. |
756 | ParseResult parseKeyword(StringRef *keyword) { |
757 | auto loc = getCurrentLocation(); |
758 | if (parseOptionalKeyword(keyword)) |
759 | return emitError(loc, "expected valid keyword"); |
760 | return success(); |
761 | } |
762 | |
763 | /// Parse the given keyword if present. |
764 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
765 | |
766 | /// Parse a keyword, if present, into 'keyword'. |
767 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
768 | |
769 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
770 | /// into 'keyword' |
771 | virtual ParseResult |
772 | parseOptionalKeyword(StringRef *keyword, |
773 | ArrayRef<StringRef> allowedValues) = 0; |
774 | |
775 | /// Parse a keyword or a quoted string. |
776 | ParseResult parseKeywordOrString(std::string *result) { |
777 | if (failed(parseOptionalKeywordOrString(result))) |
778 | return emitError(getCurrentLocation()) |
779 | << "expected valid keyword or string"; |
780 | return success(); |
781 | } |
782 | |
783 | /// Parse an optional keyword or string. |
784 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
785 | |
786 | //===--------------------------------------------------------------------===// |
787 | // Attribute/Type Parsing |
788 | //===--------------------------------------------------------------------===// |
789 | |
790 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
791 | /// the provided location to emit errors in the case of failure. Note that |
792 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
793 | /// context parameter. |
794 | template <typename T, typename... ParamsT> |
795 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
796 | return T::getChecked([&] { return emitError(loc); }, |
797 | std::forward<ParamsT>(params)...); |
798 | } |
799 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
800 | /// errors. |
801 | template <typename T, typename... ParamsT> |
802 | auto getChecked(ParamsT &&...params) { |
803 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
804 | std::forward<ParamsT>(params)...); |
805 | } |
806 | |
807 | //===--------------------------------------------------------------------===// |
808 | // Attribute Parsing |
809 | //===--------------------------------------------------------------------===// |
810 | |
811 | /// Parse an arbitrary attribute of a given type and return it in result. |
812 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
813 | |
814 | /// Parse a custom attribute with the provided callback, unless the next |
815 | /// token is `#`, in which case the generic parser is invoked. |
816 | virtual ParseResult parseCustomAttributeWithFallback( |
817 | Attribute &result, Type type, |
818 | function_ref<ParseResult(Attribute &result, Type type)> |
819 | parseAttribute) = 0; |
820 | |
821 | /// Parse an attribute of a specific kind and type. |
822 | template <typename AttrType> |
823 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
824 | SMLoc loc = getCurrentLocation(); |
825 | |
826 | // Parse any kind of attribute. |
827 | Attribute attr; |
828 | if (parseAttribute(attr, type)) |
829 | return failure(); |
830 | |
831 | // Check for the right kind of attribute. |
832 | if (!(result = attr.dyn_cast<AttrType>())) |
833 | return emitError(loc, "invalid kind of attribute specified"); |
834 | |
835 | return success(); |
836 | } |
837 | |
838 | /// Parse an arbitrary attribute and return it in result. This also adds the |
839 | /// attribute to the specified attribute list with the specified name. |
840 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
841 | NamedAttrList &attrs) { |
842 | return parseAttribute(result, Type(), attrName, attrs); |
843 | } |
844 | |
845 | /// Parse an attribute of a specific kind and type. |
846 | template <typename AttrType> |
847 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
848 | NamedAttrList &attrs) { |
849 | return parseAttribute(result, Type(), attrName, attrs); |
850 | } |
851 | |
852 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
853 | /// This also adds the attribute to the specified attribute list with the |
854 | /// specified name. |
855 | template <typename AttrType> |
856 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
857 | NamedAttrList &attrs) { |
858 | SMLoc loc = getCurrentLocation(); |
859 | |
860 | // Parse any kind of attribute. |
861 | Attribute attr; |
862 | if (parseAttribute(attr, type)) |
863 | return failure(); |
864 | |
865 | // Check for the right kind of attribute. |
866 | result = attr.dyn_cast<AttrType>(); |
867 | if (!result) |
868 | return emitError(loc, "invalid kind of attribute specified"); |
869 | |
870 | attrs.append(attrName, result); |
871 | return success(); |
872 | } |
873 | |
874 | /// Trait to check if `AttrType` provides a `parse` method. |
875 | template <typename AttrType> |
876 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
877 | std::declval<Type>())); |
878 | template <typename AttrType> |
879 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
880 | |
881 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
882 | /// which case the generic parser is invoked. The parsed attribute is |
883 | /// populated in `result` and also added to the specified attribute list with |
884 | /// the specified name. |
885 | template <typename AttrType> |
886 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
887 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
888 | StringRef attrName, NamedAttrList &attrs) { |
889 | SMLoc loc = getCurrentLocation(); |
890 | |
891 | // Parse any kind of attribute. |
892 | Attribute attr; |
893 | if (parseCustomAttributeWithFallback( |
894 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
895 | result = AttrType::parse(*this, type); |
896 | if (!result) |
897 | return failure(); |
898 | return success(); |
899 | })) |
900 | return failure(); |
901 | |
902 | // Check for the right kind of attribute. |
903 | result = attr.dyn_cast<AttrType>(); |
904 | if (!result) |
905 | return emitError(loc, "invalid kind of attribute specified"); |
906 | |
907 | attrs.append(attrName, result); |
908 | return success(); |
909 | } |
910 | |
911 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
912 | template <typename AttrType> |
913 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
914 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
915 | StringRef attrName, NamedAttrList &attrs) { |
916 | return parseAttribute(result, type, attrName, attrs); |
917 | } |
918 | |
919 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
920 | /// which case the generic parser is invoked. The parsed attribute is |
921 | /// populated in `result`. |
922 | template <typename AttrType> |
923 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
924 | parseCustomAttributeWithFallback(AttrType &result) { |
925 | SMLoc loc = getCurrentLocation(); |
926 | |
927 | // Parse any kind of attribute. |
928 | Attribute attr; |
929 | if (parseCustomAttributeWithFallback( |
930 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
931 | result = AttrType::parse(*this, type); |
932 | return success(!!result); |
933 | })) |
934 | return failure(); |
935 | |
936 | // Check for the right kind of attribute. |
937 | result = attr.dyn_cast<AttrType>(); |
938 | if (!result) |
939 | return emitError(loc, "invalid kind of attribute specified"); |
940 | return success(); |
941 | } |
942 | |
943 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
944 | template <typename AttrType> |
945 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
946 | parseCustomAttributeWithFallback(AttrType &result) { |
947 | return parseAttribute(result); |
948 | } |
949 | |
950 | /// Parse an arbitrary optional attribute of a given type and return it in |
951 | /// result. |
952 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
953 | Type type = {}) = 0; |
954 | |
955 | /// Parse an optional array attribute and return it in result. |
956 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
957 | Type type = {}) = 0; |
958 | |
959 | /// Parse an optional string attribute and return it in result. |
960 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
961 | Type type = {}) = 0; |
962 | |
963 | /// Parse an optional attribute of a specific type and add it to the list with |
964 | /// the specified name. |
965 | template <typename AttrType> |
966 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
967 | StringRef attrName, |
968 | NamedAttrList &attrs) { |
969 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
970 | } |
971 | |
972 | /// Parse an optional attribute of a specific type and add it to the list with |
973 | /// the specified name. |
974 | template <typename AttrType> |
975 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
976 | StringRef attrName, |
977 | NamedAttrList &attrs) { |
978 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
979 | if (parseResult.has_value() && succeeded(*parseResult)) |
980 | attrs.append(attrName, result); |
981 | return parseResult; |
982 | } |
983 | |
984 | /// Parse a named dictionary into 'result' if it is present. |
985 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
986 | |
987 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
988 | /// present. |
989 | virtual ParseResult |
990 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
991 | |
992 | /// Parse an affine map instance into 'map'. |
993 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
994 | |
995 | /// Parse an integer set instance into 'set'. |
996 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
997 | |
998 | //===--------------------------------------------------------------------===// |
999 | // Identifier Parsing |
1000 | //===--------------------------------------------------------------------===// |
1001 | |
1002 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1003 | /// attribute named 'attrName'. |
1004 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1005 | NamedAttrList &attrs) { |
1006 | if (failed(parseOptionalSymbolName(result, attrName, attrs))) |
1007 | return emitError(getCurrentLocation()) |
1008 | << "expected valid '@'-identifier for symbol name"; |
1009 | return success(); |
1010 | } |
1011 | |
1012 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1013 | /// string attribute named 'attrName'. |
1014 | virtual ParseResult parseOptionalSymbolName(StringAttr &result, |
1015 | StringRef attrName, |
1016 | NamedAttrList &attrs) = 0; |
1017 | |
1018 | //===--------------------------------------------------------------------===// |
1019 | // Resource Parsing |
1020 | //===--------------------------------------------------------------------===// |
1021 | |
1022 | /// Parse a handle to a resource within the assembly format. |
1023 | template <typename ResourceT> |
1024 | FailureOr<ResourceT> parseResourceHandle() { |
1025 | SMLoc handleLoc = getCurrentLocation(); |
1026 | |
1027 | // Try to load the dialect that owns the handle. |
1028 | auto *dialect = |
1029 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1030 | if (!dialect) { |
1031 | return emitError(handleLoc) |
1032 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1033 | << "' is unknown"; |
1034 | } |
1035 | |
1036 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1037 | if (failed(handle)) |
1038 | return failure(); |
1039 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1040 | return std::move(*result); |
1041 | return emitError(handleLoc) << "provided resource handle differs from the " |
1042 | "expected resource type"; |
1043 | } |
1044 | |
1045 | //===--------------------------------------------------------------------===// |
1046 | // Type Parsing |
1047 | //===--------------------------------------------------------------------===// |
1048 | |
1049 | /// Parse a type. |
1050 | virtual ParseResult parseType(Type &result) = 0; |
1051 | |
1052 | /// Parse a custom type with the provided callback, unless the next |
1053 | /// token is `#`, in which case the generic parser is invoked. |
1054 | virtual ParseResult parseCustomTypeWithFallback( |
1055 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1056 | |
1057 | /// Parse an optional type. |
1058 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1059 | |
1060 | /// Parse a type of a specific type. |
1061 | template <typename TypeT> |
1062 | ParseResult parseType(TypeT &result) { |
1063 | SMLoc loc = getCurrentLocation(); |
1064 | |
1065 | // Parse any kind of type. |
1066 | Type type; |
1067 | if (parseType(type)) |
1068 | return failure(); |
1069 | |
1070 | // Check for the right kind of type. |
1071 | result = type.dyn_cast<TypeT>(); |
1072 | if (!result) |
1073 | return emitError(loc, "invalid kind of type specified"); |
1074 | |
1075 | return success(); |
1076 | } |
1077 | |
1078 | /// Trait to check if `TypeT` provides a `parse` method. |
1079 | template <typename TypeT> |
1080 | using type_has_parse_method = |
1081 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1082 | template <typename TypeT> |
1083 | using detect_type_has_parse_method = |
1084 | llvm::is_detected<type_has_parse_method, TypeT>; |
1085 | |
1086 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1087 | /// which case the generic parser is invoked. The parsed Type is |
1088 | /// populated in `result`. |
1089 | template <typename TypeT> |
1090 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1091 | parseCustomTypeWithFallback(TypeT &result) { |
1092 | SMLoc loc = getCurrentLocation(); |
1093 | |
1094 | // Parse any kind of Type. |
1095 | Type type; |
1096 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1097 | result = TypeT::parse(*this); |
1098 | return success(!!result); |
1099 | })) |
1100 | return failure(); |
1101 | |
1102 | // Check for the right kind of Type. |
1103 | result = type.dyn_cast<TypeT>(); |
1104 | if (!result) |
1105 | return emitError(loc, "invalid kind of Type specified"); |
1106 | return success(); |
1107 | } |
1108 | |
1109 | /// SFINAE parsing method for Type that don't implement a parse method. |
1110 | template <typename TypeT> |
1111 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1112 | parseCustomTypeWithFallback(TypeT &result) { |
1113 | return parseType(result); |
1114 | } |
1115 | |
1116 | /// Parse a type list. |
1117 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
1118 | return parseCommaSeparatedList( |
1119 | [&]() { return parseType(result.emplace_back()); }); |
1120 | } |
1121 | |
1122 | /// Parse an arrow followed by a type list. |
1123 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1124 | |
1125 | /// Parse an optional arrow followed by a type list. |
1126 | virtual ParseResult |
1127 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1128 | |
1129 | /// Parse a colon followed by a type. |
1130 | virtual ParseResult parseColonType(Type &result) = 0; |
1131 | |
1132 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1133 | template <typename TypeType> |
1134 | ParseResult parseColonType(TypeType &result) { |
1135 | SMLoc loc = getCurrentLocation(); |
1136 | |
1137 | // Parse any kind of type. |
1138 | Type type; |
1139 | if (parseColonType(type)) |
1140 | return failure(); |
1141 | |
1142 | // Check for the right kind of type. |
1143 | result = type.dyn_cast<TypeType>(); |
1144 | if (!result) |
1145 | return emitError(loc, "invalid kind of type specified"); |
1146 | |
1147 | return success(); |
1148 | } |
1149 | |
1150 | /// Parse a colon followed by a type list, which must have at least one type. |
1151 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1152 | |
1153 | /// Parse an optional colon followed by a type list, which if present must |
1154 | /// have at least one type. |
1155 | virtual ParseResult |
1156 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1157 | |
1158 | /// Parse a keyword followed by a type. |
1159 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1160 | return failure(parseKeyword(keyword) || parseType(result)); |
1161 | } |
1162 | |
1163 | /// Add the specified type to the end of the specified type list and return |
1164 | /// success. This is a helper designed to allow parse methods to be simple |
1165 | /// and chain through || operators. |
1166 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1167 | result.push_back(type); |
1168 | return success(); |
1169 | } |
1170 | |
1171 | /// Add the specified types to the end of the specified type list and return |
1172 | /// success. This is a helper designed to allow parse methods to be simple |
1173 | /// and chain through || operators. |
1174 | ParseResult addTypesToList(ArrayRef<Type> types, |
1175 | SmallVectorImpl<Type> &result) { |
1176 | result.append(types.begin(), types.end()); |
1177 | return success(); |
1178 | } |
1179 | |
1180 | /// Parse a dimension list of a tensor or memref type. This populates the |
1181 | /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set |
1182 | /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable. |
1183 | /// |
1184 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1185 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1186 | /// dimension ::= `?` | decimal-literal |
1187 | /// |
1188 | /// When `allowDynamic` is not set, this is used to parse: |
1189 | /// |
1190 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1191 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1192 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1193 | bool allowDynamic = true, |
1194 | bool withTrailingX = true) = 0; |
1195 | |
1196 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1197 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1198 | /// next token. |
1199 | virtual ParseResult parseXInDimensionList() = 0; |
1200 | |
1201 | protected: |
1202 | /// Parse a handle to a resource within the assembly format for the given |
1203 | /// dialect. |
1204 | virtual FailureOr<AsmDialectResourceHandle> |
1205 | parseResourceHandle(Dialect *dialect) = 0; |
1206 | |
1207 | //===--------------------------------------------------------------------===// |
1208 | // Code Completion |
1209 | //===--------------------------------------------------------------------===// |
1210 | |
1211 | /// Parse a keyword, or an empty string if the current location signals a code |
1212 | /// completion. |
1213 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1214 | |
1215 | /// Signal the code completion of a set of expected tokens. |
1216 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1217 | |
1218 | private: |
1219 | AsmParser(const AsmParser &) = delete; |
1220 | void operator=(const AsmParser &) = delete; |
1221 | }; |
1222 | |
1223 | //===----------------------------------------------------------------------===// |
1224 | // OpAsmParser |
1225 | //===----------------------------------------------------------------------===// |
1226 | |
1227 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1228 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1229 | /// that is designed to reduce/constrain syntax innovation in individual |
1230 | /// operations. |
1231 | /// |
1232 | /// For example, consider an op like this: |
1233 | /// |
1234 | /// %x = load %p[%1, %2] : memref<...> |
1235 | /// |
1236 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1237 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1238 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1239 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1240 | /// type. |
1241 | /// |
1242 | class OpAsmParser : public AsmParser { |
1243 | public: |
1244 | using AsmParser::AsmParser; |
1245 | ~OpAsmParser() override; |
1246 | |
1247 | /// Parse a loc(...) specifier if present, filling in result if so. |
1248 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1249 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1250 | /// completes. |
1251 | virtual ParseResult |
1252 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1253 | |
1254 | /// Return the name of the specified result in the specified syntax, as well |
1255 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1256 | /// invalid result numbers. For example, in this operation: |
1257 | /// |
1258 | /// %x, %y:2, %z = foo.op |
1259 | /// |
1260 | /// getResultName(0) == {"x", 0 } |
1261 | /// getResultName(1) == {"y", 0 } |
1262 | /// getResultName(2) == {"y", 1 } |
1263 | /// getResultName(3) == {"z", 0 } |
1264 | /// getResultName(4) == {"", ~0U } |
1265 | virtual std::pair<StringRef, unsigned> |
1266 | getResultName(unsigned resultNo) const = 0; |
1267 | |
1268 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1269 | /// example in the comment for `getResultName`. |
1270 | virtual size_t getNumResults() const = 0; |
1271 | |
1272 | // These methods emit an error and return failure or success. This allows |
1273 | // these to be chained together into a linear sequence of || expressions in |
1274 | // many cases. |
1275 | |
1276 | /// Parse an operation in its generic form. |
1277 | /// The parsed operation is parsed in the current context and inserted in the |
1278 | /// provided block and insertion point. The results produced by this operation |
1279 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1280 | /// failure. |
1281 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1282 | Block::iterator insertPt) = 0; |
1283 | |
1284 | /// Parse the name of an operation, in the custom form. On success, return a |
1285 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1286 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1287 | |
1288 | //===--------------------------------------------------------------------===// |
1289 | // Operand Parsing |
1290 | //===--------------------------------------------------------------------===// |
1291 | |
1292 | /// This is the representation of an operand reference. |
1293 | struct UnresolvedOperand { |
1294 | SMLoc location; // Location of the token. |
1295 | StringRef name; // Value name, e.g. %42 or %abc |
1296 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1297 | }; |
1298 | |
1299 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1300 | /// region(s), attribute(s) and function-type, of the generic form of an |
1301 | /// operation instance and populate the input operation-state 'result' with |
1302 | /// those components. If any of the components is explicitly provided, then |
1303 | /// skip parsing that component. |
1304 | virtual ParseResult parseGenericOperationAfterOpName( |
1305 | OperationState &result, |
1306 | Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None, |
1307 | Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None, |
1308 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1309 | llvm::None, |
1310 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None, |
1311 | Optional<FunctionType> parsedFnType = llvm::None) = 0; |
1312 | |
1313 | /// Parse a single SSA value operand name along with a result number if |
1314 | /// `allowResultNumber` is true. |
1315 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1316 | bool allowResultNumber = true) = 0; |
1317 | |
1318 | /// Parse a single operand if present. |
1319 | virtual OptionalParseResult |
1320 | parseOptionalOperand(UnresolvedOperand &result, |
1321 | bool allowResultNumber = true) = 0; |
1322 | |
1323 | /// Parse zero or more SSA comma-separated operand references with a specified |
1324 | /// surrounding delimiter, and an optional required operand count. |
1325 | virtual ParseResult |
1326 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1327 | Delimiter delimiter = Delimiter::None, |
1328 | bool allowResultNumber = true, |
1329 | int requiredOperandCount = -1) = 0; |
1330 | |
1331 | /// Parse a specified number of comma separated operands. |
1332 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1333 | int requiredOperandCount, |
1334 | Delimiter delimiter = Delimiter::None) { |
1335 | return parseOperandList(result, delimiter, |
1336 | /*allowResultNumber=*/true, requiredOperandCount); |
1337 | } |
1338 | |
1339 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1340 | /// references with a specified surrounding delimiter, and an optional |
1341 | /// required operand count. A leading comma is expected before the |
1342 | /// operands. |
1343 | ParseResult |
1344 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1345 | Delimiter delimiter = Delimiter::None) { |
1346 | if (failed(parseOptionalComma())) |
1347 | return success(); // The comma is optional. |
1348 | return parseOperandList(result, delimiter); |
1349 | } |
1350 | |
1351 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1352 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1353 | Type type, |
1354 | SmallVectorImpl<Value> &result) = 0; |
1355 | |
1356 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1357 | /// appending the results to the list on success. This method should be used |
1358 | /// when all operands have the same type. |
1359 | ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type, |
1360 | SmallVectorImpl<Value> &result) { |
1361 | for (auto elt : operands) |
1362 | if (resolveOperand(elt, type, result)) |
1363 | return failure(); |
1364 | return success(); |
1365 | } |
1366 | |
1367 | /// Resolve a list of operands and a list of operand types to SSA values, |
1368 | /// emitting an error and returning failure, or appending the results |
1369 | /// to the list on success. |
1370 | ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, |
1371 | ArrayRef<Type> types, SMLoc loc, |
1372 | SmallVectorImpl<Value> &result) { |
1373 | if (operands.size() != types.size()) |
1374 | return emitError(loc) |
1375 | << operands.size() << " operands present, but expected " |
1376 | << types.size(); |
1377 | |
1378 | for (unsigned i = 0, e = operands.size(); i != e; ++i) |
1379 | if (resolveOperand(operands[i], types[i], result)) |
1380 | return failure(); |
1381 | return success(); |
1382 | } |
1383 | template <typename Operands> |
1384 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1385 | SmallVectorImpl<Value> &result) { |
1386 | return resolveOperands(std::forward<Operands>(operands), |
1387 | ArrayRef<Type>(type), loc, result); |
1388 | } |
1389 | template <typename Operands, typename Types> |
1390 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1391 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1392 | SmallVectorImpl<Value> &result) { |
1393 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1394 | size_t typeSize = std::distance(types.begin(), types.end()); |
1395 | if (operandSize != typeSize) |
1396 | return emitError(loc) |
1397 | << operandSize << " operands present, but expected " << typeSize; |
1398 | |
1399 | for (auto it : llvm::zip(operands, types)) |
1400 | if (resolveOperand(std::get<0>(it), std::get<1>(it), result)) |
1401 | return failure(); |
1402 | return success(); |
1403 | } |
1404 | |
1405 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1406 | /// Operand values must come from single-result sources, and be valid |
1407 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1408 | virtual ParseResult |
1409 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1410 | Attribute &map, StringRef attrName, |
1411 | NamedAttrList &attrs, |
1412 | Delimiter delimiter = Delimiter::Square) = 0; |
1413 | |
1414 | /// Parses an affine expression where dims and symbols are SSA operands. |
1415 | /// Operand values must come from single-result sources, and be valid |
1416 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1417 | virtual ParseResult |
1418 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1419 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1420 | AffineExpr &expr) = 0; |
1421 | |
1422 | //===--------------------------------------------------------------------===// |
1423 | // Argument Parsing |
1424 | //===--------------------------------------------------------------------===// |
1425 | |
1426 | struct Argument { |
1427 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1428 | Type type; // Type. |
1429 | DictionaryAttr attrs; // Attributes if present. |
1430 | Optional<Location> sourceLoc; // Source location specifier if present. |
1431 | }; |
1432 | |
1433 | /// Parse a single argument with the following syntax: |
1434 | /// |
1435 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1436 | /// |
1437 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1438 | /// parts of the grammar are not parsed. |
1439 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1440 | bool allowAttrs = false) = 0; |
1441 | |
1442 | /// Parse a single argument if present. |
1443 | virtual OptionalParseResult |
1444 | parseOptionalArgument(Argument &result, bool allowType = false, |
1445 | bool allowAttrs = false) = 0; |
1446 | |
1447 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1448 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1449 | Delimiter delimiter = Delimiter::None, |
1450 | bool allowType = false, |
1451 | bool allowAttrs = false) = 0; |
1452 | |
1453 | //===--------------------------------------------------------------------===// |
1454 | // Region Parsing |
1455 | //===--------------------------------------------------------------------===// |
1456 | |
1457 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1458 | /// moved to the op regions after the op is created. The first block of the |
1459 | /// region takes 'arguments'. |
1460 | /// |
1461 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1462 | /// shadow the names of other existing SSA values defined above the region |
1463 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1464 | /// to operations that are 'IsolatedFromAbove'. |
1465 | virtual ParseResult parseRegion(Region ®ion, |
1466 | ArrayRef<Argument> arguments = {}, |
1467 | bool enableNameShadowing = false) = 0; |
1468 | |
1469 | /// Parses a region if present. |
1470 | virtual OptionalParseResult |
1471 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1472 | bool enableNameShadowing = false) = 0; |
1473 | |
1474 | /// Parses a region if present. If the region is present, a new region is |
1475 | /// allocated and placed in `region`. If no region is present or on failure, |
1476 | /// `region` remains untouched. |
1477 | virtual OptionalParseResult |
1478 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1479 | ArrayRef<Argument> arguments = {}, |
1480 | bool enableNameShadowing = false) = 0; |
1481 | |
1482 | //===--------------------------------------------------------------------===// |
1483 | // Successor Parsing |
1484 | //===--------------------------------------------------------------------===// |
1485 | |
1486 | /// Parse a single operation successor. |
1487 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1488 | |
1489 | /// Parse an optional operation successor. |
1490 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1491 | |
1492 | /// Parse a single operation successor and its operand list. |
1493 | virtual ParseResult |
1494 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1495 | |
1496 | //===--------------------------------------------------------------------===// |
1497 | // Type Parsing |
1498 | //===--------------------------------------------------------------------===// |
1499 | |
1500 | /// Parse a list of assignments of the form |
1501 | /// (%x1 = %y1, %x2 = %y2, ...) |
1502 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1503 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1504 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1505 | if (!result.has_value()) |
1506 | return emitError(getCurrentLocation(), "expected '('"); |
1507 | return result.value(); |
1508 | } |
1509 | |
1510 | virtual OptionalParseResult |
1511 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1512 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1513 | }; |
1514 | |
1515 | //===--------------------------------------------------------------------===// |
1516 | // Dialect OpAsm interface. |
1517 | //===--------------------------------------------------------------------===// |
1518 | |
1519 | /// A functor used to set the name of the start of a result group of an |
1520 | /// operation. See 'getAsmResultNames' below for more details. |
1521 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1522 | |
1523 | /// A functor used to set the name of blocks in regions directly nested under |
1524 | /// an operation. |
1525 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1526 | |
1527 | class OpAsmDialectInterface |
1528 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1529 | public: |
1530 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1531 | |
1532 | //===------------------------------------------------------------------===// |
1533 | // Aliases |
1534 | //===------------------------------------------------------------------===// |
1535 | |
1536 | /// Holds the result of `getAlias` hook call. |
1537 | enum class AliasResult { |
1538 | /// The object (type or attribute) is not supported by the hook |
1539 | /// and an alias was not provided. |
1540 | NoAlias, |
1541 | /// An alias was provided, but it might be overriden by other hook. |
1542 | OverridableAlias, |
1543 | /// An alias was provided and it should be used |
1544 | /// (no other hooks will be checked). |
1545 | FinalAlias |
1546 | }; |
1547 | |
1548 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1549 | /// not necessarily a part of this dialect. The identifier is used in place of |
1550 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1551 | /// end with a numeric digit([0-9]+). |
1552 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1553 | return AliasResult::NoAlias; |
1554 | } |
1555 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1556 | return AliasResult::NoAlias; |
1557 | } |
1558 | |
1559 | //===--------------------------------------------------------------------===// |
1560 | // Resources |
1561 | //===--------------------------------------------------------------------===// |
1562 | |
1563 | /// Declare a resource with the given key, returning a handle to use for any |
1564 | /// references of this resource key within the IR during parsing. The result |
1565 | /// of `getResourceKey` on the returned handle is permitted to be different |
1566 | /// than `key`. |
1567 | virtual FailureOr<AsmDialectResourceHandle> |
1568 | declareResource(StringRef key) const { |
1569 | return failure(); |
1570 | } |
1571 | |
1572 | /// Return a key to use for the given resource. This key should uniquely |
1573 | /// identify this resource within the dialect. |
1574 | virtual std::string |
1575 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1576 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1577) |
1577 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1577); |
1578 | } |
1579 | |
1580 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1581 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1582 | /// can be emitted via the provided entry. |
1583 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1584 | |
1585 | /// Hook for building resources to use during printing. The given `op` may be |
1586 | /// inspected to help determine what information to include. |
1587 | /// `referencedResources` contains all of the resources detected when printing |
1588 | /// 'op'. |
1589 | virtual void |
1590 | buildResources(Operation *op, |
1591 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1592 | AsmResourceBuilder &builder) const {} |
1593 | }; |
1594 | } // namespace mlir |
1595 | |
1596 | //===--------------------------------------------------------------------===// |
1597 | // Operation OpAsm interface. |
1598 | //===--------------------------------------------------------------------===// |
1599 | |
1600 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1601 | #include "mlir/IR/OpAsmInterface.h.inc" |
1602 | |
1603 | namespace llvm { |
1604 | template <> |
1605 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1606 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1607 | return {DenseMapInfo<void *>::getEmptyKey(), |
1608 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1609 | } |
1610 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1611 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1612 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1613 | } |
1614 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1615 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1616 | } |
1617 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1618 | const mlir::AsmDialectResourceHandle &rhs) { |
1619 | return lhs.getResource() == rhs.getResource(); |
1620 | } |
1621 | }; |
1622 | } // namespace llvm |
1623 | |
1624 | #endif |