File: | build/source/mlir/lib/IR/BuiltinAttributes.cpp |
Warning: | line 819, column 9 1st function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | #include "mlir/IR/BuiltinAttributes.h" | |||
10 | #include "AttributeDetail.h" | |||
11 | #include "mlir/IR/AffineMap.h" | |||
12 | #include "mlir/IR/BuiltinDialect.h" | |||
13 | #include "mlir/IR/Dialect.h" | |||
14 | #include "mlir/IR/DialectResourceBlobManager.h" | |||
15 | #include "mlir/IR/IntegerSet.h" | |||
16 | #include "mlir/IR/OpImplementation.h" | |||
17 | #include "mlir/IR/Operation.h" | |||
18 | #include "mlir/IR/SymbolTable.h" | |||
19 | #include "mlir/IR/Types.h" | |||
20 | #include "llvm/ADT/APSInt.h" | |||
21 | #include "llvm/ADT/Sequence.h" | |||
22 | #include "llvm/ADT/TypeSwitch.h" | |||
23 | #include "llvm/Support/Endian.h" | |||
24 | ||||
25 | using namespace mlir; | |||
26 | using namespace mlir::detail; | |||
27 | ||||
28 | //===----------------------------------------------------------------------===// | |||
29 | /// Tablegen Attribute Definitions | |||
30 | //===----------------------------------------------------------------------===// | |||
31 | ||||
32 | #define GET_ATTRDEF_CLASSES | |||
33 | #include "mlir/IR/BuiltinAttributes.cpp.inc" | |||
34 | ||||
35 | //===----------------------------------------------------------------------===// | |||
36 | // BuiltinDialect | |||
37 | //===----------------------------------------------------------------------===// | |||
38 | ||||
39 | void BuiltinDialect::registerAttributes() { | |||
40 | addAttributes< | |||
41 | #define GET_ATTRDEF_LIST | |||
42 | #include "mlir/IR/BuiltinAttributes.cpp.inc" | |||
43 | >(); | |||
44 | } | |||
45 | ||||
46 | //===----------------------------------------------------------------------===// | |||
47 | // DictionaryAttr | |||
48 | //===----------------------------------------------------------------------===// | |||
49 | ||||
50 | /// Helper function that does either an in place sort or sorts from source array | |||
51 | /// into destination. If inPlace then storage is both the source and the | |||
52 | /// destination, else value is the source and storage destination. Returns | |||
53 | /// whether source was sorted. | |||
54 | template <bool inPlace> | |||
55 | static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value, | |||
56 | SmallVectorImpl<NamedAttribute> &storage) { | |||
57 | // Specialize for the common case. | |||
58 | switch (value.size()) { | |||
59 | case 0: | |||
60 | // Zero already sorted. | |||
61 | if (!inPlace) | |||
62 | storage.clear(); | |||
63 | break; | |||
64 | case 1: | |||
65 | // One already sorted but may need to be copied. | |||
66 | if (!inPlace) | |||
67 | storage.assign({value[0]}); | |||
68 | break; | |||
69 | case 2: { | |||
70 | bool isSorted = value[0] < value[1]; | |||
71 | if (inPlace) { | |||
72 | if (!isSorted) | |||
73 | std::swap(storage[0], storage[1]); | |||
74 | } else if (isSorted) { | |||
75 | storage.assign({value[0], value[1]}); | |||
76 | } else { | |||
77 | storage.assign({value[1], value[0]}); | |||
78 | } | |||
79 | return !isSorted; | |||
80 | } | |||
81 | default: | |||
82 | if (!inPlace) | |||
83 | storage.assign(value.begin(), value.end()); | |||
84 | // Check to see they are sorted already. | |||
85 | bool isSorted = llvm::is_sorted(value); | |||
86 | // If not, do a general sort. | |||
87 | if (!isSorted) | |||
88 | llvm::array_pod_sort(storage.begin(), storage.end()); | |||
89 | return !isSorted; | |||
90 | } | |||
91 | return false; | |||
92 | } | |||
93 | ||||
94 | /// Returns an entry with a duplicate name from the given sorted array of named | |||
95 | /// attributes. Returns std::nullopt if all elements have unique names. | |||
96 | static Optional<NamedAttribute> | |||
97 | findDuplicateElement(ArrayRef<NamedAttribute> value) { | |||
98 | const Optional<NamedAttribute> none{std::nullopt}; | |||
99 | if (value.size() < 2) | |||
100 | return none; | |||
101 | ||||
102 | if (value.size() == 2) | |||
103 | return value[0].getName() == value[1].getName() ? value[0] : none; | |||
104 | ||||
105 | const auto *it = std::adjacent_find(value.begin(), value.end(), | |||
106 | [](NamedAttribute l, NamedAttribute r) { | |||
107 | return l.getName() == r.getName(); | |||
108 | }); | |||
109 | return it != value.end() ? *it : none; | |||
110 | } | |||
111 | ||||
112 | bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value, | |||
113 | SmallVectorImpl<NamedAttribute> &storage) { | |||
114 | bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage); | |||
115 | assert(!findDuplicateElement(storage) &&(static_cast <bool> (!findDuplicateElement(storage) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 116, __extension__ __PRETTY_FUNCTION__ )) | |||
116 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(storage) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 116, __extension__ __PRETTY_FUNCTION__ )); | |||
117 | return isSorted; | |||
118 | } | |||
119 | ||||
120 | bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) { | |||
121 | bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array); | |||
122 | assert(!findDuplicateElement(array) &&(static_cast <bool> (!findDuplicateElement(array) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 123, __extension__ __PRETTY_FUNCTION__ )) | |||
123 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(array) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 123, __extension__ __PRETTY_FUNCTION__ )); | |||
124 | return isSorted; | |||
125 | } | |||
126 | ||||
127 | Optional<NamedAttribute> | |||
128 | DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array, | |||
129 | bool isSorted) { | |||
130 | if (!isSorted) | |||
131 | dictionaryAttrSort</*inPlace=*/true>(array, array); | |||
132 | return findDuplicateElement(array); | |||
133 | } | |||
134 | ||||
135 | DictionaryAttr DictionaryAttr::get(MLIRContext *context, | |||
136 | ArrayRef<NamedAttribute> value) { | |||
137 | if (value.empty()) | |||
138 | return DictionaryAttr::getEmpty(context); | |||
139 | ||||
140 | // We need to sort the element list to canonicalize it. | |||
141 | SmallVector<NamedAttribute, 8> storage; | |||
142 | if (dictionaryAttrSort</*inPlace=*/false>(value, storage)) | |||
143 | value = storage; | |||
144 | assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 145, __extension__ __PRETTY_FUNCTION__ )) | |||
145 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 145, __extension__ __PRETTY_FUNCTION__ )); | |||
146 | return Base::get(context, value); | |||
147 | } | |||
148 | /// Construct a dictionary with an array of values that is known to already be | |||
149 | /// sorted by name and uniqued. | |||
150 | DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context, | |||
151 | ArrayRef<NamedAttribute> value) { | |||
152 | if (value.empty()) | |||
153 | return DictionaryAttr::getEmpty(context); | |||
154 | // Ensure that the attribute elements are unique and sorted. | |||
155 | assert(llvm::is_sorted((static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__ )) | |||
156 | value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__ )) | |||
157 | "expected attribute values to be sorted")(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted" ) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__ )); | |||
158 | assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 159, __extension__ __PRETTY_FUNCTION__ )) | |||
159 | "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) && "DictionaryAttr element names must be unique") ? void (0) : __assert_fail ("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 159, __extension__ __PRETTY_FUNCTION__ )); | |||
160 | return Base::get(context, value); | |||
161 | } | |||
162 | ||||
163 | /// Return the specified attribute if present, null otherwise. | |||
164 | Attribute DictionaryAttr::get(StringRef name) const { | |||
165 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
166 | return it.second ? it.first->getValue() : Attribute(); | |||
167 | } | |||
168 | Attribute DictionaryAttr::get(StringAttr name) const { | |||
169 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
170 | return it.second ? it.first->getValue() : Attribute(); | |||
171 | } | |||
172 | ||||
173 | /// Return the specified named attribute if present, std::nullopt otherwise. | |||
174 | Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const { | |||
175 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
176 | return it.second ? *it.first : Optional<NamedAttribute>(); | |||
177 | } | |||
178 | Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const { | |||
179 | auto it = impl::findAttrSorted(begin(), end(), name); | |||
180 | return it.second ? *it.first : Optional<NamedAttribute>(); | |||
181 | } | |||
182 | ||||
183 | /// Return whether the specified attribute is present. | |||
184 | bool DictionaryAttr::contains(StringRef name) const { | |||
185 | return impl::findAttrSorted(begin(), end(), name).second; | |||
186 | } | |||
187 | bool DictionaryAttr::contains(StringAttr name) const { | |||
188 | return impl::findAttrSorted(begin(), end(), name).second; | |||
189 | } | |||
190 | ||||
191 | DictionaryAttr::iterator DictionaryAttr::begin() const { | |||
192 | return getValue().begin(); | |||
193 | } | |||
194 | DictionaryAttr::iterator DictionaryAttr::end() const { | |||
195 | return getValue().end(); | |||
196 | } | |||
197 | size_t DictionaryAttr::size() const { return getValue().size(); } | |||
198 | ||||
199 | DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) { | |||
200 | return Base::get(context, ArrayRef<NamedAttribute>()); | |||
201 | } | |||
202 | ||||
203 | //===----------------------------------------------------------------------===// | |||
204 | // StridedLayoutAttr | |||
205 | //===----------------------------------------------------------------------===// | |||
206 | ||||
207 | /// Prints a strided layout attribute. | |||
208 | void StridedLayoutAttr::print(llvm::raw_ostream &os) const { | |||
209 | auto printIntOrQuestion = [&](int64_t value) { | |||
210 | if (ShapedType::isDynamic(value)) | |||
211 | os << "?"; | |||
212 | else | |||
213 | os << value; | |||
214 | }; | |||
215 | ||||
216 | os << "strided<["; | |||
217 | llvm::interleaveComma(getStrides(), os, printIntOrQuestion); | |||
218 | os << "]"; | |||
219 | ||||
220 | if (getOffset() != 0) { | |||
221 | os << ", offset: "; | |||
222 | printIntOrQuestion(getOffset()); | |||
223 | } | |||
224 | os << ">"; | |||
225 | } | |||
226 | ||||
227 | /// Returns the strided layout as an affine map. | |||
228 | AffineMap StridedLayoutAttr::getAffineMap() const { | |||
229 | return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext()); | |||
230 | } | |||
231 | ||||
232 | /// Checks that the type-agnostic strided layout invariants are satisfied. | |||
233 | LogicalResult | |||
234 | StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
235 | int64_t offset, ArrayRef<int64_t> strides) { | |||
236 | if (llvm::any_of(strides, [&](int64_t stride) { return stride == 0; })) | |||
237 | return emitError() << "strides must not be zero"; | |||
238 | ||||
239 | return success(); | |||
240 | } | |||
241 | ||||
242 | /// Checks that the type-specific strided layout invariants are satisfied. | |||
243 | LogicalResult StridedLayoutAttr::verifyLayout( | |||
244 | ArrayRef<int64_t> shape, | |||
245 | function_ref<InFlightDiagnostic()> emitError) const { | |||
246 | if (shape.size() != getStrides().size()) | |||
247 | return emitError() << "expected the number of strides to match the rank"; | |||
248 | ||||
249 | return success(); | |||
250 | } | |||
251 | ||||
252 | //===----------------------------------------------------------------------===// | |||
253 | // StringAttr | |||
254 | //===----------------------------------------------------------------------===// | |||
255 | ||||
256 | StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) { | |||
257 | return Base::get(context, "", NoneType::get(context)); | |||
258 | } | |||
259 | ||||
260 | /// Twine support for StringAttr. | |||
261 | StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) { | |||
262 | // Fast-path empty twine. | |||
263 | if (twine.isTriviallyEmpty()) | |||
264 | return get(context); | |||
265 | SmallVector<char, 32> tempStr; | |||
266 | return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context)); | |||
267 | } | |||
268 | ||||
269 | /// Twine support for StringAttr. | |||
270 | StringAttr StringAttr::get(const Twine &twine, Type type) { | |||
271 | SmallVector<char, 32> tempStr; | |||
272 | return Base::get(type.getContext(), twine.toStringRef(tempStr), type); | |||
273 | } | |||
274 | ||||
275 | StringRef StringAttr::getValue() const { return getImpl()->value; } | |||
276 | ||||
277 | Type StringAttr::getType() const { return getImpl()->type; } | |||
278 | ||||
279 | Dialect *StringAttr::getReferencedDialect() const { | |||
280 | return getImpl()->referencedDialect; | |||
281 | } | |||
282 | ||||
283 | //===----------------------------------------------------------------------===// | |||
284 | // FloatAttr | |||
285 | //===----------------------------------------------------------------------===// | |||
286 | ||||
287 | double FloatAttr::getValueAsDouble() const { | |||
288 | return getValueAsDouble(getValue()); | |||
289 | } | |||
290 | double FloatAttr::getValueAsDouble(APFloat value) { | |||
291 | if (&value.getSemantics() != &APFloat::IEEEdouble()) { | |||
292 | bool losesInfo = false; | |||
293 | value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven, | |||
294 | &losesInfo); | |||
295 | } | |||
296 | return value.convertToDouble(); | |||
297 | } | |||
298 | ||||
299 | LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
300 | Type type, APFloat value) { | |||
301 | // Verify that the type is correct. | |||
302 | if (!type.isa<FloatType>()) | |||
303 | return emitError() << "expected floating point type"; | |||
304 | ||||
305 | // Verify that the type semantics match that of the value. | |||
306 | if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) { | |||
307 | return emitError() | |||
308 | << "FloatAttr type doesn't match the type implied by its value"; | |||
309 | } | |||
310 | return success(); | |||
311 | } | |||
312 | ||||
313 | //===----------------------------------------------------------------------===// | |||
314 | // SymbolRefAttr | |||
315 | //===----------------------------------------------------------------------===// | |||
316 | ||||
317 | SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value, | |||
318 | ArrayRef<FlatSymbolRefAttr> nestedRefs) { | |||
319 | return get(StringAttr::get(ctx, value), nestedRefs); | |||
320 | } | |||
321 | ||||
322 | FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) { | |||
323 | return get(ctx, value, {}).cast<FlatSymbolRefAttr>(); | |||
324 | } | |||
325 | ||||
326 | FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) { | |||
327 | return get(value, {}).cast<FlatSymbolRefAttr>(); | |||
328 | } | |||
329 | ||||
330 | FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) { | |||
331 | auto symName = | |||
332 | symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName()); | |||
333 | assert(symName && "value does not have a valid symbol name")(static_cast <bool> (symName && "value does not have a valid symbol name" ) ? void (0) : __assert_fail ("symName && \"value does not have a valid symbol name\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 333, __extension__ __PRETTY_FUNCTION__ )); | |||
334 | return SymbolRefAttr::get(symName); | |||
335 | } | |||
336 | ||||
337 | StringAttr SymbolRefAttr::getLeafReference() const { | |||
338 | ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences(); | |||
339 | return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr(); | |||
340 | } | |||
341 | ||||
342 | //===----------------------------------------------------------------------===// | |||
343 | // IntegerAttr | |||
344 | //===----------------------------------------------------------------------===// | |||
345 | ||||
346 | int64_t IntegerAttr::getInt() const { | |||
347 | assert((getType().isIndex() || getType().isSignlessInteger()) &&(static_cast <bool> ((getType().isIndex() || getType(). isSignlessInteger()) && "must be signless integer") ? void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 348, __extension__ __PRETTY_FUNCTION__ )) | |||
348 | "must be signless integer")(static_cast <bool> ((getType().isIndex() || getType(). isSignlessInteger()) && "must be signless integer") ? void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 348, __extension__ __PRETTY_FUNCTION__ )); | |||
349 | return getValue().getSExtValue(); | |||
350 | } | |||
351 | ||||
352 | int64_t IntegerAttr::getSInt() const { | |||
353 | assert(getType().isSignedInteger() && "must be signed integer")(static_cast <bool> (getType().isSignedInteger() && "must be signed integer") ? void (0) : __assert_fail ("getType().isSignedInteger() && \"must be signed integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__ )); | |||
354 | return getValue().getSExtValue(); | |||
355 | } | |||
356 | ||||
357 | uint64_t IntegerAttr::getUInt() const { | |||
358 | assert(getType().isUnsignedInteger() && "must be unsigned integer")(static_cast <bool> (getType().isUnsignedInteger() && "must be unsigned integer") ? void (0) : __assert_fail ("getType().isUnsignedInteger() && \"must be unsigned integer\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 358, __extension__ __PRETTY_FUNCTION__ )); | |||
359 | return getValue().getZExtValue(); | |||
360 | } | |||
361 | ||||
362 | /// Return the value as an APSInt which carries the signed from the type of | |||
363 | /// the attribute. This traps on signless integers types! | |||
364 | APSInt IntegerAttr::getAPSInt() const { | |||
365 | assert(!getType().isSignlessInteger() &&(static_cast <bool> (!getType().isSignlessInteger() && "Signless integers don't carry a sign for APSInt") ? void (0 ) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 366, __extension__ __PRETTY_FUNCTION__ )) | |||
366 | "Signless integers don't carry a sign for APSInt")(static_cast <bool> (!getType().isSignlessInteger() && "Signless integers don't carry a sign for APSInt") ? void (0 ) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 366, __extension__ __PRETTY_FUNCTION__ )); | |||
367 | return APSInt(getValue(), getType().isUnsignedInteger()); | |||
368 | } | |||
369 | ||||
370 | LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
371 | Type type, APInt value) { | |||
372 | if (IntegerType integerType = type.dyn_cast<IntegerType>()) { | |||
373 | if (integerType.getWidth() != value.getBitWidth()) | |||
374 | return emitError() << "integer type bit width (" << integerType.getWidth() | |||
375 | << ") doesn't match value bit width (" | |||
376 | << value.getBitWidth() << ")"; | |||
377 | return success(); | |||
378 | } | |||
379 | if (type.isa<IndexType>()) { | |||
380 | if (value.getBitWidth() != IndexType::kInternalStorageBitWidth) | |||
381 | return emitError() | |||
382 | << "value bit width (" << value.getBitWidth() | |||
383 | << ") doesn't match index type internal storage bit width (" | |||
384 | << IndexType::kInternalStorageBitWidth << ")"; | |||
385 | return success(); | |||
386 | } | |||
387 | return emitError() << "expected integer or index type"; | |||
388 | } | |||
389 | ||||
390 | BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) { | |||
391 | auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value)); | |||
392 | return attr.cast<BoolAttr>(); | |||
393 | } | |||
394 | ||||
395 | //===----------------------------------------------------------------------===// | |||
396 | // BoolAttr | |||
397 | //===----------------------------------------------------------------------===// | |||
398 | ||||
399 | bool BoolAttr::getValue() const { | |||
400 | auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl); | |||
401 | return storage->value.getBoolValue(); | |||
402 | } | |||
403 | ||||
404 | bool BoolAttr::classof(Attribute attr) { | |||
405 | IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>(); | |||
406 | return intAttr && intAttr.getType().isSignlessInteger(1); | |||
407 | } | |||
408 | ||||
409 | //===----------------------------------------------------------------------===// | |||
410 | // OpaqueAttr | |||
411 | //===----------------------------------------------------------------------===// | |||
412 | ||||
413 | LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
414 | StringAttr dialect, StringRef attrData, | |||
415 | Type type) { | |||
416 | if (!Dialect::isValidNamespace(dialect.strref())) | |||
417 | return emitError() << "invalid dialect namespace '" << dialect << "'"; | |||
418 | ||||
419 | // Check that the dialect is actually registered. | |||
420 | MLIRContext *context = dialect.getContext(); | |||
421 | if (!context->allowsUnregisteredDialects() && | |||
422 | !context->getLoadedDialect(dialect.strref())) { | |||
423 | return emitError() | |||
424 | << "#" << dialect << "<\"" << attrData << "\"> : " << type | |||
425 | << " attribute created with unregistered dialect. If this is " | |||
426 | "intended, please call allowUnregisteredDialects() on the " | |||
427 | "MLIRContext, or use -allow-unregistered-dialect with " | |||
428 | "the MLIR opt tool used"; | |||
429 | } | |||
430 | ||||
431 | return success(); | |||
432 | } | |||
433 | ||||
434 | //===----------------------------------------------------------------------===// | |||
435 | // DenseElementsAttr Utilities | |||
436 | //===----------------------------------------------------------------------===// | |||
437 | ||||
438 | /// Get the bitwidth of a dense element type within the buffer. | |||
439 | /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. | |||
440 | static size_t getDenseElementStorageWidth(size_t origWidth) { | |||
441 | return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); | |||
442 | } | |||
443 | static size_t getDenseElementStorageWidth(Type elementType) { | |||
444 | return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); | |||
445 | } | |||
446 | ||||
447 | /// Set a bit to a specific value. | |||
448 | static void setBit(char *rawData, size_t bitPos, bool value) { | |||
449 | if (value) | |||
450 | rawData[bitPos / CHAR_BIT8] |= (1 << (bitPos % CHAR_BIT8)); | |||
451 | else | |||
452 | rawData[bitPos / CHAR_BIT8] &= ~(1 << (bitPos % CHAR_BIT8)); | |||
453 | } | |||
454 | ||||
455 | /// Return the value of the specified bit. | |||
456 | static bool getBit(const char *rawData, size_t bitPos) { | |||
457 | return (rawData[bitPos / CHAR_BIT8] & (1 << (bitPos % CHAR_BIT8))) != 0; | |||
458 | } | |||
459 | ||||
460 | /// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for | |||
461 | /// BE format. | |||
462 | static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes, | |||
463 | char *result) { | |||
464 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__ )) | |||
465 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
466 | assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes) ? void (0) : __assert_fail ("value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes" , "mlir/lib/IR/BuiltinAttributes.cpp", 466, __extension__ __PRETTY_FUNCTION__ )); | |||
467 | ||||
468 | // Copy the words filled with data. | |||
469 | // For example, when `value` has 2 words, the first word is filled with data. | |||
470 | // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--| | |||
471 | size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE; | |||
472 | std::copy_n(reinterpret_cast<const char *>(value.getRawData()), | |||
473 | numFilledWords, result); | |||
474 | // Convert last word of APInt to LE format and store it in char | |||
475 | // array(`valueLE`). | |||
476 | // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------| | |||
477 | size_t lastWordPos = numFilledWords; | |||
478 | SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE); | |||
479 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
480 | reinterpret_cast<const char *>(value.getRawData()) + lastWordPos, | |||
481 | valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1); | |||
482 | // Extract actual APInt data from `valueLE`, convert endianness to BE format, | |||
483 | // and store it in `result`. | |||
484 | // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij| | |||
485 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
486 | valueLE.begin(), result + lastWordPos, | |||
487 | (numBytes - lastWordPos) * CHAR_BIT8, 1); | |||
488 | } | |||
489 | ||||
490 | /// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE | |||
491 | /// format. | |||
492 | static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, | |||
493 | APInt &result) { | |||
494 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__ )) | |||
495 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
496 | assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes) ? void (0) : __assert_fail ("result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes" , "mlir/lib/IR/BuiltinAttributes.cpp", 496, __extension__ __PRETTY_FUNCTION__ )); | |||
497 | ||||
498 | // Copy the data that fills the word of `result` from `inArray`. | |||
499 | // For example, when `result` has 2 words, the first word will be filled with | |||
500 | // data. So, the first 8 bytes are copied from `inArray` here. | |||
501 | // `inArray` (10 bytes, BE): |abcdefgh|ij| | |||
502 | // ==> `result` (2 words, BE): |abcdefgh|--------| | |||
503 | size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE; | |||
504 | std::copy_n( | |||
505 | inArray, numFilledWords, | |||
506 | const_cast<char *>(reinterpret_cast<const char *>(result.getRawData()))); | |||
507 | ||||
508 | // Convert array data which will be last word of `result` to LE format, and | |||
509 | // store it in char array(`inArrayLE`). | |||
510 | // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------| | |||
511 | size_t lastWordPos = numFilledWords; | |||
512 | SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE); | |||
513 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
514 | inArray + lastWordPos, inArrayLE.begin(), | |||
515 | (numBytes - lastWordPos) * CHAR_BIT8, 1); | |||
516 | ||||
517 | // Convert `inArrayLE` to BE format, and store it in last word of `result`. | |||
518 | // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij| | |||
519 | DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
520 | inArrayLE.begin(), | |||
521 | const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) + | |||
522 | lastWordPos, | |||
523 | APInt::APINT_BITS_PER_WORD, 1); | |||
524 | } | |||
525 | ||||
526 | /// Writes value to the bit position `bitPos` in array `rawData`. | |||
527 | static void writeBits(char *rawData, size_t bitPos, APInt value) { | |||
528 | size_t bitWidth = value.getBitWidth(); | |||
529 | ||||
530 | // If the bitwidth is 1 we just toggle the specific bit. | |||
531 | if (bitWidth == 1) | |||
532 | return setBit(rawData, bitPos, value.isOneValue()); | |||
533 | ||||
534 | // Otherwise, the bit position is guaranteed to be byte aligned. | |||
535 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned" ) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 535, __extension__ __PRETTY_FUNCTION__ )); | |||
536 | if (llvm::support::endian::system_endianness() == | |||
537 | llvm::support::endianness::big) { | |||
538 | // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`. | |||
539 | // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't | |||
540 | // work correctly in BE format. | |||
541 | // ex. `value` (2 words including 10 bytes) | |||
542 | // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| | |||
543 | copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
544 | rawData + (bitPos / CHAR_BIT8)); | |||
545 | } else { | |||
546 | std::copy_n(reinterpret_cast<const char *>(value.getRawData()), | |||
547 | llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
548 | rawData + (bitPos / CHAR_BIT8)); | |||
549 | } | |||
550 | } | |||
551 | ||||
552 | /// Reads the next `bitWidth` bits from the bit position `bitPos` in array | |||
553 | /// `rawData`. | |||
554 | static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { | |||
555 | // Handle a boolean bit position. | |||
556 | if (bitWidth == 1) | |||
557 | return APInt(1, getBit(rawData, bitPos) ? 1 : 0); | |||
558 | ||||
559 | // Otherwise, the bit position must be 8-bit aligned. | |||
560 | assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned" ) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 560, __extension__ __PRETTY_FUNCTION__ )); | |||
561 | APInt result(bitWidth, 0); | |||
562 | if (llvm::support::endian::system_endianness() == | |||
563 | llvm::support::endianness::big) { | |||
564 | // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`. | |||
565 | // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't | |||
566 | // work correctly in BE format. | |||
567 | // ex. `result` (2 words including 10 bytes) | |||
568 | // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function | |||
569 | copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT8), | |||
570 | llvm::divideCeil(bitWidth, CHAR_BIT8), result); | |||
571 | } else { | |||
572 | std::copy_n(rawData + (bitPos / CHAR_BIT8), | |||
573 | llvm::divideCeil(bitWidth, CHAR_BIT8), | |||
574 | const_cast<char *>( | |||
575 | reinterpret_cast<const char *>(result.getRawData()))); | |||
576 | } | |||
577 | return result; | |||
578 | } | |||
579 | ||||
580 | /// Returns true if 'values' corresponds to a splat, i.e. one element, or has | |||
581 | /// the same element count as 'type'. | |||
582 | template <typename Values> | |||
583 | static bool hasSameElementsOrSplat(ShapedType type, const Values &values) { | |||
584 | return (values.size() == 1) || | |||
585 | (type.getNumElements() == static_cast<int64_t>(values.size())); | |||
586 | } | |||
587 | ||||
588 | //===----------------------------------------------------------------------===// | |||
589 | // DenseElementsAttr Iterators | |||
590 | //===----------------------------------------------------------------------===// | |||
591 | ||||
592 | //===----------------------------------------------------------------------===// | |||
593 | // AttributeElementIterator | |||
594 | ||||
595 | DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( | |||
596 | DenseElementsAttr attr, size_t index) | |||
597 | : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *, | |||
598 | Attribute, Attribute, Attribute>( | |||
599 | attr.getAsOpaquePointer(), index) {} | |||
600 | ||||
601 | Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { | |||
602 | auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>(); | |||
603 | Type eltTy = owner.getElementType(); | |||
604 | if (auto intEltTy = eltTy.dyn_cast<IntegerType>()) | |||
605 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); | |||
606 | if (eltTy.isa<IndexType>()) | |||
607 | return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); | |||
608 | if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) { | |||
609 | IntElementIterator intIt(owner, index); | |||
610 | FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); | |||
611 | return FloatAttr::get(eltTy, *floatIt); | |||
612 | } | |||
613 | if (auto complexTy = eltTy.dyn_cast<ComplexType>()) { | |||
614 | auto complexEltTy = complexTy.getElementType(); | |||
615 | ComplexIntElementIterator complexIntIt(owner, index); | |||
616 | if (complexEltTy.isa<IntegerType>()) { | |||
617 | auto value = *complexIntIt; | |||
618 | auto real = IntegerAttr::get(complexEltTy, value.real()); | |||
619 | auto imag = IntegerAttr::get(complexEltTy, value.imag()); | |||
620 | return ArrayAttr::get(complexTy.getContext(), | |||
621 | ArrayRef<Attribute>{real, imag}); | |||
622 | } | |||
623 | ||||
624 | ComplexFloatElementIterator complexFloatIt( | |||
625 | complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt); | |||
626 | auto value = *complexFloatIt; | |||
627 | auto real = FloatAttr::get(complexEltTy, value.real()); | |||
628 | auto imag = FloatAttr::get(complexEltTy, value.imag()); | |||
629 | return ArrayAttr::get(complexTy.getContext(), | |||
630 | ArrayRef<Attribute>{real, imag}); | |||
631 | } | |||
632 | if (owner.isa<DenseStringElementsAttr>()) { | |||
633 | ArrayRef<StringRef> vals = owner.getRawStringData(); | |||
634 | return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); | |||
635 | } | |||
636 | llvm_unreachable("unexpected element type")::llvm::llvm_unreachable_internal("unexpected element type", "mlir/lib/IR/BuiltinAttributes.cpp" , 636); | |||
637 | } | |||
638 | ||||
639 | //===----------------------------------------------------------------------===// | |||
640 | // BoolElementIterator | |||
641 | ||||
642 | DenseElementsAttr::BoolElementIterator::BoolElementIterator( | |||
643 | DenseElementsAttr attr, size_t dataIndex) | |||
644 | : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>( | |||
645 | attr.getRawData().data(), attr.isSplat(), dataIndex) {} | |||
646 | ||||
647 | bool DenseElementsAttr::BoolElementIterator::operator*() const { | |||
648 | return getBit(getData(), getDataIndex()); | |||
649 | } | |||
650 | ||||
651 | //===----------------------------------------------------------------------===// | |||
652 | // IntElementIterator | |||
653 | ||||
654 | DenseElementsAttr::IntElementIterator::IntElementIterator( | |||
655 | DenseElementsAttr attr, size_t dataIndex) | |||
656 | : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>( | |||
657 | attr.getRawData().data(), attr.isSplat(), dataIndex), | |||
658 | bitWidth(getDenseElementBitWidth(attr.getElementType())) {} | |||
659 | ||||
660 | APInt DenseElementsAttr::IntElementIterator::operator*() const { | |||
661 | return readBits(getData(), | |||
662 | getDataIndex() * getDenseElementStorageWidth(bitWidth), | |||
663 | bitWidth); | |||
664 | } | |||
665 | ||||
666 | //===----------------------------------------------------------------------===// | |||
667 | // ComplexIntElementIterator | |||
668 | ||||
669 | DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator( | |||
670 | DenseElementsAttr attr, size_t dataIndex) | |||
671 | : DenseElementIndexedIteratorImpl<ComplexIntElementIterator, | |||
672 | std::complex<APInt>, std::complex<APInt>, | |||
673 | std::complex<APInt>>( | |||
674 | attr.getRawData().data(), attr.isSplat(), dataIndex) { | |||
675 | auto complexType = attr.getElementType().cast<ComplexType>(); | |||
676 | bitWidth = getDenseElementBitWidth(complexType.getElementType()); | |||
677 | } | |||
678 | ||||
679 | std::complex<APInt> | |||
680 | DenseElementsAttr::ComplexIntElementIterator::operator*() const { | |||
681 | size_t storageWidth = getDenseElementStorageWidth(bitWidth); | |||
682 | size_t offset = getDataIndex() * storageWidth * 2; | |||
683 | return {readBits(getData(), offset, bitWidth), | |||
684 | readBits(getData(), offset + storageWidth, bitWidth)}; | |||
685 | } | |||
686 | ||||
687 | //===----------------------------------------------------------------------===// | |||
688 | // DenseArrayAttr | |||
689 | //===----------------------------------------------------------------------===// | |||
690 | ||||
691 | LogicalResult | |||
692 | DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
693 | Type elementType, int64_t size, ArrayRef<char> rawData) { | |||
694 | if (!elementType.isIntOrIndexOrFloat()) | |||
695 | return emitError() << "expected integer or floating point element type"; | |||
696 | int64_t dataSize = rawData.size(); | |||
697 | int64_t elementSize = | |||
698 | llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT8); | |||
699 | if (size * elementSize != dataSize) { | |||
700 | return emitError() << "expected data size (" << size << " elements, " | |||
701 | << elementSize | |||
702 | << " bytes each) does not match: " << dataSize | |||
703 | << " bytes"; | |||
704 | } | |||
705 | return success(); | |||
706 | } | |||
707 | ||||
708 | namespace { | |||
709 | /// Instantiations of this class provide utilities for interacting with native | |||
710 | /// data types in the context of DenseArrayAttr. | |||
711 | template <size_t width, | |||
712 | IntegerType::SignednessSemantics signedness = IntegerType::Signless> | |||
713 | struct DenseArrayAttrIntUtil { | |||
714 | static bool checkElementType(Type eltType) { | |||
715 | auto type = eltType.dyn_cast<IntegerType>(); | |||
716 | if (!type || type.getWidth() != width) | |||
717 | return false; | |||
718 | return type.getSignedness() == signedness; | |||
719 | } | |||
720 | ||||
721 | static Type getElementType(MLIRContext *ctx) { | |||
722 | return IntegerType::get(ctx, width, signedness); | |||
723 | } | |||
724 | ||||
725 | template <typename T> | |||
726 | static void printElement(raw_ostream &os, T value) { | |||
727 | os << value; | |||
728 | } | |||
729 | ||||
730 | template <typename T> | |||
731 | static ParseResult parseElement(AsmParser &parser, T &value) { | |||
732 | return parser.parseInteger(value); | |||
733 | } | |||
734 | }; | |||
735 | template <typename T> | |||
736 | struct DenseArrayAttrUtil; | |||
737 | ||||
738 | /// Specialization for boolean elements to print 'true' and 'false' literals for | |||
739 | /// elements. | |||
740 | template <> | |||
741 | struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> { | |||
742 | static void printElement(raw_ostream &os, bool value) { | |||
743 | os << (value ? "true" : "false"); | |||
744 | } | |||
745 | }; | |||
746 | ||||
747 | /// Specialization for 8-bit integers to ensure values are printed as integers | |||
748 | /// and not characters. | |||
749 | template <> | |||
750 | struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> { | |||
751 | static void printElement(raw_ostream &os, int8_t value) { | |||
752 | os << static_cast<int>(value); | |||
753 | } | |||
754 | }; | |||
755 | template <> | |||
756 | struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {}; | |||
757 | template <> | |||
758 | struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {}; | |||
759 | template <> | |||
760 | struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {}; | |||
761 | ||||
762 | /// Specialization for 32-bit floats. | |||
763 | template <> | |||
764 | struct DenseArrayAttrUtil<float> { | |||
765 | static bool checkElementType(Type eltType) { return eltType.isF32(); } | |||
766 | static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); } | |||
767 | static void printElement(raw_ostream &os, float value) { os << value; } | |||
768 | ||||
769 | /// Parse a double and cast it to a float. | |||
770 | static ParseResult parseElement(AsmParser &parser, float &value) { | |||
771 | double doubleVal; | |||
772 | if (parser.parseFloat(doubleVal)) | |||
773 | return failure(); | |||
774 | value = doubleVal; | |||
775 | return success(); | |||
776 | } | |||
777 | }; | |||
778 | ||||
779 | /// Specialization for 64-bit floats. | |||
780 | template <> | |||
781 | struct DenseArrayAttrUtil<double> { | |||
782 | static bool checkElementType(Type eltType) { return eltType.isF64(); } | |||
783 | static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); } | |||
784 | static void printElement(raw_ostream &os, float value) { os << value; } | |||
785 | static ParseResult parseElement(AsmParser &parser, double &value) { | |||
786 | return parser.parseFloat(value); | |||
787 | } | |||
788 | }; | |||
789 | } // namespace | |||
790 | ||||
791 | template <typename T> | |||
792 | void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const { | |||
793 | print(printer.getStream()); | |||
794 | } | |||
795 | ||||
796 | template <typename T> | |||
797 | void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const { | |||
798 | llvm::interleaveComma(asArrayRef(), os, [&](T value) { | |||
799 | DenseArrayAttrUtil<T>::printElement(os, value); | |||
800 | }); | |||
801 | } | |||
802 | ||||
803 | template <typename T> | |||
804 | void DenseArrayAttrImpl<T>::print(raw_ostream &os) const { | |||
805 | os << "["; | |||
806 | printWithoutBraces(os); | |||
807 | os << "]"; | |||
808 | } | |||
809 | ||||
810 | /// Parse a DenseArrayAttr without the braces: `1, 2, 3` | |||
811 | template <typename T> | |||
812 | Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser, | |||
813 | Type odsType) { | |||
814 | SmallVector<T> data; | |||
815 | if (failed(parser.parseCommaSeparatedList([&]() { | |||
816 | T value; | |||
| ||||
817 | if (DenseArrayAttrUtil<T>::parseElement(parser, value)) | |||
818 | return failure(); | |||
819 | data.push_back(value); | |||
| ||||
820 | return success(); | |||
821 | }))) | |||
822 | return {}; | |||
823 | return get(parser.getContext(), data); | |||
824 | } | |||
825 | ||||
826 | /// Parse a DenseArrayAttr: `[ 1, 2, 3 ]` | |||
827 | template <typename T> | |||
828 | Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) { | |||
829 | if (parser.parseLSquare()) | |||
830 | return {}; | |||
831 | // Handle empty list case. | |||
832 | if (succeeded(parser.parseOptionalRSquare())) | |||
833 | return get(parser.getContext(), {}); | |||
834 | Attribute result = parseWithoutBraces(parser, odsType); | |||
835 | if (parser.parseRSquare()) | |||
836 | return {}; | |||
837 | return result; | |||
838 | } | |||
839 | ||||
840 | /// Conversion from DenseArrayAttr<T> to ArrayRef<T>. | |||
841 | template <typename T> | |||
842 | DenseArrayAttrImpl<T>::operator ArrayRef<T>() const { | |||
843 | ArrayRef<char> raw = getRawData(); | |||
844 | assert((raw.size() % sizeof(T)) == 0)(static_cast <bool> ((raw.size() % sizeof(T)) == 0) ? void (0) : __assert_fail ("(raw.size() % sizeof(T)) == 0", "mlir/lib/IR/BuiltinAttributes.cpp" , 844, __extension__ __PRETTY_FUNCTION__)); | |||
845 | return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()), | |||
846 | raw.size() / sizeof(T)); | |||
847 | } | |||
848 | ||||
849 | /// Builds a DenseArrayAttr<T> from an ArrayRef<T>. | |||
850 | template <typename T> | |||
851 | DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context, | |||
852 | ArrayRef<T> content) { | |||
853 | Type elementType = DenseArrayAttrUtil<T>::getElementType(context); | |||
854 | auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), | |||
855 | content.size() * sizeof(T)); | |||
856 | return llvm::cast<DenseArrayAttrImpl<T>>( | |||
857 | Base::get(context, elementType, content.size(), rawArray)); | |||
858 | } | |||
859 | ||||
860 | template <typename T> | |||
861 | bool DenseArrayAttrImpl<T>::classof(Attribute attr) { | |||
862 | if (auto denseArray = attr.dyn_cast<DenseArrayAttr>()) | |||
863 | return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType()); | |||
864 | return false; | |||
865 | } | |||
866 | ||||
867 | namespace mlir { | |||
868 | namespace detail { | |||
869 | // Explicit instantiation for all the supported DenseArrayAttr. | |||
870 | template class DenseArrayAttrImpl<bool>; | |||
871 | template class DenseArrayAttrImpl<int8_t>; | |||
872 | template class DenseArrayAttrImpl<int16_t>; | |||
873 | template class DenseArrayAttrImpl<int32_t>; | |||
874 | template class DenseArrayAttrImpl<int64_t>; | |||
875 | template class DenseArrayAttrImpl<float>; | |||
876 | template class DenseArrayAttrImpl<double>; | |||
877 | } // namespace detail | |||
878 | } // namespace mlir | |||
879 | ||||
880 | //===----------------------------------------------------------------------===// | |||
881 | // DenseElementsAttr | |||
882 | //===----------------------------------------------------------------------===// | |||
883 | ||||
884 | /// Method for support type inquiry through isa, cast and dyn_cast. | |||
885 | bool DenseElementsAttr::classof(Attribute attr) { | |||
886 | return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>(); | |||
887 | } | |||
888 | ||||
889 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
890 | ArrayRef<Attribute> values) { | |||
891 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 891, __extension__ __PRETTY_FUNCTION__ )); | |||
892 | ||||
893 | // If the element type is not based on int/float/index, assume it is a string | |||
894 | // type. | |||
895 | Type eltType = type.getElementType(); | |||
896 | if (!eltType.isIntOrIndexOrFloat()) { | |||
897 | SmallVector<StringRef, 8> stringValues; | |||
898 | stringValues.reserve(values.size()); | |||
899 | for (Attribute attr : values) { | |||
900 | assert(attr.isa<StringAttr>() &&(static_cast <bool> (attr.isa<StringAttr>() && "expected string value for non integer/index/float element") ? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 901, __extension__ __PRETTY_FUNCTION__ )) | |||
901 | "expected string value for non integer/index/float element")(static_cast <bool> (attr.isa<StringAttr>() && "expected string value for non integer/index/float element") ? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 901, __extension__ __PRETTY_FUNCTION__ )); | |||
902 | stringValues.push_back(attr.cast<StringAttr>().getValue()); | |||
903 | } | |||
904 | return get(type, stringValues); | |||
905 | } | |||
906 | ||||
907 | // Otherwise, get the raw storage width to use for the allocation. | |||
908 | size_t bitWidth = getDenseElementBitWidth(eltType); | |||
909 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); | |||
910 | ||||
911 | // Compress the attribute values into a character buffer. | |||
912 | SmallVector<char, 8> data( | |||
913 | llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT8)); | |||
914 | APInt intVal; | |||
915 | for (unsigned i = 0, e = values.size(); i < e; ++i) { | |||
916 | if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) { | |||
917 | assert(floatAttr.getType() == eltType &&(static_cast <bool> (floatAttr.getType() == eltType && "expected float attribute type to equal element type") ? void (0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__ )) | |||
918 | "expected float attribute type to equal element type")(static_cast <bool> (floatAttr.getType() == eltType && "expected float attribute type to equal element type") ? void (0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__ )); | |||
919 | intVal = floatAttr.getValue().bitcastToAPInt(); | |||
920 | } else { | |||
921 | auto intAttr = values[i].cast<IntegerAttr>(); | |||
922 | assert(intAttr.getType() == eltType &&(static_cast <bool> (intAttr.getType() == eltType && "expected integer attribute type to equal element type") ? void (0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 923, __extension__ __PRETTY_FUNCTION__ )) | |||
923 | "expected integer attribute type to equal element type")(static_cast <bool> (intAttr.getType() == eltType && "expected integer attribute type to equal element type") ? void (0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 923, __extension__ __PRETTY_FUNCTION__ )); | |||
924 | intVal = intAttr.getValue(); | |||
925 | } | |||
926 | ||||
927 | assert(intVal.getBitWidth() == bitWidth &&(static_cast <bool> (intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type") ? void (0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__ )) | |||
928 | "expected value to have same bitwidth as element type")(static_cast <bool> (intVal.getBitWidth() == bitWidth && "expected value to have same bitwidth as element type") ? void (0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__ )); | |||
929 | writeBits(data.data(), i * storageBitWidth, intVal); | |||
930 | } | |||
931 | ||||
932 | // Handle the special encoding of splat of bool. | |||
933 | if (values.size() == 1 && eltType.isInteger(1)) | |||
934 | data[0] = data[0] ? -1 : 0; | |||
935 | ||||
936 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
937 | } | |||
938 | ||||
939 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
940 | ArrayRef<bool> values) { | |||
941 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 941, __extension__ __PRETTY_FUNCTION__ )); | |||
942 | assert(type.getElementType().isInteger(1))(static_cast <bool> (type.getElementType().isInteger(1) ) ? void (0) : __assert_fail ("type.getElementType().isInteger(1)" , "mlir/lib/IR/BuiltinAttributes.cpp", 942, __extension__ __PRETTY_FUNCTION__ )); | |||
943 | ||||
944 | std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT8)); | |||
945 | ||||
946 | if (!values.empty()) { | |||
947 | bool isSplat = true; | |||
948 | bool firstValue = values[0]; | |||
949 | for (int i = 0, e = values.size(); i != e; ++i) { | |||
950 | isSplat &= values[i] == firstValue; | |||
951 | setBit(buff.data(), i, values[i]); | |||
952 | } | |||
953 | ||||
954 | // Splat of bool is encoded as a byte with all-ones in it. | |||
955 | if (isSplat) { | |||
956 | buff.resize(1); | |||
957 | buff[0] = values[0] ? -1 : 0; | |||
958 | } | |||
959 | } | |||
960 | ||||
961 | return DenseIntOrFPElementsAttr::getRaw(type, buff); | |||
962 | } | |||
963 | ||||
964 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
965 | ArrayRef<StringRef> values) { | |||
966 | assert(!type.getElementType().isIntOrFloat())(static_cast <bool> (!type.getElementType().isIntOrFloat ()) ? void (0) : __assert_fail ("!type.getElementType().isIntOrFloat()" , "mlir/lib/IR/BuiltinAttributes.cpp", 966, __extension__ __PRETTY_FUNCTION__ )); | |||
967 | return DenseStringElementsAttr::get(type, values); | |||
968 | } | |||
969 | ||||
970 | /// Constructs a dense integer elements attribute from an array of APInt | |||
971 | /// values. Each APInt value is expected to have the same bitwidth as the | |||
972 | /// element type of 'type'. | |||
973 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
974 | ArrayRef<APInt> values) { | |||
975 | assert(type.getElementType().isIntOrIndex())(static_cast <bool> (type.getElementType().isIntOrIndex ()) ? void (0) : __assert_fail ("type.getElementType().isIntOrIndex()" , "mlir/lib/IR/BuiltinAttributes.cpp", 975, __extension__ __PRETTY_FUNCTION__ )); | |||
976 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 976, __extension__ __PRETTY_FUNCTION__ )); | |||
977 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); | |||
978 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); | |||
979 | } | |||
980 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
981 | ArrayRef<std::complex<APInt>> values) { | |||
982 | ComplexType complex = type.getElementType().cast<ComplexType>(); | |||
983 | assert(complex.getElementType().isa<IntegerType>())(static_cast <bool> (complex.getElementType().isa<IntegerType >()) ? void (0) : __assert_fail ("complex.getElementType().isa<IntegerType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 983, __extension__ __PRETTY_FUNCTION__ )); | |||
984 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 984, __extension__ __PRETTY_FUNCTION__ )); | |||
985 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; | |||
986 | ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()), | |||
987 | values.size() * 2); | |||
988 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals); | |||
989 | } | |||
990 | ||||
991 | // Constructs a dense float elements attribute from an array of APFloat | |||
992 | // values. Each APFloat value is expected to have the same bitwidth as the | |||
993 | // element type of 'type'. | |||
994 | DenseElementsAttr DenseElementsAttr::get(ShapedType type, | |||
995 | ArrayRef<APFloat> values) { | |||
996 | assert(type.getElementType().isa<FloatType>())(static_cast <bool> (type.getElementType().isa<FloatType >()) ? void (0) : __assert_fail ("type.getElementType().isa<FloatType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 996, __extension__ __PRETTY_FUNCTION__ )); | |||
997 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 997, __extension__ __PRETTY_FUNCTION__ )); | |||
998 | size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType()); | |||
999 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values); | |||
1000 | } | |||
1001 | DenseElementsAttr | |||
1002 | DenseElementsAttr::get(ShapedType type, | |||
1003 | ArrayRef<std::complex<APFloat>> values) { | |||
1004 | ComplexType complex = type.getElementType().cast<ComplexType>(); | |||
1005 | assert(complex.getElementType().isa<FloatType>())(static_cast <bool> (complex.getElementType().isa<FloatType >()) ? void (0) : __assert_fail ("complex.getElementType().isa<FloatType>()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1005, __extension__ __PRETTY_FUNCTION__ )); | |||
1006 | assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values )) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1006, __extension__ __PRETTY_FUNCTION__ )); | |||
1007 | ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()), | |||
1008 | values.size() * 2); | |||
1009 | size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2; | |||
1010 | return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals); | |||
1011 | } | |||
1012 | ||||
1013 | /// Construct a dense elements attribute from a raw buffer representing the | |||
1014 | /// data for this attribute. Users should generally not use this methods as | |||
1015 | /// the expected buffer format may not be a form the user expects. | |||
1016 | DenseElementsAttr | |||
1017 | DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) { | |||
1018 | return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer); | |||
1019 | } | |||
1020 | ||||
1021 | /// Returns true if the given buffer is a valid raw buffer for the given type. | |||
1022 | bool DenseElementsAttr::isValidRawBuffer(ShapedType type, | |||
1023 | ArrayRef<char> rawBuffer, | |||
1024 | bool &detectedSplat) { | |||
1025 | size_t storageWidth = getDenseElementStorageWidth(type.getElementType()); | |||
1026 | size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT8; | |||
1027 | int64_t numElements = type.getNumElements(); | |||
1028 | ||||
1029 | // The initializer is always a splat if the result type has a single element. | |||
1030 | detectedSplat = numElements == 1; | |||
1031 | ||||
1032 | // Storage width of 1 is special as it is packed by the bit. | |||
1033 | if (storageWidth == 1) { | |||
1034 | // Check for a splat, or a buffer equal to the number of elements which | |||
1035 | // consists of either all 0's or all 1's. | |||
1036 | if (rawBuffer.size() == 1) { | |||
1037 | auto rawByte = static_cast<uint8_t>(rawBuffer[0]); | |||
1038 | if (rawByte == 0 || rawByte == 0xff) { | |||
1039 | detectedSplat = true; | |||
1040 | return true; | |||
1041 | } | |||
1042 | } | |||
1043 | ||||
1044 | // This is a valid non-splat buffer if it has the right size. | |||
1045 | return rawBufferWidth == llvm::alignTo<8>(numElements); | |||
1046 | } | |||
1047 | ||||
1048 | // All other types are 8-bit aligned, so we can just check the buffer width | |||
1049 | // to know if only a single initializer element was passed in. | |||
1050 | if (rawBufferWidth == storageWidth) { | |||
1051 | detectedSplat = true; | |||
1052 | return true; | |||
1053 | } | |||
1054 | ||||
1055 | // The raw buffer is valid if it has the right size. | |||
1056 | return rawBufferWidth == storageWidth * numElements; | |||
1057 | } | |||
1058 | ||||
1059 | /// Check the information for a C++ data type, check if this type is valid for | |||
1060 | /// the current attribute. This method is used to verify specific type | |||
1061 | /// invariants that the templatized 'getValues' method cannot. | |||
1062 | static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt, | |||
1063 | bool isSigned) { | |||
1064 | // Make sure that the data element size is the same as the type element width. | |||
1065 | if (getDenseElementBitWidth(type) != | |||
1066 | static_cast<size_t>(dataEltSize * CHAR_BIT8)) | |||
1067 | return false; | |||
1068 | ||||
1069 | // Check that the element type is either float or integer or index. | |||
1070 | if (!isInt) | |||
1071 | return type.isa<FloatType>(); | |||
1072 | if (type.isIndex()) | |||
1073 | return true; | |||
1074 | ||||
1075 | auto intType = type.dyn_cast<IntegerType>(); | |||
1076 | if (!intType) | |||
1077 | return false; | |||
1078 | ||||
1079 | // Make sure signedness semantics is consistent. | |||
1080 | if (intType.isSignless()) | |||
1081 | return true; | |||
1082 | return intType.isSigned() ? isSigned : !isSigned; | |||
1083 | } | |||
1084 | ||||
1085 | /// Defaults down the subclass implementation. | |||
1086 | DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type, | |||
1087 | ArrayRef<char> data, | |||
1088 | int64_t dataEltSize, | |||
1089 | bool isInt, bool isSigned) { | |||
1090 | return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt, | |||
1091 | isSigned); | |||
1092 | } | |||
1093 | DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type, | |||
1094 | ArrayRef<char> data, | |||
1095 | int64_t dataEltSize, | |||
1096 | bool isInt, | |||
1097 | bool isSigned) { | |||
1098 | return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize, | |||
1099 | isInt, isSigned); | |||
1100 | } | |||
1101 | ||||
1102 | bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt, | |||
1103 | bool isSigned) const { | |||
1104 | return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned); | |||
1105 | } | |||
1106 | bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt, | |||
1107 | bool isSigned) const { | |||
1108 | return ::isValidIntOrFloat( | |||
1109 | getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, | |||
1110 | isInt, isSigned); | |||
1111 | } | |||
1112 | ||||
1113 | /// Returns true if this attribute corresponds to a splat, i.e. if all element | |||
1114 | /// values are the same. | |||
1115 | bool DenseElementsAttr::isSplat() const { | |||
1116 | return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat; | |||
1117 | } | |||
1118 | ||||
1119 | /// Return if the given complex type has an integer element type. | |||
1120 | static bool isComplexOfIntType(Type type) { | |||
1121 | return type.cast<ComplexType>().getElementType().isa<IntegerType>(); | |||
1122 | } | |||
1123 | ||||
1124 | auto DenseElementsAttr::tryGetComplexIntValues() const | |||
1125 | -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> { | |||
1126 | if (!isComplexOfIntType(getElementType())) | |||
1127 | return failure(); | |||
1128 | return iterator_range_impl<ComplexIntElementIterator>( | |||
1129 | getType(), ComplexIntElementIterator(*this, 0), | |||
1130 | ComplexIntElementIterator(*this, getNumElements())); | |||
1131 | } | |||
1132 | ||||
1133 | auto DenseElementsAttr::tryGetFloatValues() const | |||
1134 | -> FailureOr<iterator_range_impl<FloatElementIterator>> { | |||
1135 | auto eltTy = getElementType().dyn_cast<FloatType>(); | |||
1136 | if (!eltTy) | |||
1137 | return failure(); | |||
1138 | const auto &elementSemantics = eltTy.getFloatSemantics(); | |||
1139 | return iterator_range_impl<FloatElementIterator>( | |||
1140 | getType(), FloatElementIterator(elementSemantics, raw_int_begin()), | |||
1141 | FloatElementIterator(elementSemantics, raw_int_end())); | |||
1142 | } | |||
1143 | ||||
1144 | auto DenseElementsAttr::tryGetComplexFloatValues() const | |||
1145 | -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> { | |||
1146 | auto complexTy = getElementType().dyn_cast<ComplexType>(); | |||
1147 | if (!complexTy) | |||
1148 | return failure(); | |||
1149 | auto eltTy = complexTy.getElementType().dyn_cast<FloatType>(); | |||
1150 | if (!eltTy) | |||
1151 | return failure(); | |||
1152 | const auto &semantics = eltTy.getFloatSemantics(); | |||
1153 | return iterator_range_impl<ComplexFloatElementIterator>( | |||
1154 | getType(), {semantics, {*this, 0}}, | |||
1155 | {semantics, {*this, static_cast<size_t>(getNumElements())}}); | |||
1156 | } | |||
1157 | ||||
1158 | /// Return the raw storage data held by this attribute. | |||
1159 | ArrayRef<char> DenseElementsAttr::getRawData() const { | |||
1160 | return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data; | |||
1161 | } | |||
1162 | ||||
1163 | ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const { | |||
1164 | return static_cast<DenseStringElementsAttrStorage *>(impl)->data; | |||
1165 | } | |||
1166 | ||||
1167 | /// Return a new DenseElementsAttr that has the same data as the current | |||
1168 | /// attribute, but has been reshaped to 'newType'. The new type must have the | |||
1169 | /// same total number of elements as well as element type. | |||
1170 | DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) { | |||
1171 | ShapedType curType = getType(); | |||
1172 | if (curType == newType) | |||
1173 | return *this; | |||
1174 | ||||
1175 | assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1176, __extension__ __PRETTY_FUNCTION__ )) | |||
1176 | "expected the same element type")(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1176, __extension__ __PRETTY_FUNCTION__ )); | |||
1177 | assert(newType.getNumElements() == curType.getNumElements() &&(static_cast <bool> (newType.getNumElements() == curType .getNumElements() && "expected the same number of elements" ) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1178, __extension__ __PRETTY_FUNCTION__ )) | |||
1178 | "expected the same number of elements")(static_cast <bool> (newType.getNumElements() == curType .getNumElements() && "expected the same number of elements" ) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1178, __extension__ __PRETTY_FUNCTION__ )); | |||
1179 | return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); | |||
1180 | } | |||
1181 | ||||
1182 | DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) { | |||
1183 | assert(isSplat() && "expected a splat type")(static_cast <bool> (isSplat() && "expected a splat type" ) ? void (0) : __assert_fail ("isSplat() && \"expected a splat type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1183, __extension__ __PRETTY_FUNCTION__ )); | |||
1184 | ||||
1185 | ShapedType curType = getType(); | |||
1186 | if (curType == newType) | |||
1187 | return *this; | |||
1188 | ||||
1189 | assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1190, __extension__ __PRETTY_FUNCTION__ )) | |||
1190 | "expected the same element type")(static_cast <bool> (newType.getElementType() == curType .getElementType() && "expected the same element type" ) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1190, __extension__ __PRETTY_FUNCTION__ )); | |||
1191 | return DenseIntOrFPElementsAttr::getRaw(newType, getRawData()); | |||
1192 | } | |||
1193 | ||||
1194 | /// Return a new DenseElementsAttr that has the same data as the current | |||
1195 | /// attribute, but has bitcast elements such that it is now 'newType'. The new | |||
1196 | /// type must have the same shape and element types of the same bitwidth as the | |||
1197 | /// current type. | |||
1198 | DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) { | |||
1199 | ShapedType curType = getType(); | |||
1200 | Type curElType = curType.getElementType(); | |||
1201 | if (curElType == newElType) | |||
1202 | return *this; | |||
1203 | ||||
1204 | assert(getDenseElementBitWidth(newElType) ==(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__ )) | |||
1205 | getDenseElementBitWidth(curElType) &&(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__ )) | |||
1206 | "expected element types with the same bitwidth")(static_cast <bool> (getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth" ) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__ )); | |||
1207 | return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType), | |||
1208 | getRawData()); | |||
1209 | } | |||
1210 | ||||
1211 | DenseElementsAttr | |||
1212 | DenseElementsAttr::mapValues(Type newElementType, | |||
1213 | function_ref<APInt(const APInt &)> mapping) const { | |||
1214 | return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping); | |||
1215 | } | |||
1216 | ||||
1217 | DenseElementsAttr DenseElementsAttr::mapValues( | |||
1218 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { | |||
1219 | return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping); | |||
1220 | } | |||
1221 | ||||
1222 | ShapedType DenseElementsAttr::getType() const { | |||
1223 | return static_cast<const DenseElementsAttributeStorage *>(impl)->type; | |||
1224 | } | |||
1225 | ||||
1226 | Type DenseElementsAttr::getElementType() const { | |||
1227 | return getType().getElementType(); | |||
1228 | } | |||
1229 | ||||
1230 | int64_t DenseElementsAttr::getNumElements() const { | |||
1231 | return getType().getNumElements(); | |||
1232 | } | |||
1233 | ||||
1234 | //===----------------------------------------------------------------------===// | |||
1235 | // DenseIntOrFPElementsAttr | |||
1236 | //===----------------------------------------------------------------------===// | |||
1237 | ||||
1238 | /// Utility method to write a range of APInt values to a buffer. | |||
1239 | template <typename APRangeT> | |||
1240 | static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data, | |||
1241 | APRangeT &&values) { | |||
1242 | size_t numValues = llvm::size(values); | |||
1243 | data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT8)); | |||
1244 | size_t offset = 0; | |||
1245 | for (auto it = values.begin(), e = values.end(); it != e; | |||
1246 | ++it, offset += storageWidth) { | |||
1247 | assert((*it).getBitWidth() <= storageWidth)(static_cast <bool> ((*it).getBitWidth() <= storageWidth ) ? void (0) : __assert_fail ("(*it).getBitWidth() <= storageWidth" , "mlir/lib/IR/BuiltinAttributes.cpp", 1247, __extension__ __PRETTY_FUNCTION__ )); | |||
1248 | writeBits(data.data(), offset, *it); | |||
1249 | } | |||
1250 | ||||
1251 | // Handle the special encoding of splat of a boolean. | |||
1252 | if (numValues == 1 && (*values.begin()).getBitWidth() == 1) | |||
1253 | data[0] = data[0] ? -1 : 0; | |||
1254 | } | |||
1255 | ||||
1256 | /// Constructs a dense elements attribute from an array of raw APFloat values. | |||
1257 | /// Each APFloat value is expected to have the same bitwidth as the element | |||
1258 | /// type of 'type'. 'type' must be a vector or tensor with static shape. | |||
1259 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1260 | size_t storageWidth, | |||
1261 | ArrayRef<APFloat> values) { | |||
1262 | std::vector<char> data; | |||
1263 | auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); }; | |||
1264 | writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat)); | |||
1265 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
1266 | } | |||
1267 | ||||
1268 | /// Constructs a dense elements attribute from an array of raw APInt values. | |||
1269 | /// Each APInt value is expected to have the same bitwidth as the element type | |||
1270 | /// of 'type'. | |||
1271 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1272 | size_t storageWidth, | |||
1273 | ArrayRef<APInt> values) { | |||
1274 | std::vector<char> data; | |||
1275 | writeAPIntsToBuffer(storageWidth, data, values); | |||
1276 | return DenseIntOrFPElementsAttr::getRaw(type, data); | |||
1277 | } | |||
1278 | ||||
1279 | DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type, | |||
1280 | ArrayRef<char> data) { | |||
1281 | assert(type.hasStaticShape() && "type must have static shape")(static_cast <bool> (type.hasStaticShape() && "type must have static shape" ) ? void (0) : __assert_fail ("type.hasStaticShape() && \"type must have static shape\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1281, __extension__ __PRETTY_FUNCTION__ )); | |||
1282 | bool isSplat = false; | |||
1283 | bool isValid = isValidRawBuffer(type, data, isSplat); | |||
1284 | assert(isValid)(static_cast <bool> (isValid) ? void (0) : __assert_fail ("isValid", "mlir/lib/IR/BuiltinAttributes.cpp", 1284, __extension__ __PRETTY_FUNCTION__)); | |||
1285 | (void)isValid; | |||
1286 | return Base::get(type.getContext(), type, data, isSplat); | |||
1287 | } | |||
1288 | ||||
1289 | /// Overload of the raw 'get' method that asserts that the given type is of | |||
1290 | /// complex type. This method is used to verify type invariants that the | |||
1291 | /// templatized 'get' method cannot. | |||
1292 | DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type, | |||
1293 | ArrayRef<char> data, | |||
1294 | int64_t dataEltSize, | |||
1295 | bool isInt, | |||
1296 | bool isSigned) { | |||
1297 | assert(::isValidIntOrFloat((static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__ )) | |||
1298 | type.getElementType().cast<ComplexType>().getElementType(),(static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__ )) | |||
1299 | dataEltSize / 2, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat( type.getElementType ().cast<ComplexType>().getElementType(), dataEltSize / 2 , isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__ )); | |||
1300 | ||||
1301 | int64_t numElements = data.size() / dataEltSize; | |||
1302 | (void)numElements; | |||
1303 | assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements == type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1303, __extension__ __PRETTY_FUNCTION__ )); | |||
1304 | return getRaw(type, data); | |||
1305 | } | |||
1306 | ||||
1307 | /// Overload of the 'getRaw' method that asserts that the given type is of | |||
1308 | /// integer type. This method is used to verify type invariants that the | |||
1309 | /// templatized 'get' method cannot. | |||
1310 | DenseElementsAttr | |||
1311 | DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data, | |||
1312 | int64_t dataEltSize, bool isInt, | |||
1313 | bool isSigned) { | |||
1314 | assert((static_cast <bool> (::isValidIntOrFloat(type.getElementType (), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1315, __extension__ __PRETTY_FUNCTION__ )) | |||
1315 | ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat(type.getElementType (), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)" , "mlir/lib/IR/BuiltinAttributes.cpp", 1315, __extension__ __PRETTY_FUNCTION__ )); | |||
1316 | ||||
1317 | int64_t numElements = data.size() / dataEltSize; | |||
1318 | assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements == type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1318, __extension__ __PRETTY_FUNCTION__ )); | |||
1319 | (void)numElements; | |||
1320 | return getRaw(type, data); | |||
1321 | } | |||
1322 | ||||
1323 | void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine( | |||
1324 | const char *inRawData, char *outRawData, size_t elementBitWidth, | |||
1325 | size_t numElements) { | |||
1326 | using llvm::support::ulittle16_t; | |||
1327 | using llvm::support::ulittle32_t; | |||
1328 | using llvm::support::ulittle64_t; | |||
1329 | ||||
1330 | assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 1331, __extension__ __PRETTY_FUNCTION__ )) | |||
1331 | llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness () == llvm::support::endianness::big) ? void (0) : __assert_fail ("llvm::support::endian::system_endianness() == llvm::support::endianness::big" , "mlir/lib/IR/BuiltinAttributes.cpp", 1331, __extension__ __PRETTY_FUNCTION__ )); // NOLINT | |||
1332 | // NOLINT to avoid warning message about replacing by static_assert() | |||
1333 | ||||
1334 | // Following std::copy_n always converts endianness on BE machine. | |||
1335 | switch (elementBitWidth) { | |||
1336 | case 16: { | |||
1337 | const ulittle16_t *inRawDataPos = | |||
1338 | reinterpret_cast<const ulittle16_t *>(inRawData); | |||
1339 | uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData); | |||
1340 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1341 | break; | |||
1342 | } | |||
1343 | case 32: { | |||
1344 | const ulittle32_t *inRawDataPos = | |||
1345 | reinterpret_cast<const ulittle32_t *>(inRawData); | |||
1346 | uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData); | |||
1347 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1348 | break; | |||
1349 | } | |||
1350 | case 64: { | |||
1351 | const ulittle64_t *inRawDataPos = | |||
1352 | reinterpret_cast<const ulittle64_t *>(inRawData); | |||
1353 | uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData); | |||
1354 | std::copy_n(inRawDataPos, numElements, outDataPos); | |||
1355 | break; | |||
1356 | } | |||
1357 | default: { | |||
1358 | size_t nBytes = elementBitWidth / CHAR_BIT8; | |||
1359 | for (size_t i = 0; i < nBytes; i++) | |||
1360 | std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i); | |||
1361 | break; | |||
1362 | } | |||
1363 | } | |||
1364 | } | |||
1365 | ||||
1366 | void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine( | |||
1367 | ArrayRef<char> inRawData, MutableArrayRef<char> outRawData, | |||
1368 | ShapedType type) { | |||
1369 | size_t numElements = type.getNumElements(); | |||
1370 | Type elementType = type.getElementType(); | |||
1371 | if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) { | |||
1372 | elementType = complexTy.getElementType(); | |||
1373 | numElements = numElements * 2; | |||
1374 | } | |||
1375 | size_t elementBitWidth = getDenseElementStorageWidth(elementType); | |||
1376 | assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&(static_cast <bool> (numElements * elementBitWidth == inRawData .size() * 8 && inRawData.size() <= outRawData.size ()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1377, __extension__ __PRETTY_FUNCTION__ )) | |||
1377 | inRawData.size() <= outRawData.size())(static_cast <bool> (numElements * elementBitWidth == inRawData .size() * 8 && inRawData.size() <= outRawData.size ()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()" , "mlir/lib/IR/BuiltinAttributes.cpp", 1377, __extension__ __PRETTY_FUNCTION__ )); | |||
1378 | if (elementBitWidth <= CHAR_BIT8) | |||
1379 | std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size()); | |||
1380 | else | |||
1381 | convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(), | |||
1382 | elementBitWidth, numElements); | |||
1383 | } | |||
1384 | ||||
1385 | //===----------------------------------------------------------------------===// | |||
1386 | // DenseFPElementsAttr | |||
1387 | //===----------------------------------------------------------------------===// | |||
1388 | ||||
1389 | template <typename Fn, typename Attr> | |||
1390 | static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType, | |||
1391 | Type newElementType, | |||
1392 | llvm::SmallVectorImpl<char> &data) { | |||
1393 | size_t bitWidth = getDenseElementBitWidth(newElementType); | |||
1394 | size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); | |||
1395 | ||||
1396 | ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType); | |||
1397 | ||||
1398 | size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements(); | |||
1399 | data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT8)); | |||
1400 | ||||
1401 | // Functor used to process a single element value of the attribute. | |||
1402 | auto processElt = [&](decltype(*attr.begin()) value, size_t index) { | |||
1403 | auto newInt = mapping(value); | |||
1404 | assert(newInt.getBitWidth() == bitWidth)(static_cast <bool> (newInt.getBitWidth() == bitWidth) ? void (0) : __assert_fail ("newInt.getBitWidth() == bitWidth" , "mlir/lib/IR/BuiltinAttributes.cpp", 1404, __extension__ __PRETTY_FUNCTION__ )); | |||
1405 | writeBits(data.data(), index * storageBitWidth, newInt); | |||
1406 | }; | |||
1407 | ||||
1408 | // Check for the splat case. | |||
1409 | if (attr.isSplat()) { | |||
1410 | if (bitWidth == 1) { | |||
1411 | // Handle the special encoding of splat of bool. | |||
1412 | data[0] = mapping(*attr.begin()).isZero() ? 0 : -1; | |||
1413 | } else { | |||
1414 | processElt(*attr.begin(), /*index=*/0); | |||
1415 | } | |||
1416 | return newArrayType; | |||
1417 | } | |||
1418 | ||||
1419 | // Otherwise, process all of the element values. | |||
1420 | uint64_t elementIdx = 0; | |||
1421 | for (auto value : attr) | |||
1422 | processElt(value, elementIdx++); | |||
1423 | return newArrayType; | |||
1424 | } | |||
1425 | ||||
1426 | DenseElementsAttr DenseFPElementsAttr::mapValues( | |||
1427 | Type newElementType, function_ref<APInt(const APFloat &)> mapping) const { | |||
1428 | llvm::SmallVector<char, 8> elementData; | |||
1429 | auto newArrayType = | |||
1430 | mappingHelper(mapping, *this, getType(), newElementType, elementData); | |||
1431 | ||||
1432 | return getRaw(newArrayType, elementData); | |||
1433 | } | |||
1434 | ||||
1435 | /// Method for supporting type inquiry through isa, cast and dyn_cast. | |||
1436 | bool DenseFPElementsAttr::classof(Attribute attr) { | |||
1437 | if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) | |||
1438 | return denseAttr.getType().getElementType().isa<FloatType>(); | |||
1439 | return false; | |||
1440 | } | |||
1441 | ||||
1442 | //===----------------------------------------------------------------------===// | |||
1443 | // DenseIntElementsAttr | |||
1444 | //===----------------------------------------------------------------------===// | |||
1445 | ||||
1446 | DenseElementsAttr DenseIntElementsAttr::mapValues( | |||
1447 | Type newElementType, function_ref<APInt(const APInt &)> mapping) const { | |||
1448 | llvm::SmallVector<char, 8> elementData; | |||
1449 | auto newArrayType = | |||
1450 | mappingHelper(mapping, *this, getType(), newElementType, elementData); | |||
1451 | return getRaw(newArrayType, elementData); | |||
1452 | } | |||
1453 | ||||
1454 | /// Method for supporting type inquiry through isa, cast and dyn_cast. | |||
1455 | bool DenseIntElementsAttr::classof(Attribute attr) { | |||
1456 | if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) | |||
1457 | return denseAttr.getType().getElementType().isIntOrIndex(); | |||
1458 | return false; | |||
1459 | } | |||
1460 | ||||
1461 | //===----------------------------------------------------------------------===// | |||
1462 | // DenseResourceElementsAttr | |||
1463 | //===----------------------------------------------------------------------===// | |||
1464 | ||||
1465 | DenseResourceElementsAttr | |||
1466 | DenseResourceElementsAttr::get(ShapedType type, | |||
1467 | DenseResourceElementsHandle handle) { | |||
1468 | return Base::get(type.getContext(), type, handle); | |||
1469 | } | |||
1470 | ||||
1471 | DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type, | |||
1472 | StringRef blobName, | |||
1473 | AsmResourceBlob blob) { | |||
1474 | // Extract the builtin dialect resource manager from context and construct a | |||
1475 | // handle by inserting a new resource using the provided blob. | |||
1476 | auto &manager = | |||
1477 | DenseResourceElementsHandle::getManagerInterface(type.getContext()); | |||
1478 | return get(type, manager.insert(blobName, std::move(blob))); | |||
1479 | } | |||
1480 | ||||
1481 | //===----------------------------------------------------------------------===// | |||
1482 | // DenseResourceElementsAttrBase | |||
1483 | ||||
1484 | namespace { | |||
1485 | /// Instantiations of this class provide utilities for interacting with native | |||
1486 | /// data types in the context of DenseResourceElementsAttr. | |||
1487 | template <typename T> | |||
1488 | struct DenseResourceAttrUtil; | |||
1489 | template <size_t width, bool isSigned> | |||
1490 | struct DenseResourceElementsAttrIntUtil { | |||
1491 | static bool checkElementType(Type eltType) { | |||
1492 | IntegerType type = eltType.dyn_cast<IntegerType>(); | |||
1493 | if (!type || type.getWidth() != width) | |||
1494 | return false; | |||
1495 | return isSigned ? !type.isUnsigned() : !type.isSigned(); | |||
1496 | } | |||
1497 | }; | |||
1498 | template <> | |||
1499 | struct DenseResourceAttrUtil<bool> { | |||
1500 | static bool checkElementType(Type eltType) { | |||
1501 | return eltType.isSignlessInteger(1); | |||
1502 | } | |||
1503 | }; | |||
1504 | template <> | |||
1505 | struct DenseResourceAttrUtil<int8_t> | |||
1506 | : public DenseResourceElementsAttrIntUtil<8, true> {}; | |||
1507 | template <> | |||
1508 | struct DenseResourceAttrUtil<uint8_t> | |||
1509 | : public DenseResourceElementsAttrIntUtil<8, false> {}; | |||
1510 | template <> | |||
1511 | struct DenseResourceAttrUtil<int16_t> | |||
1512 | : public DenseResourceElementsAttrIntUtil<16, true> {}; | |||
1513 | template <> | |||
1514 | struct DenseResourceAttrUtil<uint16_t> | |||
1515 | : public DenseResourceElementsAttrIntUtil<16, false> {}; | |||
1516 | template <> | |||
1517 | struct DenseResourceAttrUtil<int32_t> | |||
1518 | : public DenseResourceElementsAttrIntUtil<32, true> {}; | |||
1519 | template <> | |||
1520 | struct DenseResourceAttrUtil<uint32_t> | |||
1521 | : public DenseResourceElementsAttrIntUtil<32, false> {}; | |||
1522 | template <> | |||
1523 | struct DenseResourceAttrUtil<int64_t> | |||
1524 | : public DenseResourceElementsAttrIntUtil<64, true> {}; | |||
1525 | template <> | |||
1526 | struct DenseResourceAttrUtil<uint64_t> | |||
1527 | : public DenseResourceElementsAttrIntUtil<64, false> {}; | |||
1528 | template <> | |||
1529 | struct DenseResourceAttrUtil<float> { | |||
1530 | static bool checkElementType(Type eltType) { return eltType.isF32(); } | |||
1531 | }; | |||
1532 | template <> | |||
1533 | struct DenseResourceAttrUtil<double> { | |||
1534 | static bool checkElementType(Type eltType) { return eltType.isF64(); } | |||
1535 | }; | |||
1536 | } // namespace | |||
1537 | ||||
1538 | template <typename T> | |||
1539 | DenseResourceElementsAttrBase<T> | |||
1540 | DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName, | |||
1541 | AsmResourceBlob blob) { | |||
1542 | // Check that the blob is in the form we were expecting. | |||
1543 | assert(blob.getDataAlignment() == alignof(T) &&(static_cast <bool> (blob.getDataAlignment() == alignof (T) && "alignment mismatch between expected alignment and blob alignment" ) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1544, __extension__ __PRETTY_FUNCTION__ )) | |||
1544 | "alignment mismatch between expected alignment and blob alignment")(static_cast <bool> (blob.getDataAlignment() == alignof (T) && "alignment mismatch between expected alignment and blob alignment" ) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1544, __extension__ __PRETTY_FUNCTION__ )); | |||
1545 | assert(((blob.getData().size() % sizeof(T)) == 0) &&(static_cast <bool> (((blob.getData().size() % sizeof(T )) == 0) && "size mismatch between expected element width and blob size" ) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1546, __extension__ __PRETTY_FUNCTION__ )) | |||
1546 | "size mismatch between expected element width and blob size")(static_cast <bool> (((blob.getData().size() % sizeof(T )) == 0) && "size mismatch between expected element width and blob size" ) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1546, __extension__ __PRETTY_FUNCTION__ )); | |||
1547 | assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType (type.getElementType()) && "invalid shape element type for provided type `T`" ) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1548, __extension__ __PRETTY_FUNCTION__ )) | |||
1548 | "invalid shape element type for provided type `T`")(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType (type.getElementType()) && "invalid shape element type for provided type `T`" ) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1548, __extension__ __PRETTY_FUNCTION__ )); | |||
1549 | return DenseResourceElementsAttr::get(type, blobName, std::move(blob)) | |||
1550 | .template cast<DenseResourceElementsAttrBase<T>>(); | |||
1551 | } | |||
1552 | ||||
1553 | template <typename T> | |||
1554 | Optional<ArrayRef<T>> | |||
1555 | DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const { | |||
1556 | if (AsmResourceBlob *blob = this->getRawHandle().getBlob()) | |||
1557 | return blob->template getDataAs<T>(); | |||
1558 | return std::nullopt; | |||
1559 | } | |||
1560 | ||||
1561 | template <typename T> | |||
1562 | bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) { | |||
1563 | auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>(); | |||
1564 | return resourceAttr && DenseResourceAttrUtil<T>::checkElementType( | |||
1565 | resourceAttr.getElementType()); | |||
1566 | } | |||
1567 | ||||
1568 | namespace mlir { | |||
1569 | namespace detail { | |||
1570 | // Explicit instantiation for all the supported DenseResourceElementsAttr. | |||
1571 | template class DenseResourceElementsAttrBase<bool>; | |||
1572 | template class DenseResourceElementsAttrBase<int8_t>; | |||
1573 | template class DenseResourceElementsAttrBase<int16_t>; | |||
1574 | template class DenseResourceElementsAttrBase<int32_t>; | |||
1575 | template class DenseResourceElementsAttrBase<int64_t>; | |||
1576 | template class DenseResourceElementsAttrBase<uint8_t>; | |||
1577 | template class DenseResourceElementsAttrBase<uint16_t>; | |||
1578 | template class DenseResourceElementsAttrBase<uint32_t>; | |||
1579 | template class DenseResourceElementsAttrBase<uint64_t>; | |||
1580 | template class DenseResourceElementsAttrBase<float>; | |||
1581 | template class DenseResourceElementsAttrBase<double>; | |||
1582 | } // namespace detail | |||
1583 | } // namespace mlir | |||
1584 | ||||
1585 | //===----------------------------------------------------------------------===// | |||
1586 | // SparseElementsAttr | |||
1587 | //===----------------------------------------------------------------------===// | |||
1588 | ||||
1589 | /// Get a zero APFloat for the given sparse attribute. | |||
1590 | APFloat SparseElementsAttr::getZeroAPFloat() const { | |||
1591 | auto eltType = getElementType().cast<FloatType>(); | |||
1592 | return APFloat(eltType.getFloatSemantics()); | |||
1593 | } | |||
1594 | ||||
1595 | /// Get a zero APInt for the given sparse attribute. | |||
1596 | APInt SparseElementsAttr::getZeroAPInt() const { | |||
1597 | auto eltType = getElementType().cast<IntegerType>(); | |||
1598 | return APInt::getZero(eltType.getWidth()); | |||
1599 | } | |||
1600 | ||||
1601 | /// Get a zero attribute for the given attribute type. | |||
1602 | Attribute SparseElementsAttr::getZeroAttr() const { | |||
1603 | auto eltType = getElementType(); | |||
1604 | ||||
1605 | // Handle floating point elements. | |||
1606 | if (eltType.isa<FloatType>()) | |||
1607 | return FloatAttr::get(eltType, 0); | |||
1608 | ||||
1609 | // Handle complex elements. | |||
1610 | if (auto complexTy = eltType.dyn_cast<ComplexType>()) { | |||
1611 | auto eltType = complexTy.getElementType(); | |||
1612 | Attribute zero; | |||
1613 | if (eltType.isa<FloatType>()) | |||
1614 | zero = FloatAttr::get(eltType, 0); | |||
1615 | else // must be integer | |||
1616 | zero = IntegerAttr::get(eltType, 0); | |||
1617 | return ArrayAttr::get(complexTy.getContext(), | |||
1618 | ArrayRef<Attribute>{zero, zero}); | |||
1619 | } | |||
1620 | ||||
1621 | // Handle string type. | |||
1622 | if (getValues().isa<DenseStringElementsAttr>()) | |||
1623 | return StringAttr::get("", eltType); | |||
1624 | ||||
1625 | // Otherwise, this is an integer. | |||
1626 | return IntegerAttr::get(eltType, 0); | |||
1627 | } | |||
1628 | ||||
1629 | /// Flatten, and return, all of the sparse indices in this attribute in | |||
1630 | /// row-major order. | |||
1631 | std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const { | |||
1632 | std::vector<ptrdiff_t> flatSparseIndices; | |||
1633 | ||||
1634 | // The sparse indices are 64-bit integers, so we can reinterpret the raw data | |||
1635 | // as a 1-D index array. | |||
1636 | auto sparseIndices = getIndices(); | |||
1637 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); | |||
1638 | if (sparseIndices.isSplat()) { | |||
1639 | SmallVector<uint64_t, 8> indices(getType().getRank(), | |||
1640 | *sparseIndexValues.begin()); | |||
1641 | flatSparseIndices.push_back(getFlattenedIndex(indices)); | |||
1642 | return flatSparseIndices; | |||
1643 | } | |||
1644 | ||||
1645 | // Otherwise, reinterpret each index as an ArrayRef when flattening. | |||
1646 | auto numSparseIndices = sparseIndices.getType().getDimSize(0); | |||
1647 | size_t rank = getType().getRank(); | |||
1648 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) | |||
1649 | flatSparseIndices.push_back(getFlattenedIndex( | |||
1650 | {&*std::next(sparseIndexValues.begin(), i * rank), rank})); | |||
1651 | return flatSparseIndices; | |||
1652 | } | |||
1653 | ||||
1654 | LogicalResult | |||
1655 | SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError, | |||
1656 | ShapedType type, DenseIntElementsAttr sparseIndices, | |||
1657 | DenseElementsAttr values) { | |||
1658 | ShapedType valuesType = values.getType(); | |||
1659 | if (valuesType.getRank() != 1) | |||
1660 | return emitError() << "expected 1-d tensor for sparse element values"; | |||
1661 | ||||
1662 | // Verify the indices and values shape. | |||
1663 | ShapedType indicesType = sparseIndices.getType(); | |||
1664 | auto emitShapeError = [&]() { | |||
1665 | return emitError() << "expected shape ([" << type.getShape() | |||
1666 | << "]); inferred shape of indices literal ([" | |||
1667 | << indicesType.getShape() | |||
1668 | << "]); inferred shape of values literal ([" | |||
1669 | << valuesType.getShape() << "])"; | |||
1670 | }; | |||
1671 | // Verify indices shape. | |||
1672 | size_t rank = type.getRank(), indicesRank = indicesType.getRank(); | |||
1673 | if (indicesRank == 2) { | |||
1674 | if (indicesType.getDimSize(1) != static_cast<int64_t>(rank)) | |||
1675 | return emitShapeError(); | |||
1676 | } else if (indicesRank != 1 || rank != 1) { | |||
1677 | return emitShapeError(); | |||
1678 | } | |||
1679 | // Verify the values shape. | |||
1680 | int64_t numSparseIndices = indicesType.getDimSize(0); | |||
1681 | if (numSparseIndices != valuesType.getDimSize(0)) | |||
1682 | return emitShapeError(); | |||
1683 | ||||
1684 | // Verify that the sparse indices are within the value shape. | |||
1685 | auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) { | |||
1686 | return emitError() | |||
1687 | << "sparse index #" << indexNum | |||
1688 | << " is not contained within the value shape, with index=[" << index | |||
1689 | << "], and type=" << type; | |||
1690 | }; | |||
1691 | ||||
1692 | // Handle the case where the index values are a splat. | |||
1693 | auto sparseIndexValues = sparseIndices.getValues<uint64_t>(); | |||
1694 | if (sparseIndices.isSplat()) { | |||
1695 | SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin()); | |||
1696 | if (!ElementsAttr::isValidIndex(type, indices)) | |||
1697 | return emitIndexError(0, indices); | |||
1698 | return success(); | |||
1699 | } | |||
1700 | ||||
1701 | // Otherwise, reinterpret each index as an ArrayRef. | |||
1702 | for (size_t i = 0, e = numSparseIndices; i != e; ++i) { | |||
1703 | ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank), | |||
1704 | rank); | |||
1705 | if (!ElementsAttr::isValidIndex(type, index)) | |||
1706 | return emitIndexError(i, index); | |||
1707 | } | |||
1708 | ||||
1709 | return success(); | |||
1710 | } | |||
1711 | ||||
1712 | //===----------------------------------------------------------------------===// | |||
1713 | // Attribute Utilities | |||
1714 | //===----------------------------------------------------------------------===// | |||
1715 | ||||
1716 | AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides, | |||
1717 | int64_t offset, | |||
1718 | MLIRContext *context) { | |||
1719 | AffineExpr expr; | |||
1720 | unsigned nSymbols = 0; | |||
1721 | ||||
1722 | // AffineExpr for offset. | |||
1723 | // Static case. | |||
1724 | if (!ShapedType::isDynamic(offset)) { | |||
1725 | auto cst = getAffineConstantExpr(offset, context); | |||
1726 | expr = cst; | |||
1727 | } else { | |||
1728 | // Dynamic case, new symbol for the offset. | |||
1729 | auto sym = getAffineSymbolExpr(nSymbols++, context); | |||
1730 | expr = sym; | |||
1731 | } | |||
1732 | ||||
1733 | // AffineExpr for strides. | |||
1734 | for (const auto &en : llvm::enumerate(strides)) { | |||
1735 | auto dim = en.index(); | |||
1736 | auto stride = en.value(); | |||
1737 | assert(stride != 0 && "Invalid stride specification")(static_cast <bool> (stride != 0 && "Invalid stride specification" ) ? void (0) : __assert_fail ("stride != 0 && \"Invalid stride specification\"" , "mlir/lib/IR/BuiltinAttributes.cpp", 1737, __extension__ __PRETTY_FUNCTION__ )); | |||
1738 | auto d = getAffineDimExpr(dim, context); | |||
1739 | AffineExpr mult; | |||
1740 | // Static case. | |||
1741 | if (!ShapedType::isDynamic(stride)) | |||
1742 | mult = getAffineConstantExpr(stride, context); | |||
1743 | else | |||
1744 | // Dynamic case, new symbol for each new stride. | |||
1745 | mult = getAffineSymbolExpr(nSymbols++, context); | |||
1746 | expr = expr + d * mult; | |||
1747 | } | |||
1748 | ||||
1749 | return AffineMap::get(strides.size(), nSymbols, expr); | |||
1750 | } |
1 | //===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This classes used by the implementation details of Op types. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPIMPLEMENTATION_H |
14 | #define MLIR_IR_OPIMPLEMENTATION_H |
15 | |
16 | #include "mlir/IR/BuiltinTypes.h" |
17 | #include "mlir/IR/DialectInterface.h" |
18 | #include "mlir/IR/OpDefinition.h" |
19 | #include "llvm/ADT/Twine.h" |
20 | #include "llvm/Support/SMLoc.h" |
21 | |
22 | namespace mlir { |
23 | class AsmParsedResourceEntry; |
24 | class AsmResourceBuilder; |
25 | class Builder; |
26 | |
27 | //===----------------------------------------------------------------------===// |
28 | // AsmDialectResourceHandle |
29 | //===----------------------------------------------------------------------===// |
30 | |
31 | /// This class represents an opaque handle to a dialect resource entry. |
32 | class AsmDialectResourceHandle { |
33 | public: |
34 | AsmDialectResourceHandle() = default; |
35 | AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect) |
36 | : resource(resource), opaqueID(resourceID), dialect(dialect) {} |
37 | bool operator==(const AsmDialectResourceHandle &other) const { |
38 | return resource == other.resource; |
39 | } |
40 | |
41 | /// Return an opaque pointer to the referenced resource. |
42 | void *getResource() const { return resource; } |
43 | |
44 | /// Return the type ID of the resource. |
45 | TypeID getTypeID() const { return opaqueID; } |
46 | |
47 | /// Return the dialect that owns the resource. |
48 | Dialect *getDialect() const { return dialect; } |
49 | |
50 | private: |
51 | /// The opaque handle to the dialect resource. |
52 | void *resource = nullptr; |
53 | /// The type of the resource referenced. |
54 | TypeID opaqueID; |
55 | /// The dialect owning the given resource. |
56 | Dialect *dialect; |
57 | }; |
58 | |
59 | /// This class represents a CRTP base class for dialect resource handles. It |
60 | /// abstracts away various utilities necessary for defined derived resource |
61 | /// handles. |
62 | template <typename DerivedT, typename ResourceT, typename DialectT> |
63 | class AsmDialectResourceHandleBase : public AsmDialectResourceHandle { |
64 | public: |
65 | using Dialect = DialectT; |
66 | |
67 | /// Construct a handle from a pointer to the resource. The given pointer |
68 | /// should be guaranteed to live beyond the life of this handle. |
69 | AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect) |
70 | : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {} |
71 | AsmDialectResourceHandleBase(AsmDialectResourceHandle handle) |
72 | : AsmDialectResourceHandle(handle) { |
73 | assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get< DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()" , "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__ __PRETTY_FUNCTION__)); |
74 | } |
75 | |
76 | /// Return the resource referenced by this handle. |
77 | ResourceT *getResource() { |
78 | return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource()); |
79 | } |
80 | const ResourceT *getResource() const { |
81 | return const_cast<AsmDialectResourceHandleBase *>(this)->getResource(); |
82 | } |
83 | |
84 | /// Return the dialect that owns the resource. |
85 | DialectT *getDialect() const { |
86 | return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect()); |
87 | } |
88 | |
89 | /// Support llvm style casting. |
90 | static bool classof(const AsmDialectResourceHandle *handle) { |
91 | return handle->getTypeID() == TypeID::get<DerivedT>(); |
92 | } |
93 | }; |
94 | |
95 | inline llvm::hash_code hash_value(const AsmDialectResourceHandle ¶m) { |
96 | return llvm::hash_value(param.getResource()); |
97 | } |
98 | |
99 | //===----------------------------------------------------------------------===// |
100 | // AsmPrinter |
101 | //===----------------------------------------------------------------------===// |
102 | |
103 | /// This base class exposes generic asm printer hooks, usable across the various |
104 | /// derived printers. |
105 | class AsmPrinter { |
106 | public: |
107 | /// This class contains the internal default implementation of the base |
108 | /// printer methods. |
109 | class Impl; |
110 | |
111 | /// Initialize the printer with the given internal implementation. |
112 | AsmPrinter(Impl &impl) : impl(&impl) {} |
113 | virtual ~AsmPrinter(); |
114 | |
115 | /// Return the raw output stream used by this printer. |
116 | virtual raw_ostream &getStream() const; |
117 | |
118 | /// Print the given floating point value in a stabilized form that can be |
119 | /// roundtripped through the IR. This is the companion to the 'parseFloat' |
120 | /// hook on the AsmParser. |
121 | virtual void printFloat(const APFloat &value); |
122 | |
123 | virtual void printType(Type type); |
124 | virtual void printAttribute(Attribute attr); |
125 | |
126 | /// Trait to check if `AttrType` provides a `print` method. |
127 | template <typename AttrOrType> |
128 | using has_print_method = |
129 | decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>())); |
130 | template <typename AttrOrType> |
131 | using detect_has_print_method = |
132 | llvm::is_detected<has_print_method, AttrOrType>; |
133 | |
134 | /// Print the provided attribute in the context of an operation custom |
135 | /// printer/parser: this will invoke directly the print method on the |
136 | /// attribute class and skip the `#dialect.mnemonic` prefix in most cases. |
137 | template <typename AttrOrType, |
138 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
139 | *sfinae = nullptr> |
140 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
141 | if (succeeded(printAlias(attrOrType))) |
142 | return; |
143 | attrOrType.print(*this); |
144 | } |
145 | |
146 | /// Print the provided array of attributes or types in the context of an |
147 | /// operation custom printer/parser: this will invoke directly the print |
148 | /// method on the attribute class and skip the `#dialect.mnemonic` prefix in |
149 | /// most cases. |
150 | template <typename AttrOrType, |
151 | std::enable_if_t<detect_has_print_method<AttrOrType>::value> |
152 | *sfinae = nullptr> |
153 | void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) { |
154 | llvm::interleaveComma( |
155 | attrOrTypes, getStream(), |
156 | [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); }); |
157 | } |
158 | |
159 | /// SFINAE for printing the provided attribute in the context of an operation |
160 | /// custom printer in the case where the attribute does not define a print |
161 | /// method. |
162 | template <typename AttrOrType, |
163 | std::enable_if_t<!detect_has_print_method<AttrOrType>::value> |
164 | *sfinae = nullptr> |
165 | void printStrippedAttrOrType(AttrOrType attrOrType) { |
166 | *this << attrOrType; |
167 | } |
168 | |
169 | /// Print the given attribute without its type. The corresponding parser must |
170 | /// provide a valid type for the attribute. |
171 | virtual void printAttributeWithoutType(Attribute attr); |
172 | |
173 | /// Print the given string as a keyword, or a quoted and escaped string if it |
174 | /// has any special or non-printable characters in it. |
175 | virtual void printKeywordOrString(StringRef keyword); |
176 | |
177 | /// Print the given string as a symbol reference, i.e. a form representable by |
178 | /// a SymbolRefAttr. A symbol reference is represented as a string prefixed |
179 | /// with '@'. The reference is surrounded with ""'s and escaped if it has any |
180 | /// special or non-printable characters in it. |
181 | virtual void printSymbolName(StringRef symbolRef); |
182 | |
183 | /// Print a handle to the given dialect resource. |
184 | virtual void printResourceHandle(const AsmDialectResourceHandle &resource); |
185 | |
186 | /// Print an optional arrow followed by a type list. |
187 | template <typename TypeRange> |
188 | void printOptionalArrowTypeList(TypeRange &&types) { |
189 | if (types.begin() != types.end()) |
190 | printArrowTypeList(types); |
191 | } |
192 | template <typename TypeRange> |
193 | void printArrowTypeList(TypeRange &&types) { |
194 | auto &os = getStream() << " -> "; |
195 | |
196 | bool wrapped = !llvm::hasSingleElement(types) || |
197 | (*types.begin()).template isa<FunctionType>(); |
198 | if (wrapped) |
199 | os << '('; |
200 | llvm::interleaveComma(types, *this); |
201 | if (wrapped) |
202 | os << ')'; |
203 | } |
204 | |
205 | /// Print the two given type ranges in a functional form. |
206 | template <typename InputRangeT, typename ResultRangeT> |
207 | void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) { |
208 | auto &os = getStream(); |
209 | os << '('; |
210 | llvm::interleaveComma(inputs, *this); |
211 | os << ')'; |
212 | printArrowTypeList(results); |
213 | } |
214 | |
215 | protected: |
216 | /// Initialize the printer with no internal implementation. In this case, all |
217 | /// virtual methods of this class must be overriden. |
218 | AsmPrinter() = default; |
219 | |
220 | private: |
221 | AsmPrinter(const AsmPrinter &) = delete; |
222 | void operator=(const AsmPrinter &) = delete; |
223 | |
224 | /// Print the alias for the given attribute, return failure if no alias could |
225 | /// be printed. |
226 | virtual LogicalResult printAlias(Attribute attr); |
227 | |
228 | /// Print the alias for the given type, return failure if no alias could |
229 | /// be printed. |
230 | virtual LogicalResult printAlias(Type type); |
231 | |
232 | /// The internal implementation of the printer. |
233 | Impl *impl{nullptr}; |
234 | }; |
235 | |
236 | template <typename AsmPrinterT> |
237 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
238 | AsmPrinterT &> |
239 | operator<<(AsmPrinterT &p, Type type) { |
240 | p.printType(type); |
241 | return p; |
242 | } |
243 | |
244 | template <typename AsmPrinterT> |
245 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
246 | AsmPrinterT &> |
247 | operator<<(AsmPrinterT &p, Attribute attr) { |
248 | p.printAttribute(attr); |
249 | return p; |
250 | } |
251 | |
252 | template <typename AsmPrinterT> |
253 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
254 | AsmPrinterT &> |
255 | operator<<(AsmPrinterT &p, const APFloat &value) { |
256 | p.printFloat(value); |
257 | return p; |
258 | } |
259 | template <typename AsmPrinterT> |
260 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
261 | AsmPrinterT &> |
262 | operator<<(AsmPrinterT &p, float value) { |
263 | return p << APFloat(value); |
264 | } |
265 | template <typename AsmPrinterT> |
266 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
267 | AsmPrinterT &> |
268 | operator<<(AsmPrinterT &p, double value) { |
269 | return p << APFloat(value); |
270 | } |
271 | |
272 | // Support printing anything that isn't convertible to one of the other |
273 | // streamable types, even if it isn't exactly one of them. For example, we want |
274 | // to print FunctionType with the Type version above, not have it match this. |
275 | template <typename AsmPrinterT, typename T, |
276 | std::enable_if_t<!std::is_convertible<T &, Value &>::value && |
277 | !std::is_convertible<T &, Type &>::value && |
278 | !std::is_convertible<T &, Attribute &>::value && |
279 | !std::is_convertible<T &, ValueRange>::value && |
280 | !std::is_convertible<T &, APFloat &>::value && |
281 | !llvm::is_one_of<T, bool, float, double>::value, |
282 | T> * = nullptr> |
283 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
284 | AsmPrinterT &> |
285 | operator<<(AsmPrinterT &p, const T &other) { |
286 | p.getStream() << other; |
287 | return p; |
288 | } |
289 | |
290 | template <typename AsmPrinterT> |
291 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
292 | AsmPrinterT &> |
293 | operator<<(AsmPrinterT &p, bool value) { |
294 | return p << (value ? StringRef("true") : "false"); |
295 | } |
296 | |
297 | template <typename AsmPrinterT, typename ValueRangeT> |
298 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
299 | AsmPrinterT &> |
300 | operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) { |
301 | llvm::interleaveComma(types, p); |
302 | return p; |
303 | } |
304 | template <typename AsmPrinterT> |
305 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
306 | AsmPrinterT &> |
307 | operator<<(AsmPrinterT &p, const TypeRange &types) { |
308 | llvm::interleaveComma(types, p); |
309 | return p; |
310 | } |
311 | template <typename AsmPrinterT, typename ElementT> |
312 | inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value, |
313 | AsmPrinterT &> |
314 | operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) { |
315 | llvm::interleaveComma(types, p); |
316 | return p; |
317 | } |
318 | |
319 | //===----------------------------------------------------------------------===// |
320 | // OpAsmPrinter |
321 | //===----------------------------------------------------------------------===// |
322 | |
323 | /// This is a pure-virtual base class that exposes the asmprinter hooks |
324 | /// necessary to implement a custom print() method. |
325 | class OpAsmPrinter : public AsmPrinter { |
326 | public: |
327 | using AsmPrinter::AsmPrinter; |
328 | ~OpAsmPrinter() override; |
329 | |
330 | /// Print a loc(...) specifier if printing debug info is enabled. |
331 | virtual void printOptionalLocationSpecifier(Location loc) = 0; |
332 | |
333 | /// Print a newline and indent the printer to the start of the current |
334 | /// operation. |
335 | virtual void printNewline() = 0; |
336 | |
337 | /// Increase indentation. |
338 | virtual void increaseIndent() = 0; |
339 | |
340 | /// Decrease indentation. |
341 | virtual void decreaseIndent() = 0; |
342 | |
343 | /// Print a block argument in the usual format of: |
344 | /// %ssaName : type {attr1=42} loc("here") |
345 | /// where location printing is controlled by the standard internal option. |
346 | /// You may pass omitType=true to not print a type, and pass an empty |
347 | /// attribute list if you don't care for attributes. |
348 | virtual void printRegionArgument(BlockArgument arg, |
349 | ArrayRef<NamedAttribute> argAttrs = {}, |
350 | bool omitType = false) = 0; |
351 | |
352 | /// Print implementations for various things an operation contains. |
353 | virtual void printOperand(Value value) = 0; |
354 | virtual void printOperand(Value value, raw_ostream &os) = 0; |
355 | |
356 | /// Print a comma separated list of operands. |
357 | template <typename ContainerType> |
358 | void printOperands(const ContainerType &container) { |
359 | printOperands(container.begin(), container.end()); |
360 | } |
361 | |
362 | /// Print a comma separated list of operands. |
363 | template <typename IteratorType> |
364 | void printOperands(IteratorType it, IteratorType end) { |
365 | llvm::interleaveComma(llvm::make_range(it, end), getStream(), |
366 | [this](Value value) { printOperand(value); }); |
367 | } |
368 | |
369 | /// Print the given successor. |
370 | virtual void printSuccessor(Block *successor) = 0; |
371 | |
372 | /// Print the successor and its operands. |
373 | virtual void printSuccessorAndUseList(Block *successor, |
374 | ValueRange succOperands) = 0; |
375 | |
376 | /// If the specified operation has attributes, print out an attribute |
377 | /// dictionary with their values. elidedAttrs allows the client to ignore |
378 | /// specific well known attributes, commonly used if the attribute value is |
379 | /// printed some other way (like as a fixed operand). |
380 | virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs, |
381 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
382 | |
383 | /// If the specified operation has attributes, print out an attribute |
384 | /// dictionary prefixed with 'attributes'. |
385 | virtual void |
386 | printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs, |
387 | ArrayRef<StringRef> elidedAttrs = {}) = 0; |
388 | |
389 | /// Prints the entire operation with the custom assembly form, if available, |
390 | /// or the generic assembly form, otherwise. |
391 | virtual void printCustomOrGenericOp(Operation *op) = 0; |
392 | |
393 | /// Print the entire operation with the default generic assembly form. |
394 | /// If `printOpName` is true, then the operation name is printed (the default) |
395 | /// otherwise it is omitted and the print will start with the operand list. |
396 | virtual void printGenericOp(Operation *op, bool printOpName = true) = 0; |
397 | |
398 | /// Prints a region. |
399 | /// If 'printEntryBlockArgs' is false, the arguments of the |
400 | /// block are not printed. If 'printBlockTerminator' is false, the terminator |
401 | /// operation of the block is not printed. If printEmptyBlock is true, then |
402 | /// the block header is printed even if the block is empty. |
403 | virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true, |
404 | bool printBlockTerminators = true, |
405 | bool printEmptyBlock = false) = 0; |
406 | |
407 | /// Renumber the arguments for the specified region to the same names as the |
408 | /// SSA values in namesToUse. This may only be used for IsolatedFromAbove |
409 | /// operations. If any entry in namesToUse is null, the corresponding |
410 | /// argument name is left alone. |
411 | virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; |
412 | |
413 | /// Prints an affine map of SSA ids, where SSA id names are used in place |
414 | /// of dims/symbols. |
415 | /// Operand values must come from single-result sources, and be valid |
416 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
417 | virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, |
418 | ValueRange operands) = 0; |
419 | |
420 | /// Prints an affine expression of SSA ids with SSA id names used instead of |
421 | /// dims and symbols. |
422 | /// Operand values must come from single-result sources, and be valid |
423 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
424 | virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands, |
425 | ValueRange symOperands) = 0; |
426 | |
427 | /// Print the complete type of an operation in functional form. |
428 | void printFunctionalType(Operation *op); |
429 | using AsmPrinter::printFunctionalType; |
430 | }; |
431 | |
432 | // Make the implementations convenient to use. |
433 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) { |
434 | p.printOperand(value); |
435 | return p; |
436 | } |
437 | |
438 | template <typename T, |
439 | std::enable_if_t<std::is_convertible<T &, ValueRange>::value && |
440 | !std::is_convertible<T &, Value &>::value, |
441 | T> * = nullptr> |
442 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) { |
443 | p.printOperands(values); |
444 | return p; |
445 | } |
446 | |
447 | inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) { |
448 | p.printSuccessor(value); |
449 | return p; |
450 | } |
451 | |
452 | //===----------------------------------------------------------------------===// |
453 | // AsmParser |
454 | //===----------------------------------------------------------------------===// |
455 | |
456 | /// This base class exposes generic asm parser hooks, usable across the various |
457 | /// derived parsers. |
458 | class AsmParser { |
459 | public: |
460 | AsmParser() = default; |
461 | virtual ~AsmParser(); |
462 | |
463 | MLIRContext *getContext() const; |
464 | |
465 | /// Return the location of the original name token. |
466 | virtual SMLoc getNameLoc() const = 0; |
467 | |
468 | //===--------------------------------------------------------------------===// |
469 | // Utilities |
470 | //===--------------------------------------------------------------------===// |
471 | |
472 | /// Emit a diagnostic at the specified location and return failure. |
473 | virtual InFlightDiagnostic emitError(SMLoc loc, |
474 | const Twine &message = {}) = 0; |
475 | |
476 | /// Return a builder which provides useful access to MLIRContext, global |
477 | /// objects like types and attributes. |
478 | virtual Builder &getBuilder() const = 0; |
479 | |
480 | /// Get the location of the next token and store it into the argument. This |
481 | /// always succeeds. |
482 | virtual SMLoc getCurrentLocation() = 0; |
483 | ParseResult getCurrentLocation(SMLoc *loc) { |
484 | *loc = getCurrentLocation(); |
485 | return success(); |
486 | } |
487 | |
488 | /// Re-encode the given source location as an MLIR location and return it. |
489 | /// Note: This method should only be used when a `Location` is necessary, as |
490 | /// the encoding process is not efficient. |
491 | virtual Location getEncodedSourceLoc(SMLoc loc) = 0; |
492 | |
493 | //===--------------------------------------------------------------------===// |
494 | // Token Parsing |
495 | //===--------------------------------------------------------------------===// |
496 | |
497 | /// Parse a '->' token. |
498 | virtual ParseResult parseArrow() = 0; |
499 | |
500 | /// Parse a '->' token if present |
501 | virtual ParseResult parseOptionalArrow() = 0; |
502 | |
503 | /// Parse a `{` token. |
504 | virtual ParseResult parseLBrace() = 0; |
505 | |
506 | /// Parse a `{` token if present. |
507 | virtual ParseResult parseOptionalLBrace() = 0; |
508 | |
509 | /// Parse a `}` token. |
510 | virtual ParseResult parseRBrace() = 0; |
511 | |
512 | /// Parse a `}` token if present. |
513 | virtual ParseResult parseOptionalRBrace() = 0; |
514 | |
515 | /// Parse a `:` token. |
516 | virtual ParseResult parseColon() = 0; |
517 | |
518 | /// Parse a `:` token if present. |
519 | virtual ParseResult parseOptionalColon() = 0; |
520 | |
521 | /// Parse a `,` token. |
522 | virtual ParseResult parseComma() = 0; |
523 | |
524 | /// Parse a `,` token if present. |
525 | virtual ParseResult parseOptionalComma() = 0; |
526 | |
527 | /// Parse a `=` token. |
528 | virtual ParseResult parseEqual() = 0; |
529 | |
530 | /// Parse a `=` token if present. |
531 | virtual ParseResult parseOptionalEqual() = 0; |
532 | |
533 | /// Parse a '<' token. |
534 | virtual ParseResult parseLess() = 0; |
535 | |
536 | /// Parse a '<' token if present. |
537 | virtual ParseResult parseOptionalLess() = 0; |
538 | |
539 | /// Parse a '>' token. |
540 | virtual ParseResult parseGreater() = 0; |
541 | |
542 | /// Parse a '>' token if present. |
543 | virtual ParseResult parseOptionalGreater() = 0; |
544 | |
545 | /// Parse a '?' token. |
546 | virtual ParseResult parseQuestion() = 0; |
547 | |
548 | /// Parse a '?' token if present. |
549 | virtual ParseResult parseOptionalQuestion() = 0; |
550 | |
551 | /// Parse a '+' token. |
552 | virtual ParseResult parsePlus() = 0; |
553 | |
554 | /// Parse a '+' token if present. |
555 | virtual ParseResult parseOptionalPlus() = 0; |
556 | |
557 | /// Parse a '*' token. |
558 | virtual ParseResult parseStar() = 0; |
559 | |
560 | /// Parse a '*' token if present. |
561 | virtual ParseResult parseOptionalStar() = 0; |
562 | |
563 | /// Parse a '|' token. |
564 | virtual ParseResult parseVerticalBar() = 0; |
565 | |
566 | /// Parse a '|' token if present. |
567 | virtual ParseResult parseOptionalVerticalBar() = 0; |
568 | |
569 | /// Parse a quoted string token. |
570 | ParseResult parseString(std::string *string) { |
571 | auto loc = getCurrentLocation(); |
572 | if (parseOptionalString(string)) |
573 | return emitError(loc, "expected string"); |
574 | return success(); |
575 | } |
576 | |
577 | /// Parse a quoted string token if present. |
578 | virtual ParseResult parseOptionalString(std::string *string) = 0; |
579 | |
580 | /// Parses a Base64 encoded string of bytes. |
581 | virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0; |
582 | |
583 | /// Parse a `(` token. |
584 | virtual ParseResult parseLParen() = 0; |
585 | |
586 | /// Parse a `(` token if present. |
587 | virtual ParseResult parseOptionalLParen() = 0; |
588 | |
589 | /// Parse a `)` token. |
590 | virtual ParseResult parseRParen() = 0; |
591 | |
592 | /// Parse a `)` token if present. |
593 | virtual ParseResult parseOptionalRParen() = 0; |
594 | |
595 | /// Parse a `[` token. |
596 | virtual ParseResult parseLSquare() = 0; |
597 | |
598 | /// Parse a `[` token if present. |
599 | virtual ParseResult parseOptionalLSquare() = 0; |
600 | |
601 | /// Parse a `]` token. |
602 | virtual ParseResult parseRSquare() = 0; |
603 | |
604 | /// Parse a `]` token if present. |
605 | virtual ParseResult parseOptionalRSquare() = 0; |
606 | |
607 | /// Parse a `...` token. |
608 | virtual ParseResult parseEllipsis() = 0; |
609 | |
610 | /// Parse a `...` token if present; |
611 | virtual ParseResult parseOptionalEllipsis() = 0; |
612 | |
613 | /// Parse a floating point value from the stream. |
614 | virtual ParseResult parseFloat(double &result) = 0; |
615 | |
616 | /// Parse an integer value from the stream. |
617 | template <typename IntT> |
618 | ParseResult parseInteger(IntT &result) { |
619 | auto loc = getCurrentLocation(); |
620 | OptionalParseResult parseResult = parseOptionalInteger(result); |
621 | if (!parseResult.has_value()) |
622 | return emitError(loc, "expected integer value"); |
623 | return *parseResult; |
624 | } |
625 | |
626 | /// Parse an optional integer value from the stream. |
627 | virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0; |
628 | |
629 | template <typename IntT> |
630 | OptionalParseResult parseOptionalInteger(IntT &result) { |
631 | auto loc = getCurrentLocation(); |
632 | |
633 | // Parse the unsigned variant. |
634 | APInt uintResult; |
635 | OptionalParseResult parseResult = parseOptionalInteger(uintResult); |
636 | if (!parseResult.has_value() || failed(*parseResult)) |
637 | return parseResult; |
638 | |
639 | // Try to convert to the provided integer type. sextOrTrunc is correct even |
640 | // for unsigned types because parseOptionalInteger ensures the sign bit is |
641 | // zero for non-negated integers. |
642 | result = |
643 | (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue(); |
644 | if (APInt(uintResult.getBitWidth(), result) != uintResult) |
645 | return emitError(loc, "integer value too large"); |
646 | return success(); |
647 | } |
648 | |
649 | /// These are the supported delimiters around operand lists and region |
650 | /// argument lists, used by parseOperandList. |
651 | enum class Delimiter { |
652 | /// Zero or more operands with no delimiters. |
653 | None, |
654 | /// Parens surrounding zero or more operands. |
655 | Paren, |
656 | /// Square brackets surrounding zero or more operands. |
657 | Square, |
658 | /// <> brackets surrounding zero or more operands. |
659 | LessGreater, |
660 | /// {} brackets surrounding zero or more operands. |
661 | Braces, |
662 | /// Parens supporting zero or more operands, or nothing. |
663 | OptionalParen, |
664 | /// Square brackets supporting zero or more ops, or nothing. |
665 | OptionalSquare, |
666 | /// <> brackets supporting zero or more ops, or nothing. |
667 | OptionalLessGreater, |
668 | /// {} brackets surrounding zero or more operands, or nothing. |
669 | OptionalBraces, |
670 | }; |
671 | |
672 | /// Parse a list of comma-separated items with an optional delimiter. If a |
673 | /// delimiter is provided, then an empty list is allowed. If not, then at |
674 | /// least one element will be parsed. |
675 | /// |
676 | /// contextMessage is an optional message appended to "expected '('" sorts of |
677 | /// diagnostics when parsing the delimeters. |
678 | virtual ParseResult |
679 | parseCommaSeparatedList(Delimiter delimiter, |
680 | function_ref<ParseResult()> parseElementFn, |
681 | StringRef contextMessage = StringRef()) = 0; |
682 | |
683 | /// Parse a comma separated list of elements that must have at least one entry |
684 | /// in it. |
685 | ParseResult |
686 | parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) { |
687 | return parseCommaSeparatedList(Delimiter::None, parseElementFn); |
688 | } |
689 | |
690 | //===--------------------------------------------------------------------===// |
691 | // Keyword Parsing |
692 | //===--------------------------------------------------------------------===// |
693 | |
694 | /// This class represents a StringSwitch like class that is useful for parsing |
695 | /// expected keywords. On construction, it invokes `parseKeyword` and |
696 | /// processes each of the provided cases statements until a match is hit. The |
697 | /// provided `ResultT` must be assignable from `failure()`. |
698 | template <typename ResultT = ParseResult> |
699 | class KeywordSwitch { |
700 | public: |
701 | KeywordSwitch(AsmParser &parser) |
702 | : parser(parser), loc(parser.getCurrentLocation()) { |
703 | if (failed(parser.parseKeywordOrCompletion(&keyword))) |
704 | result = failure(); |
705 | } |
706 | |
707 | /// Case that uses the provided value when true. |
708 | KeywordSwitch &Case(StringLiteral str, ResultT value) { |
709 | return Case(str, [&](StringRef, SMLoc) { return std::move(value); }); |
710 | } |
711 | KeywordSwitch &Default(ResultT value) { |
712 | return Default([&](StringRef, SMLoc) { return std::move(value); }); |
713 | } |
714 | /// Case that invokes the provided functor when true. The parameters passed |
715 | /// to the functor are the keyword, and the location of the keyword (in case |
716 | /// any errors need to be emitted). |
717 | template <typename FnT> |
718 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
719 | Case(StringLiteral str, FnT &&fn) { |
720 | if (result) |
721 | return *this; |
722 | |
723 | // If the word was empty, record this as a completion. |
724 | if (keyword.empty()) |
725 | parser.codeCompleteExpectedTokens(str); |
726 | else if (keyword == str) |
727 | result.emplace(std::move(fn(keyword, loc))); |
728 | return *this; |
729 | } |
730 | template <typename FnT> |
731 | std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &> |
732 | Default(FnT &&fn) { |
733 | if (!result) |
734 | result.emplace(fn(keyword, loc)); |
735 | return *this; |
736 | } |
737 | |
738 | /// Returns true if this switch has a value yet. |
739 | bool hasValue() const { return result.has_value(); } |
740 | |
741 | /// Return the result of the switch. |
742 | [[nodiscard]] operator ResultT() { |
743 | if (!result) |
744 | return parser.emitError(loc, "unexpected keyword: ") << keyword; |
745 | return std::move(*result); |
746 | } |
747 | |
748 | private: |
749 | /// The parser used to construct this switch. |
750 | AsmParser &parser; |
751 | |
752 | /// The location of the keyword, used to emit errors as necessary. |
753 | SMLoc loc; |
754 | |
755 | /// The parsed keyword itself. |
756 | StringRef keyword; |
757 | |
758 | /// The result of the switch statement or none if currently unknown. |
759 | Optional<ResultT> result; |
760 | }; |
761 | |
762 | /// Parse a given keyword. |
763 | ParseResult parseKeyword(StringRef keyword) { |
764 | return parseKeyword(keyword, ""); |
765 | } |
766 | virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0; |
767 | |
768 | /// Parse a keyword into 'keyword'. |
769 | ParseResult parseKeyword(StringRef *keyword) { |
770 | auto loc = getCurrentLocation(); |
771 | if (parseOptionalKeyword(keyword)) |
772 | return emitError(loc, "expected valid keyword"); |
773 | return success(); |
774 | } |
775 | |
776 | /// Parse the given keyword if present. |
777 | virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0; |
778 | |
779 | /// Parse a keyword, if present, into 'keyword'. |
780 | virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0; |
781 | |
782 | /// Parse a keyword, if present, and if one of the 'allowedValues', |
783 | /// into 'keyword' |
784 | virtual ParseResult |
785 | parseOptionalKeyword(StringRef *keyword, |
786 | ArrayRef<StringRef> allowedValues) = 0; |
787 | |
788 | /// Parse a keyword or a quoted string. |
789 | ParseResult parseKeywordOrString(std::string *result) { |
790 | if (failed(parseOptionalKeywordOrString(result))) |
791 | return emitError(getCurrentLocation()) |
792 | << "expected valid keyword or string"; |
793 | return success(); |
794 | } |
795 | |
796 | /// Parse an optional keyword or string. |
797 | virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0; |
798 | |
799 | //===--------------------------------------------------------------------===// |
800 | // Attribute/Type Parsing |
801 | //===--------------------------------------------------------------------===// |
802 | |
803 | /// Invoke the `getChecked` method of the given Attribute or Type class, using |
804 | /// the provided location to emit errors in the case of failure. Note that |
805 | /// unlike `OpBuilder::getType`, this method does not implicitly insert a |
806 | /// context parameter. |
807 | template <typename T, typename... ParamsT> |
808 | auto getChecked(SMLoc loc, ParamsT &&...params) { |
809 | return T::getChecked([&] { return emitError(loc); }, |
810 | std::forward<ParamsT>(params)...); |
811 | } |
812 | /// A variant of `getChecked` that uses the result of `getNameLoc` to emit |
813 | /// errors. |
814 | template <typename T, typename... ParamsT> |
815 | auto getChecked(ParamsT &&...params) { |
816 | return T::getChecked([&] { return emitError(getNameLoc()); }, |
817 | std::forward<ParamsT>(params)...); |
818 | } |
819 | |
820 | //===--------------------------------------------------------------------===// |
821 | // Attribute Parsing |
822 | //===--------------------------------------------------------------------===// |
823 | |
824 | /// Parse an arbitrary attribute of a given type and return it in result. |
825 | virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0; |
826 | |
827 | /// Parse a custom attribute with the provided callback, unless the next |
828 | /// token is `#`, in which case the generic parser is invoked. |
829 | virtual ParseResult parseCustomAttributeWithFallback( |
830 | Attribute &result, Type type, |
831 | function_ref<ParseResult(Attribute &result, Type type)> |
832 | parseAttribute) = 0; |
833 | |
834 | /// Parse an attribute of a specific kind and type. |
835 | template <typename AttrType> |
836 | ParseResult parseAttribute(AttrType &result, Type type = {}) { |
837 | SMLoc loc = getCurrentLocation(); |
838 | |
839 | // Parse any kind of attribute. |
840 | Attribute attr; |
841 | if (parseAttribute(attr, type)) |
842 | return failure(); |
843 | |
844 | // Check for the right kind of attribute. |
845 | if (!(result = attr.dyn_cast<AttrType>())) |
846 | return emitError(loc, "invalid kind of attribute specified"); |
847 | |
848 | return success(); |
849 | } |
850 | |
851 | /// Parse an arbitrary attribute and return it in result. This also adds the |
852 | /// attribute to the specified attribute list with the specified name. |
853 | ParseResult parseAttribute(Attribute &result, StringRef attrName, |
854 | NamedAttrList &attrs) { |
855 | return parseAttribute(result, Type(), attrName, attrs); |
856 | } |
857 | |
858 | /// Parse an attribute of a specific kind and type. |
859 | template <typename AttrType> |
860 | ParseResult parseAttribute(AttrType &result, StringRef attrName, |
861 | NamedAttrList &attrs) { |
862 | return parseAttribute(result, Type(), attrName, attrs); |
863 | } |
864 | |
865 | /// Parse an arbitrary attribute of a given type and populate it in `result`. |
866 | /// This also adds the attribute to the specified attribute list with the |
867 | /// specified name. |
868 | template <typename AttrType> |
869 | ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName, |
870 | NamedAttrList &attrs) { |
871 | SMLoc loc = getCurrentLocation(); |
872 | |
873 | // Parse any kind of attribute. |
874 | Attribute attr; |
875 | if (parseAttribute(attr, type)) |
876 | return failure(); |
877 | |
878 | // Check for the right kind of attribute. |
879 | result = attr.dyn_cast<AttrType>(); |
880 | if (!result) |
881 | return emitError(loc, "invalid kind of attribute specified"); |
882 | |
883 | attrs.append(attrName, result); |
884 | return success(); |
885 | } |
886 | |
887 | /// Trait to check if `AttrType` provides a `parse` method. |
888 | template <typename AttrType> |
889 | using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(), |
890 | std::declval<Type>())); |
891 | template <typename AttrType> |
892 | using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>; |
893 | |
894 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
895 | /// which case the generic parser is invoked. The parsed attribute is |
896 | /// populated in `result` and also added to the specified attribute list with |
897 | /// the specified name. |
898 | template <typename AttrType> |
899 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
900 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
901 | StringRef attrName, NamedAttrList &attrs) { |
902 | SMLoc loc = getCurrentLocation(); |
903 | |
904 | // Parse any kind of attribute. |
905 | Attribute attr; |
906 | if (parseCustomAttributeWithFallback( |
907 | attr, type, [&](Attribute &result, Type type) -> ParseResult { |
908 | result = AttrType::parse(*this, type); |
909 | if (!result) |
910 | return failure(); |
911 | return success(); |
912 | })) |
913 | return failure(); |
914 | |
915 | // Check for the right kind of attribute. |
916 | result = attr.dyn_cast<AttrType>(); |
917 | if (!result) |
918 | return emitError(loc, "invalid kind of attribute specified"); |
919 | |
920 | attrs.append(attrName, result); |
921 | return success(); |
922 | } |
923 | |
924 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
925 | template <typename AttrType> |
926 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
927 | parseCustomAttributeWithFallback(AttrType &result, Type type, |
928 | StringRef attrName, NamedAttrList &attrs) { |
929 | return parseAttribute(result, type, attrName, attrs); |
930 | } |
931 | |
932 | /// Parse a custom attribute of a given type unless the next token is `#`, in |
933 | /// which case the generic parser is invoked. The parsed attribute is |
934 | /// populated in `result`. |
935 | template <typename AttrType> |
936 | std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult> |
937 | parseCustomAttributeWithFallback(AttrType &result) { |
938 | SMLoc loc = getCurrentLocation(); |
939 | |
940 | // Parse any kind of attribute. |
941 | Attribute attr; |
942 | if (parseCustomAttributeWithFallback( |
943 | attr, {}, [&](Attribute &result, Type type) -> ParseResult { |
944 | result = AttrType::parse(*this, type); |
945 | return success(!!result); |
946 | })) |
947 | return failure(); |
948 | |
949 | // Check for the right kind of attribute. |
950 | result = attr.dyn_cast<AttrType>(); |
951 | if (!result) |
952 | return emitError(loc, "invalid kind of attribute specified"); |
953 | return success(); |
954 | } |
955 | |
956 | /// SFINAE parsing method for Attribute that don't implement a parse method. |
957 | template <typename AttrType> |
958 | std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult> |
959 | parseCustomAttributeWithFallback(AttrType &result) { |
960 | return parseAttribute(result); |
961 | } |
962 | |
963 | /// Parse an arbitrary optional attribute of a given type and return it in |
964 | /// result. |
965 | virtual OptionalParseResult parseOptionalAttribute(Attribute &result, |
966 | Type type = {}) = 0; |
967 | |
968 | /// Parse an optional array attribute and return it in result. |
969 | virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result, |
970 | Type type = {}) = 0; |
971 | |
972 | /// Parse an optional string attribute and return it in result. |
973 | virtual OptionalParseResult parseOptionalAttribute(StringAttr &result, |
974 | Type type = {}) = 0; |
975 | |
976 | /// Parse an optional symbol ref attribute and return it in result. |
977 | virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result, |
978 | Type type = {}) = 0; |
979 | |
980 | /// Parse an optional attribute of a specific type and add it to the list with |
981 | /// the specified name. |
982 | template <typename AttrType> |
983 | OptionalParseResult parseOptionalAttribute(AttrType &result, |
984 | StringRef attrName, |
985 | NamedAttrList &attrs) { |
986 | return parseOptionalAttribute(result, Type(), attrName, attrs); |
987 | } |
988 | |
989 | /// Parse an optional attribute of a specific type and add it to the list with |
990 | /// the specified name. |
991 | template <typename AttrType> |
992 | OptionalParseResult parseOptionalAttribute(AttrType &result, Type type, |
993 | StringRef attrName, |
994 | NamedAttrList &attrs) { |
995 | OptionalParseResult parseResult = parseOptionalAttribute(result, type); |
996 | if (parseResult.has_value() && succeeded(*parseResult)) |
997 | attrs.append(attrName, result); |
998 | return parseResult; |
999 | } |
1000 | |
1001 | /// Parse a named dictionary into 'result' if it is present. |
1002 | virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0; |
1003 | |
1004 | /// Parse a named dictionary into 'result' if the `attributes` keyword is |
1005 | /// present. |
1006 | virtual ParseResult |
1007 | parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0; |
1008 | |
1009 | /// Parse an affine map instance into 'map'. |
1010 | virtual ParseResult parseAffineMap(AffineMap &map) = 0; |
1011 | |
1012 | /// Parse an integer set instance into 'set'. |
1013 | virtual ParseResult printIntegerSet(IntegerSet &set) = 0; |
1014 | |
1015 | //===--------------------------------------------------------------------===// |
1016 | // Identifier Parsing |
1017 | //===--------------------------------------------------------------------===// |
1018 | |
1019 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1020 | /// attribute. |
1021 | ParseResult parseSymbolName(StringAttr &result) { |
1022 | if (failed(parseOptionalSymbolName(result))) |
1023 | return emitError(getCurrentLocation()) |
1024 | << "expected valid '@'-identifier for symbol name"; |
1025 | return success(); |
1026 | } |
1027 | |
1028 | /// Parse an @-identifier and store it (without the '@' symbol) in a string |
1029 | /// attribute named 'attrName'. |
1030 | ParseResult parseSymbolName(StringAttr &result, StringRef attrName, |
1031 | NamedAttrList &attrs) { |
1032 | if (parseSymbolName(result)) |
1033 | return failure(); |
1034 | attrs.append(attrName, result); |
1035 | return success(); |
1036 | } |
1037 | |
1038 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1039 | /// string attribute. |
1040 | virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0; |
1041 | |
1042 | /// Parse an optional @-identifier and store it (without the '@' symbol) in a |
1043 | /// string attribute named 'attrName'. |
1044 | ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName, |
1045 | NamedAttrList &attrs) { |
1046 | if (succeeded(parseOptionalSymbolName(result))) { |
1047 | attrs.append(attrName, result); |
1048 | return success(); |
1049 | } |
1050 | return failure(); |
1051 | } |
1052 | |
1053 | //===--------------------------------------------------------------------===// |
1054 | // Resource Parsing |
1055 | //===--------------------------------------------------------------------===// |
1056 | |
1057 | /// Parse a handle to a resource within the assembly format. |
1058 | template <typename ResourceT> |
1059 | FailureOr<ResourceT> parseResourceHandle() { |
1060 | SMLoc handleLoc = getCurrentLocation(); |
1061 | |
1062 | // Try to load the dialect that owns the handle. |
1063 | auto *dialect = |
1064 | getContext()->getOrLoadDialect<typename ResourceT::Dialect>(); |
1065 | if (!dialect) { |
1066 | return emitError(handleLoc) |
1067 | << "dialect '" << ResourceT::Dialect::getDialectNamespace() |
1068 | << "' is unknown"; |
1069 | } |
1070 | |
1071 | FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect); |
1072 | if (failed(handle)) |
1073 | return failure(); |
1074 | if (auto *result = dyn_cast<ResourceT>(&*handle)) |
1075 | return std::move(*result); |
1076 | return emitError(handleLoc) << "provided resource handle differs from the " |
1077 | "expected resource type"; |
1078 | } |
1079 | |
1080 | //===--------------------------------------------------------------------===// |
1081 | // Type Parsing |
1082 | //===--------------------------------------------------------------------===// |
1083 | |
1084 | /// Parse a type. |
1085 | virtual ParseResult parseType(Type &result) = 0; |
1086 | |
1087 | /// Parse a custom type with the provided callback, unless the next |
1088 | /// token is `#`, in which case the generic parser is invoked. |
1089 | virtual ParseResult parseCustomTypeWithFallback( |
1090 | Type &result, function_ref<ParseResult(Type &result)> parseType) = 0; |
1091 | |
1092 | /// Parse an optional type. |
1093 | virtual OptionalParseResult parseOptionalType(Type &result) = 0; |
1094 | |
1095 | /// Parse a type of a specific type. |
1096 | template <typename TypeT> |
1097 | ParseResult parseType(TypeT &result) { |
1098 | SMLoc loc = getCurrentLocation(); |
1099 | |
1100 | // Parse any kind of type. |
1101 | Type type; |
1102 | if (parseType(type)) |
1103 | return failure(); |
1104 | |
1105 | // Check for the right kind of type. |
1106 | result = type.dyn_cast<TypeT>(); |
1107 | if (!result) |
1108 | return emitError(loc, "invalid kind of type specified"); |
1109 | |
1110 | return success(); |
1111 | } |
1112 | |
1113 | /// Trait to check if `TypeT` provides a `parse` method. |
1114 | template <typename TypeT> |
1115 | using type_has_parse_method = |
1116 | decltype(TypeT::parse(std::declval<AsmParser &>())); |
1117 | template <typename TypeT> |
1118 | using detect_type_has_parse_method = |
1119 | llvm::is_detected<type_has_parse_method, TypeT>; |
1120 | |
1121 | /// Parse a custom Type of a given type unless the next token is `#`, in |
1122 | /// which case the generic parser is invoked. The parsed Type is |
1123 | /// populated in `result`. |
1124 | template <typename TypeT> |
1125 | std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult> |
1126 | parseCustomTypeWithFallback(TypeT &result) { |
1127 | SMLoc loc = getCurrentLocation(); |
1128 | |
1129 | // Parse any kind of Type. |
1130 | Type type; |
1131 | if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult { |
1132 | result = TypeT::parse(*this); |
1133 | return success(!!result); |
1134 | })) |
1135 | return failure(); |
1136 | |
1137 | // Check for the right kind of Type. |
1138 | result = type.dyn_cast<TypeT>(); |
1139 | if (!result) |
1140 | return emitError(loc, "invalid kind of Type specified"); |
1141 | return success(); |
1142 | } |
1143 | |
1144 | /// SFINAE parsing method for Type that don't implement a parse method. |
1145 | template <typename TypeT> |
1146 | std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult> |
1147 | parseCustomTypeWithFallback(TypeT &result) { |
1148 | return parseType(result); |
1149 | } |
1150 | |
1151 | /// Parse a type list. |
1152 | ParseResult parseTypeList(SmallVectorImpl<Type> &result) { |
1153 | return parseCommaSeparatedList( |
1154 | [&]() { return parseType(result.emplace_back()); }); |
1155 | } |
1156 | |
1157 | /// Parse an arrow followed by a type list. |
1158 | virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1159 | |
1160 | /// Parse an optional arrow followed by a type list. |
1161 | virtual ParseResult |
1162 | parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0; |
1163 | |
1164 | /// Parse a colon followed by a type. |
1165 | virtual ParseResult parseColonType(Type &result) = 0; |
1166 | |
1167 | /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType. |
1168 | template <typename TypeType> |
1169 | ParseResult parseColonType(TypeType &result) { |
1170 | SMLoc loc = getCurrentLocation(); |
1171 | |
1172 | // Parse any kind of type. |
1173 | Type type; |
1174 | if (parseColonType(type)) |
1175 | return failure(); |
1176 | |
1177 | // Check for the right kind of type. |
1178 | result = type.dyn_cast<TypeType>(); |
1179 | if (!result) |
1180 | return emitError(loc, "invalid kind of type specified"); |
1181 | |
1182 | return success(); |
1183 | } |
1184 | |
1185 | /// Parse a colon followed by a type list, which must have at least one type. |
1186 | virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1187 | |
1188 | /// Parse an optional colon followed by a type list, which if present must |
1189 | /// have at least one type. |
1190 | virtual ParseResult |
1191 | parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0; |
1192 | |
1193 | /// Parse a keyword followed by a type. |
1194 | ParseResult parseKeywordType(const char *keyword, Type &result) { |
1195 | return failure(parseKeyword(keyword) || parseType(result)); |
1196 | } |
1197 | |
1198 | /// Add the specified type to the end of the specified type list and return |
1199 | /// success. This is a helper designed to allow parse methods to be simple |
1200 | /// and chain through || operators. |
1201 | ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) { |
1202 | result.push_back(type); |
1203 | return success(); |
1204 | } |
1205 | |
1206 | /// Add the specified types to the end of the specified type list and return |
1207 | /// success. This is a helper designed to allow parse methods to be simple |
1208 | /// and chain through || operators. |
1209 | ParseResult addTypesToList(ArrayRef<Type> types, |
1210 | SmallVectorImpl<Type> &result) { |
1211 | result.append(types.begin(), types.end()); |
1212 | return success(); |
1213 | } |
1214 | |
1215 | /// Parse a dimension list of a tensor or memref type. This populates the |
1216 | /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set |
1217 | /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable. |
1218 | /// |
1219 | /// dimension-list ::= eps | dimension (`x` dimension)* |
1220 | /// dimension-list-with-trailing-x ::= (dimension `x`)* |
1221 | /// dimension ::= `?` | decimal-literal |
1222 | /// |
1223 | /// When `allowDynamic` is not set, this is used to parse: |
1224 | /// |
1225 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* |
1226 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* |
1227 | virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions, |
1228 | bool allowDynamic = true, |
1229 | bool withTrailingX = true) = 0; |
1230 | |
1231 | /// Parse an 'x' token in a dimension list, handling the case where the x is |
1232 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the |
1233 | /// next token. |
1234 | virtual ParseResult parseXInDimensionList() = 0; |
1235 | |
1236 | protected: |
1237 | /// Parse a handle to a resource within the assembly format for the given |
1238 | /// dialect. |
1239 | virtual FailureOr<AsmDialectResourceHandle> |
1240 | parseResourceHandle(Dialect *dialect) = 0; |
1241 | |
1242 | //===--------------------------------------------------------------------===// |
1243 | // Code Completion |
1244 | //===--------------------------------------------------------------------===// |
1245 | |
1246 | /// Parse a keyword, or an empty string if the current location signals a code |
1247 | /// completion. |
1248 | virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0; |
1249 | |
1250 | /// Signal the code completion of a set of expected tokens. |
1251 | virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0; |
1252 | |
1253 | private: |
1254 | AsmParser(const AsmParser &) = delete; |
1255 | void operator=(const AsmParser &) = delete; |
1256 | }; |
1257 | |
1258 | //===----------------------------------------------------------------------===// |
1259 | // OpAsmParser |
1260 | //===----------------------------------------------------------------------===// |
1261 | |
1262 | /// The OpAsmParser has methods for interacting with the asm parser: parsing |
1263 | /// things from it, emitting errors etc. It has an intentionally high-level API |
1264 | /// that is designed to reduce/constrain syntax innovation in individual |
1265 | /// operations. |
1266 | /// |
1267 | /// For example, consider an op like this: |
1268 | /// |
1269 | /// %x = load %p[%1, %2] : memref<...> |
1270 | /// |
1271 | /// The "%x = load" tokens are already parsed and therefore invisible to the |
1272 | /// custom op parser. This can be supported by calling `parseOperandList` to |
1273 | /// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to |
1274 | /// parse the indices, then calling `parseColonTypeList` to parse the result |
1275 | /// type. |
1276 | /// |
1277 | class OpAsmParser : public AsmParser { |
1278 | public: |
1279 | using AsmParser::AsmParser; |
1280 | ~OpAsmParser() override; |
1281 | |
1282 | /// Parse a loc(...) specifier if present, filling in result if so. |
1283 | /// Location for BlockArgument and Operation may be deferred with an alias, in |
1284 | /// which case an OpaqueLoc is set and will be resolved when parsing |
1285 | /// completes. |
1286 | virtual ParseResult |
1287 | parseOptionalLocationSpecifier(Optional<Location> &result) = 0; |
1288 | |
1289 | /// Return the name of the specified result in the specified syntax, as well |
1290 | /// as the sub-element in the name. It returns an empty string and ~0U for |
1291 | /// invalid result numbers. For example, in this operation: |
1292 | /// |
1293 | /// %x, %y:2, %z = foo.op |
1294 | /// |
1295 | /// getResultName(0) == {"x", 0 } |
1296 | /// getResultName(1) == {"y", 0 } |
1297 | /// getResultName(2) == {"y", 1 } |
1298 | /// getResultName(3) == {"z", 0 } |
1299 | /// getResultName(4) == {"", ~0U } |
1300 | virtual std::pair<StringRef, unsigned> |
1301 | getResultName(unsigned resultNo) const = 0; |
1302 | |
1303 | /// Return the number of declared SSA results. This returns 4 for the foo.op |
1304 | /// example in the comment for `getResultName`. |
1305 | virtual size_t getNumResults() const = 0; |
1306 | |
1307 | // These methods emit an error and return failure or success. This allows |
1308 | // these to be chained together into a linear sequence of || expressions in |
1309 | // many cases. |
1310 | |
1311 | /// Parse an operation in its generic form. |
1312 | /// The parsed operation is parsed in the current context and inserted in the |
1313 | /// provided block and insertion point. The results produced by this operation |
1314 | /// aren't mapped to any named value in the parser. Returns nullptr on |
1315 | /// failure. |
1316 | virtual Operation *parseGenericOperation(Block *insertBlock, |
1317 | Block::iterator insertPt) = 0; |
1318 | |
1319 | /// Parse the name of an operation, in the custom form. On success, return a |
1320 | /// an object of type 'OperationName'. Otherwise, failure is returned. |
1321 | virtual FailureOr<OperationName> parseCustomOperationName() = 0; |
1322 | |
1323 | //===--------------------------------------------------------------------===// |
1324 | // Operand Parsing |
1325 | //===--------------------------------------------------------------------===// |
1326 | |
1327 | /// This is the representation of an operand reference. |
1328 | struct UnresolvedOperand { |
1329 | SMLoc location; // Location of the token. |
1330 | StringRef name; // Value name, e.g. %42 or %abc |
1331 | unsigned number; // Number, e.g. 12 for an operand like %xyz#12 |
1332 | }; |
1333 | |
1334 | /// Parse different components, viz., use-info of operand(s), successor(s), |
1335 | /// region(s), attribute(s) and function-type, of the generic form of an |
1336 | /// operation instance and populate the input operation-state 'result' with |
1337 | /// those components. If any of the components is explicitly provided, then |
1338 | /// skip parsing that component. |
1339 | virtual ParseResult parseGenericOperationAfterOpName( |
1340 | OperationState &result, |
1341 | Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt, |
1342 | Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt, |
1343 | Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions = |
1344 | std::nullopt, |
1345 | Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt, |
1346 | Optional<FunctionType> parsedFnType = std::nullopt) = 0; |
1347 | |
1348 | /// Parse a single SSA value operand name along with a result number if |
1349 | /// `allowResultNumber` is true. |
1350 | virtual ParseResult parseOperand(UnresolvedOperand &result, |
1351 | bool allowResultNumber = true) = 0; |
1352 | |
1353 | /// Parse a single operand if present. |
1354 | virtual OptionalParseResult |
1355 | parseOptionalOperand(UnresolvedOperand &result, |
1356 | bool allowResultNumber = true) = 0; |
1357 | |
1358 | /// Parse zero or more SSA comma-separated operand references with a specified |
1359 | /// surrounding delimiter, and an optional required operand count. |
1360 | virtual ParseResult |
1361 | parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1362 | Delimiter delimiter = Delimiter::None, |
1363 | bool allowResultNumber = true, |
1364 | int requiredOperandCount = -1) = 0; |
1365 | |
1366 | /// Parse a specified number of comma separated operands. |
1367 | ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1368 | int requiredOperandCount, |
1369 | Delimiter delimiter = Delimiter::None) { |
1370 | return parseOperandList(result, delimiter, |
1371 | /*allowResultNumber=*/true, requiredOperandCount); |
1372 | } |
1373 | |
1374 | /// Parse zero or more trailing SSA comma-separated trailing operand |
1375 | /// references with a specified surrounding delimiter, and an optional |
1376 | /// required operand count. A leading comma is expected before the |
1377 | /// operands. |
1378 | ParseResult |
1379 | parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result, |
1380 | Delimiter delimiter = Delimiter::None) { |
1381 | if (failed(parseOptionalComma())) |
1382 | return success(); // The comma is optional. |
1383 | return parseOperandList(result, delimiter); |
1384 | } |
1385 | |
1386 | /// Resolve an operand to an SSA value, emitting an error on failure. |
1387 | virtual ParseResult resolveOperand(const UnresolvedOperand &operand, |
1388 | Type type, |
1389 | SmallVectorImpl<Value> &result) = 0; |
1390 | |
1391 | /// Resolve a list of operands to SSA values, emitting an error on failure, or |
1392 | /// appending the results to the list on success. This method should be used |
1393 | /// when all operands have the same type. |
1394 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1395 | ParseResult resolveOperands(Operands &&operands, Type type, |
1396 | SmallVectorImpl<Value> &result) { |
1397 | for (const UnresolvedOperand &operand : operands) |
1398 | if (resolveOperand(operand, type, result)) |
1399 | return failure(); |
1400 | return success(); |
1401 | } |
1402 | template <typename Operands = ArrayRef<UnresolvedOperand>> |
1403 | ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc, |
1404 | SmallVectorImpl<Value> &result) { |
1405 | return resolveOperands(std::forward<Operands>(operands), type, result); |
1406 | } |
1407 | |
1408 | /// Resolve a list of operands and a list of operand types to SSA values, |
1409 | /// emitting an error and returning failure, or appending the results |
1410 | /// to the list on success. |
1411 | template <typename Operands = ArrayRef<UnresolvedOperand>, |
1412 | typename Types = ArrayRef<Type>> |
1413 | std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult> |
1414 | resolveOperands(Operands &&operands, Types &&types, SMLoc loc, |
1415 | SmallVectorImpl<Value> &result) { |
1416 | size_t operandSize = std::distance(operands.begin(), operands.end()); |
1417 | size_t typeSize = std::distance(types.begin(), types.end()); |
1418 | if (operandSize != typeSize) |
1419 | return emitError(loc) |
1420 | << operandSize << " operands present, but expected " << typeSize; |
1421 | |
1422 | for (auto [operand, type] : llvm::zip(operands, types)) |
1423 | if (resolveOperand(operand, type, result)) |
1424 | return failure(); |
1425 | return success(); |
1426 | } |
1427 | |
1428 | /// Parses an affine map attribute where dims and symbols are SSA operands. |
1429 | /// Operand values must come from single-result sources, and be valid |
1430 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1431 | virtual ParseResult |
1432 | parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands, |
1433 | Attribute &map, StringRef attrName, |
1434 | NamedAttrList &attrs, |
1435 | Delimiter delimiter = Delimiter::Square) = 0; |
1436 | |
1437 | /// Parses an affine expression where dims and symbols are SSA operands. |
1438 | /// Operand values must come from single-result sources, and be valid |
1439 | /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. |
1440 | virtual ParseResult |
1441 | parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands, |
1442 | SmallVectorImpl<UnresolvedOperand> &symbOperands, |
1443 | AffineExpr &expr) = 0; |
1444 | |
1445 | //===--------------------------------------------------------------------===// |
1446 | // Argument Parsing |
1447 | //===--------------------------------------------------------------------===// |
1448 | |
1449 | struct Argument { |
1450 | UnresolvedOperand ssaName; // SourceLoc, SSA name, result #. |
1451 | Type type; // Type. |
1452 | DictionaryAttr attrs; // Attributes if present. |
1453 | Optional<Location> sourceLoc; // Source location specifier if present. |
1454 | }; |
1455 | |
1456 | /// Parse a single argument with the following syntax: |
1457 | /// |
1458 | /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)` |
1459 | /// |
1460 | /// If `allowType` is false or `allowAttrs` are false then the respective |
1461 | /// parts of the grammar are not parsed. |
1462 | virtual ParseResult parseArgument(Argument &result, bool allowType = false, |
1463 | bool allowAttrs = false) = 0; |
1464 | |
1465 | /// Parse a single argument if present. |
1466 | virtual OptionalParseResult |
1467 | parseOptionalArgument(Argument &result, bool allowType = false, |
1468 | bool allowAttrs = false) = 0; |
1469 | |
1470 | /// Parse zero or more arguments with a specified surrounding delimiter. |
1471 | virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result, |
1472 | Delimiter delimiter = Delimiter::None, |
1473 | bool allowType = false, |
1474 | bool allowAttrs = false) = 0; |
1475 | |
1476 | //===--------------------------------------------------------------------===// |
1477 | // Region Parsing |
1478 | //===--------------------------------------------------------------------===// |
1479 | |
1480 | /// Parses a region. Any parsed blocks are appended to 'region' and must be |
1481 | /// moved to the op regions after the op is created. The first block of the |
1482 | /// region takes 'arguments'. |
1483 | /// |
1484 | /// If 'enableNameShadowing' is set to true, the argument names are allowed to |
1485 | /// shadow the names of other existing SSA values defined above the region |
1486 | /// scope. 'enableNameShadowing' can only be set to true for regions attached |
1487 | /// to operations that are 'IsolatedFromAbove'. |
1488 | virtual ParseResult parseRegion(Region ®ion, |
1489 | ArrayRef<Argument> arguments = {}, |
1490 | bool enableNameShadowing = false) = 0; |
1491 | |
1492 | /// Parses a region if present. |
1493 | virtual OptionalParseResult |
1494 | parseOptionalRegion(Region ®ion, ArrayRef<Argument> arguments = {}, |
1495 | bool enableNameShadowing = false) = 0; |
1496 | |
1497 | /// Parses a region if present. If the region is present, a new region is |
1498 | /// allocated and placed in `region`. If no region is present or on failure, |
1499 | /// `region` remains untouched. |
1500 | virtual OptionalParseResult |
1501 | parseOptionalRegion(std::unique_ptr<Region> ®ion, |
1502 | ArrayRef<Argument> arguments = {}, |
1503 | bool enableNameShadowing = false) = 0; |
1504 | |
1505 | //===--------------------------------------------------------------------===// |
1506 | // Successor Parsing |
1507 | //===--------------------------------------------------------------------===// |
1508 | |
1509 | /// Parse a single operation successor. |
1510 | virtual ParseResult parseSuccessor(Block *&dest) = 0; |
1511 | |
1512 | /// Parse an optional operation successor. |
1513 | virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0; |
1514 | |
1515 | /// Parse a single operation successor and its operand list. |
1516 | virtual ParseResult |
1517 | parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0; |
1518 | |
1519 | //===--------------------------------------------------------------------===// |
1520 | // Type Parsing |
1521 | //===--------------------------------------------------------------------===// |
1522 | |
1523 | /// Parse a list of assignments of the form |
1524 | /// (%x1 = %y1, %x2 = %y2, ...) |
1525 | ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs, |
1526 | SmallVectorImpl<UnresolvedOperand> &rhs) { |
1527 | OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs); |
1528 | if (!result.has_value()) |
1529 | return emitError(getCurrentLocation(), "expected '('"); |
1530 | return result.value(); |
1531 | } |
1532 | |
1533 | virtual OptionalParseResult |
1534 | parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs, |
1535 | SmallVectorImpl<UnresolvedOperand> &rhs) = 0; |
1536 | }; |
1537 | |
1538 | //===--------------------------------------------------------------------===// |
1539 | // Dialect OpAsm interface. |
1540 | //===--------------------------------------------------------------------===// |
1541 | |
1542 | /// A functor used to set the name of the start of a result group of an |
1543 | /// operation. See 'getAsmResultNames' below for more details. |
1544 | using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>; |
1545 | |
1546 | /// A functor used to set the name of blocks in regions directly nested under |
1547 | /// an operation. |
1548 | using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>; |
1549 | |
1550 | class OpAsmDialectInterface |
1551 | : public DialectInterface::Base<OpAsmDialectInterface> { |
1552 | public: |
1553 | OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {} |
1554 | |
1555 | //===------------------------------------------------------------------===// |
1556 | // Aliases |
1557 | //===------------------------------------------------------------------===// |
1558 | |
1559 | /// Holds the result of `getAlias` hook call. |
1560 | enum class AliasResult { |
1561 | /// The object (type or attribute) is not supported by the hook |
1562 | /// and an alias was not provided. |
1563 | NoAlias, |
1564 | /// An alias was provided, but it might be overriden by other hook. |
1565 | OverridableAlias, |
1566 | /// An alias was provided and it should be used |
1567 | /// (no other hooks will be checked). |
1568 | FinalAlias |
1569 | }; |
1570 | |
1571 | /// Hooks for getting an alias identifier alias for a given symbol, that is |
1572 | /// not necessarily a part of this dialect. The identifier is used in place of |
1573 | /// the symbol when printing textual IR. These aliases must not contain `.` or |
1574 | /// end with a numeric digit([0-9]+). |
1575 | virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const { |
1576 | return AliasResult::NoAlias; |
1577 | } |
1578 | virtual AliasResult getAlias(Type type, raw_ostream &os) const { |
1579 | return AliasResult::NoAlias; |
1580 | } |
1581 | |
1582 | //===--------------------------------------------------------------------===// |
1583 | // Resources |
1584 | //===--------------------------------------------------------------------===// |
1585 | |
1586 | /// Declare a resource with the given key, returning a handle to use for any |
1587 | /// references of this resource key within the IR during parsing. The result |
1588 | /// of `getResourceKey` on the returned handle is permitted to be different |
1589 | /// than `key`. |
1590 | virtual FailureOr<AsmDialectResourceHandle> |
1591 | declareResource(StringRef key) const { |
1592 | return failure(); |
1593 | } |
1594 | |
1595 | /// Return a key to use for the given resource. This key should uniquely |
1596 | /// identify this resource within the dialect. |
1597 | virtual std::string |
1598 | getResourceKey(const AsmDialectResourceHandle &handle) const { |
1599 | llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600) |
1600 | "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources" , "mlir/include/mlir/IR/OpImplementation.h", 1600); |
1601 | } |
1602 | |
1603 | /// Hook for parsing resource entries. Returns failure if the entry was not |
1604 | /// valid, or could otherwise not be processed correctly. Any necessary errors |
1605 | /// can be emitted via the provided entry. |
1606 | virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const; |
1607 | |
1608 | /// Hook for building resources to use during printing. The given `op` may be |
1609 | /// inspected to help determine what information to include. |
1610 | /// `referencedResources` contains all of the resources detected when printing |
1611 | /// 'op'. |
1612 | virtual void |
1613 | buildResources(Operation *op, |
1614 | const SetVector<AsmDialectResourceHandle> &referencedResources, |
1615 | AsmResourceBuilder &builder) const {} |
1616 | }; |
1617 | } // namespace mlir |
1618 | |
1619 | //===--------------------------------------------------------------------===// |
1620 | // Operation OpAsm interface. |
1621 | //===--------------------------------------------------------------------===// |
1622 | |
1623 | /// The OpAsmOpInterface, see OpAsmInterface.td for more details. |
1624 | #include "mlir/IR/OpAsmInterface.h.inc" |
1625 | |
1626 | namespace llvm { |
1627 | template <> |
1628 | struct DenseMapInfo<mlir::AsmDialectResourceHandle> { |
1629 | static inline mlir::AsmDialectResourceHandle getEmptyKey() { |
1630 | return {DenseMapInfo<void *>::getEmptyKey(), |
1631 | DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr}; |
1632 | } |
1633 | static inline mlir::AsmDialectResourceHandle getTombstoneKey() { |
1634 | return {DenseMapInfo<void *>::getTombstoneKey(), |
1635 | DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr}; |
1636 | } |
1637 | static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) { |
1638 | return DenseMapInfo<void *>::getHashValue(handle.getResource()); |
1639 | } |
1640 | static bool isEqual(const mlir::AsmDialectResourceHandle &lhs, |
1641 | const mlir::AsmDialectResourceHandle &rhs) { |
1642 | return lhs.getResource() == rhs.getResource(); |
1643 | } |
1644 | }; |
1645 | } // namespace llvm |
1646 | |
1647 | #endif |