Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/IR/BuiltinAttributes.cpp
Warning:line 845, column 9
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name BuiltinAttributes.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/IR -I /build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/IR -I include -I /build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/llvm/include -I /build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-08-29-020613-35344-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/IR/BuiltinAttributes.cpp

/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/IR/BuiltinAttributes.cpp

1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Endian.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29/// Tablegen Attribute Definitions
30//===----------------------------------------------------------------------===//
31
32#define GET_ATTRDEF_CLASSES
33#include "mlir/IR/BuiltinAttributes.cpp.inc"
34
35//===----------------------------------------------------------------------===//
36// BuiltinDialect
37//===----------------------------------------------------------------------===//
38
39void BuiltinDialect::registerAttributes() {
40 addAttributes<
41#define GET_ATTRDEF_LIST
42#include "mlir/IR/BuiltinAttributes.cpp.inc"
43 >();
44}
45
46//===----------------------------------------------------------------------===//
47// ArrayAttr
48//===----------------------------------------------------------------------===//
49
50void ArrayAttr::walkImmediateSubElements(
51 function_ref<void(Attribute)> walkAttrsFn,
52 function_ref<void(Type)> walkTypesFn) const {
53 for (Attribute attr : getValue())
54 walkAttrsFn(attr);
55}
56
57Attribute
58ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
59 ArrayRef<Type> replTypes) const {
60 return get(getContext(), replAttrs);
61}
62
63//===----------------------------------------------------------------------===//
64// DictionaryAttr
65//===----------------------------------------------------------------------===//
66
67/// Helper function that does either an in place sort or sorts from source array
68/// into destination. If inPlace then storage is both the source and the
69/// destination, else value is the source and storage destination. Returns
70/// whether source was sorted.
71template <bool inPlace>
72static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
73 SmallVectorImpl<NamedAttribute> &storage) {
74 // Specialize for the common case.
75 switch (value.size()) {
76 case 0:
77 // Zero already sorted.
78 if (!inPlace)
79 storage.clear();
80 break;
81 case 1:
82 // One already sorted but may need to be copied.
83 if (!inPlace)
84 storage.assign({value[0]});
85 break;
86 case 2: {
87 bool isSorted = value[0] < value[1];
88 if (inPlace) {
89 if (!isSorted)
90 std::swap(storage[0], storage[1]);
91 } else if (isSorted) {
92 storage.assign({value[0], value[1]});
93 } else {
94 storage.assign({value[1], value[0]});
95 }
96 return !isSorted;
97 }
98 default:
99 if (!inPlace)
100 storage.assign(value.begin(), value.end());
101 // Check to see they are sorted already.
102 bool isSorted = llvm::is_sorted(value);
103 // If not, do a general sort.
104 if (!isSorted)
105 llvm::array_pod_sort(storage.begin(), storage.end());
106 return !isSorted;
107 }
108 return false;
109}
110
111/// Returns an entry with a duplicate name from the given sorted array of named
112/// attributes. Returns llvm::None if all elements have unique names.
113static Optional<NamedAttribute>
114findDuplicateElement(ArrayRef<NamedAttribute> value) {
115 const Optional<NamedAttribute> none{llvm::None};
116 if (value.size() < 2)
117 return none;
118
119 if (value.size() == 2)
120 return value[0].getName() == value[1].getName() ? value[0] : none;
121
122 const auto *it = std::adjacent_find(value.begin(), value.end(),
123 [](NamedAttribute l, NamedAttribute r) {
124 return l.getName() == r.getName();
125 });
126 return it != value.end() ? *it : none;
127}
128
129bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
130 SmallVectorImpl<NamedAttribute> &storage) {
131 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
132 assert(!findDuplicateElement(storage) &&(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__
))
133 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__
))
;
134 return isSorted;
135}
136
137bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
138 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
139 assert(!findDuplicateElement(array) &&(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__
))
140 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__
))
;
141 return isSorted;
142}
143
144Optional<NamedAttribute>
145DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
146 bool isSorted) {
147 if (!isSorted)
148 dictionaryAttrSort</*inPlace=*/true>(array, array);
149 return findDuplicateElement(array);
150}
151
152DictionaryAttr DictionaryAttr::get(MLIRContext *context,
153 ArrayRef<NamedAttribute> value) {
154 if (value.empty())
155 return DictionaryAttr::getEmpty(context);
156
157 // We need to sort the element list to canonicalize it.
158 SmallVector<NamedAttribute, 8> storage;
159 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
160 value = storage;
161 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__
))
162 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__
))
;
163 return Base::get(context, value);
164}
165/// Construct a dictionary with an array of values that is known to already be
166/// sorted by name and uniqued.
167DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
168 ArrayRef<NamedAttribute> value) {
169 if (value.empty())
170 return DictionaryAttr::getEmpty(context);
171 // Ensure that the attribute elements are unique and sorted.
172 assert(llvm::is_sorted((static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
173 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
174 "expected attribute values to be sorted")(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
;
175 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__
))
176 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__
))
;
177 return Base::get(context, value);
178}
179
180/// Return the specified attribute if present, null otherwise.
181Attribute DictionaryAttr::get(StringRef name) const {
182 auto it = impl::findAttrSorted(begin(), end(), name);
183 return it.second ? it.first->getValue() : Attribute();
184}
185Attribute DictionaryAttr::get(StringAttr name) const {
186 auto it = impl::findAttrSorted(begin(), end(), name);
187 return it.second ? it.first->getValue() : Attribute();
188}
189
190/// Return the specified named attribute if present, None otherwise.
191Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
192 auto it = impl::findAttrSorted(begin(), end(), name);
193 return it.second ? *it.first : Optional<NamedAttribute>();
194}
195Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
196 auto it = impl::findAttrSorted(begin(), end(), name);
197 return it.second ? *it.first : Optional<NamedAttribute>();
198}
199
200/// Return whether the specified attribute is present.
201bool DictionaryAttr::contains(StringRef name) const {
202 return impl::findAttrSorted(begin(), end(), name).second;
203}
204bool DictionaryAttr::contains(StringAttr name) const {
205 return impl::findAttrSorted(begin(), end(), name).second;
206}
207
208DictionaryAttr::iterator DictionaryAttr::begin() const {
209 return getValue().begin();
210}
211DictionaryAttr::iterator DictionaryAttr::end() const {
212 return getValue().end();
213}
214size_t DictionaryAttr::size() const { return getValue().size(); }
215
216DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
217 return Base::get(context, ArrayRef<NamedAttribute>());
218}
219
220void DictionaryAttr::walkImmediateSubElements(
221 function_ref<void(Attribute)> walkAttrsFn,
222 function_ref<void(Type)> walkTypesFn) const {
223 for (const NamedAttribute &attr : getValue())
224 walkAttrsFn(attr.getValue());
225}
226
227Attribute
228DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
229 ArrayRef<Type> replTypes) const {
230 std::vector<NamedAttribute> vec = getValue().vec();
231 for (auto &it : llvm::enumerate(replAttrs))
232 vec[it.index()].setValue(it.value());
233
234 // The above only modifies the mapped value, but not the key, and therefore
235 // not the order of the elements. It remains sorted
236 return getWithSorted(getContext(), vec);
237}
238
239//===----------------------------------------------------------------------===//
240// StringAttr
241//===----------------------------------------------------------------------===//
242
243StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
244 return Base::get(context, "", NoneType::get(context));
245}
246
247/// Twine support for StringAttr.
248StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
249 // Fast-path empty twine.
250 if (twine.isTriviallyEmpty())
251 return get(context);
252 SmallVector<char, 32> tempStr;
253 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
254}
255
256/// Twine support for StringAttr.
257StringAttr StringAttr::get(const Twine &twine, Type type) {
258 SmallVector<char, 32> tempStr;
259 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
260}
261
262StringRef StringAttr::getValue() const { return getImpl()->value; }
263
264Type StringAttr::getType() const { return getImpl()->type; }
265
266Dialect *StringAttr::getReferencedDialect() const {
267 return getImpl()->referencedDialect;
268}
269
270//===----------------------------------------------------------------------===//
271// FloatAttr
272//===----------------------------------------------------------------------===//
273
274double FloatAttr::getValueAsDouble() const {
275 return getValueAsDouble(getValue());
276}
277double FloatAttr::getValueAsDouble(APFloat value) {
278 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
279 bool losesInfo = false;
280 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
281 &losesInfo);
282 }
283 return value.convertToDouble();
284}
285
286LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
287 Type type, APFloat value) {
288 // Verify that the type is correct.
289 if (!type.isa<FloatType>())
290 return emitError() << "expected floating point type";
291
292 // Verify that the type semantics match that of the value.
293 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
294 return emitError()
295 << "FloatAttr type doesn't match the type implied by its value";
296 }
297 return success();
298}
299
300//===----------------------------------------------------------------------===//
301// SymbolRefAttr
302//===----------------------------------------------------------------------===//
303
304SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
305 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
306 return get(StringAttr::get(ctx, value), nestedRefs);
307}
308
309FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
310 return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
311}
312
313FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
314 return get(value, {}).cast<FlatSymbolRefAttr>();
315}
316
317FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
318 auto symName =
319 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
320 assert(symName && "value does not have a valid symbol name")(static_cast <bool> (symName && "value does not have a valid symbol name"
) ? void (0) : __assert_fail ("symName && \"value does not have a valid symbol name\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 320, __extension__ __PRETTY_FUNCTION__
))
;
321 return SymbolRefAttr::get(symName);
322}
323
324StringAttr SymbolRefAttr::getLeafReference() const {
325 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
326 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
327}
328
329void SymbolRefAttr::walkImmediateSubElements(
330 function_ref<void(Attribute)> walkAttrsFn,
331 function_ref<void(Type)> walkTypesFn) const {
332 walkAttrsFn(getRootReference());
333 for (FlatSymbolRefAttr ref : getNestedReferences())
334 walkAttrsFn(ref);
335}
336
337Attribute
338SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
339 ArrayRef<Type> replTypes) const {
340 ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
341 ArrayRef<FlatSymbolRefAttr> nestedRefs(
342 static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
343 rawNestedRefs.size());
344 return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
345}
346
347//===----------------------------------------------------------------------===//
348// IntegerAttr
349//===----------------------------------------------------------------------===//
350
351int64_t IntegerAttr::getInt() const {
352 assert((getType().isIndex() || getType().isSignlessInteger()) &&(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__
))
353 "must be signless integer")(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__
))
;
354 return getValue().getSExtValue();
355}
356
357int64_t IntegerAttr::getSInt() const {
358 assert(getType().isSignedInteger() && "must be signed integer")(static_cast <bool> (getType().isSignedInteger() &&
"must be signed integer") ? void (0) : __assert_fail ("getType().isSignedInteger() && \"must be signed integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 358, __extension__ __PRETTY_FUNCTION__
))
;
359 return getValue().getSExtValue();
360}
361
362uint64_t IntegerAttr::getUInt() const {
363 assert(getType().isUnsignedInteger() && "must be unsigned integer")(static_cast <bool> (getType().isUnsignedInteger() &&
"must be unsigned integer") ? void (0) : __assert_fail ("getType().isUnsignedInteger() && \"must be unsigned integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 363, __extension__ __PRETTY_FUNCTION__
))
;
364 return getValue().getZExtValue();
365}
366
367/// Return the value as an APSInt which carries the signed from the type of
368/// the attribute. This traps on signless integers types!
369APSInt IntegerAttr::getAPSInt() const {
370 assert(!getType().isSignlessInteger() &&(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 371, __extension__ __PRETTY_FUNCTION__
))
371 "Signless integers don't carry a sign for APSInt")(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 371, __extension__ __PRETTY_FUNCTION__
))
;
372 return APSInt(getValue(), getType().isUnsignedInteger());
373}
374
375LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
376 Type type, APInt value) {
377 if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
378 if (integerType.getWidth() != value.getBitWidth())
379 return emitError() << "integer type bit width (" << integerType.getWidth()
380 << ") doesn't match value bit width ("
381 << value.getBitWidth() << ")";
382 return success();
383 }
384 if (type.isa<IndexType>())
385 return success();
386 return emitError() << "expected integer or index type";
387}
388
389BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
390 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
391 return attr.cast<BoolAttr>();
392}
393
394//===----------------------------------------------------------------------===//
395// BoolAttr
396//===----------------------------------------------------------------------===//
397
398bool BoolAttr::getValue() const {
399 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
400 return storage->value.getBoolValue();
401}
402
403bool BoolAttr::classof(Attribute attr) {
404 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
405 return intAttr && intAttr.getType().isSignlessInteger(1);
406}
407
408//===----------------------------------------------------------------------===//
409// OpaqueAttr
410//===----------------------------------------------------------------------===//
411
412LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
413 StringAttr dialect, StringRef attrData,
414 Type type) {
415 if (!Dialect::isValidNamespace(dialect.strref()))
416 return emitError() << "invalid dialect namespace '" << dialect << "'";
417
418 // Check that the dialect is actually registered.
419 MLIRContext *context = dialect.getContext();
420 if (!context->allowsUnregisteredDialects() &&
421 !context->getLoadedDialect(dialect.strref())) {
422 return emitError()
423 << "#" << dialect << "<\"" << attrData << "\"> : " << type
424 << " attribute created with unregistered dialect. If this is "
425 "intended, please call allowUnregisteredDialects() on the "
426 "MLIRContext, or use -allow-unregistered-dialect with "
427 "the MLIR opt tool used";
428 }
429
430 return success();
431}
432
433//===----------------------------------------------------------------------===//
434// DenseElementsAttr Utilities
435//===----------------------------------------------------------------------===//
436
437/// Get the bitwidth of a dense element type within the buffer.
438/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
439static size_t getDenseElementStorageWidth(size_t origWidth) {
440 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
441}
442static size_t getDenseElementStorageWidth(Type elementType) {
443 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
444}
445
446/// Set a bit to a specific value.
447static void setBit(char *rawData, size_t bitPos, bool value) {
448 if (value)
449 rawData[bitPos / CHAR_BIT8] |= (1 << (bitPos % CHAR_BIT8));
450 else
451 rawData[bitPos / CHAR_BIT8] &= ~(1 << (bitPos % CHAR_BIT8));
452}
453
454/// Return the value of the specified bit.
455static bool getBit(const char *rawData, size_t bitPos) {
456 return (rawData[bitPos / CHAR_BIT8] & (1 << (bitPos % CHAR_BIT8))) != 0;
457}
458
459/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
460/// BE format.
461static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
462 char *result) {
463 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 464, __extension__ __PRETTY_FUNCTION__
))
464 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 464, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
465 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (value.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__
))
;
466
467 // Copy the words filled with data.
468 // For example, when `value` has 2 words, the first word is filled with data.
469 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
470 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
471 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
472 numFilledWords, result);
473 // Convert last word of APInt to LE format and store it in char
474 // array(`valueLE`).
475 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
476 size_t lastWordPos = numFilledWords;
477 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
478 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
479 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
480 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
481 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
482 // and store it in `result`.
483 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
484 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
485 valueLE.begin(), result + lastWordPos,
486 (numBytes - lastWordPos) * CHAR_BIT8, 1);
487}
488
489/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
490/// format.
491static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
492 APInt &result) {
493 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 494, __extension__ __PRETTY_FUNCTION__
))
494 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 494, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
495 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (result.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__
))
;
496
497 // Copy the data that fills the word of `result` from `inArray`.
498 // For example, when `result` has 2 words, the first word will be filled with
499 // data. So, the first 8 bytes are copied from `inArray` here.
500 // `inArray` (10 bytes, BE): |abcdefgh|ij|
501 // ==> `result` (2 words, BE): |abcdefgh|--------|
502 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
503 std::copy_n(
504 inArray, numFilledWords,
505 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
506
507 // Convert array data which will be last word of `result` to LE format, and
508 // store it in char array(`inArrayLE`).
509 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
510 size_t lastWordPos = numFilledWords;
511 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
512 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
513 inArray + lastWordPos, inArrayLE.begin(),
514 (numBytes - lastWordPos) * CHAR_BIT8, 1);
515
516 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
517 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
518 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
519 inArrayLE.begin(),
520 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
521 lastWordPos,
522 APInt::APINT_BITS_PER_WORD, 1);
523}
524
525/// Writes value to the bit position `bitPos` in array `rawData`.
526static void writeBits(char *rawData, size_t bitPos, APInt value) {
527 size_t bitWidth = value.getBitWidth();
528
529 // If the bitwidth is 1 we just toggle the specific bit.
530 if (bitWidth == 1)
531 return setBit(rawData, bitPos, value.isOneValue());
532
533 // Otherwise, the bit position is guaranteed to be byte aligned.
534 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 534, __extension__ __PRETTY_FUNCTION__
))
;
535 if (llvm::support::endian::system_endianness() ==
536 llvm::support::endianness::big) {
537 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
538 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
539 // work correctly in BE format.
540 // ex. `value` (2 words including 10 bytes)
541 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
542 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT8),
543 rawData + (bitPos / CHAR_BIT8));
544 } else {
545 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
546 llvm::divideCeil(bitWidth, CHAR_BIT8),
547 rawData + (bitPos / CHAR_BIT8));
548 }
549}
550
551/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
552/// `rawData`.
553static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
554 // Handle a boolean bit position.
555 if (bitWidth == 1)
556 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
557
558 // Otherwise, the bit position must be 8-bit aligned.
559 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 559, __extension__ __PRETTY_FUNCTION__
))
;
560 APInt result(bitWidth, 0);
561 if (llvm::support::endian::system_endianness() ==
562 llvm::support::endianness::big) {
563 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
564 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
565 // work correctly in BE format.
566 // ex. `result` (2 words including 10 bytes)
567 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
568 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT8),
569 llvm::divideCeil(bitWidth, CHAR_BIT8), result);
570 } else {
571 std::copy_n(rawData + (bitPos / CHAR_BIT8),
572 llvm::divideCeil(bitWidth, CHAR_BIT8),
573 const_cast<char *>(
574 reinterpret_cast<const char *>(result.getRawData())));
575 }
576 return result;
577}
578
579/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
580/// the same element count as 'type'.
581template <typename Values>
582static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
583 return (values.size() == 1) ||
584 (type.getNumElements() == static_cast<int64_t>(values.size()));
585}
586
587//===----------------------------------------------------------------------===//
588// DenseElementsAttr Iterators
589//===----------------------------------------------------------------------===//
590
591//===----------------------------------------------------------------------===//
592// AttributeElementIterator
593
594DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
595 DenseElementsAttr attr, size_t index)
596 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
597 Attribute, Attribute, Attribute>(
598 attr.getAsOpaquePointer(), index) {}
599
600Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
601 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
602 Type eltTy = owner.getElementType();
603 if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
604 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
605 if (eltTy.isa<IndexType>())
606 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
607 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
608 IntElementIterator intIt(owner, index);
609 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
610 return FloatAttr::get(eltTy, *floatIt);
611 }
612 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
613 auto complexEltTy = complexTy.getElementType();
614 ComplexIntElementIterator complexIntIt(owner, index);
615 if (complexEltTy.isa<IntegerType>()) {
616 auto value = *complexIntIt;
617 auto real = IntegerAttr::get(complexEltTy, value.real());
618 auto imag = IntegerAttr::get(complexEltTy, value.imag());
619 return ArrayAttr::get(complexTy.getContext(),
620 ArrayRef<Attribute>{real, imag});
621 }
622
623 ComplexFloatElementIterator complexFloatIt(
624 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
625 auto value = *complexFloatIt;
626 auto real = FloatAttr::get(complexEltTy, value.real());
627 auto imag = FloatAttr::get(complexEltTy, value.imag());
628 return ArrayAttr::get(complexTy.getContext(),
629 ArrayRef<Attribute>{real, imag});
630 }
631 if (owner.isa<DenseStringElementsAttr>()) {
632 ArrayRef<StringRef> vals = owner.getRawStringData();
633 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
634 }
635 llvm_unreachable("unexpected element type")::llvm::llvm_unreachable_internal("unexpected element type", "mlir/lib/IR/BuiltinAttributes.cpp"
, 635)
;
636}
637
638//===----------------------------------------------------------------------===//
639// BoolElementIterator
640
641DenseElementsAttr::BoolElementIterator::BoolElementIterator(
642 DenseElementsAttr attr, size_t dataIndex)
643 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
644 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
645
646bool DenseElementsAttr::BoolElementIterator::operator*() const {
647 return getBit(getData(), getDataIndex());
648}
649
650//===----------------------------------------------------------------------===//
651// IntElementIterator
652
653DenseElementsAttr::IntElementIterator::IntElementIterator(
654 DenseElementsAttr attr, size_t dataIndex)
655 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
656 attr.getRawData().data(), attr.isSplat(), dataIndex),
657 bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
658
659APInt DenseElementsAttr::IntElementIterator::operator*() const {
660 return readBits(getData(),
661 getDataIndex() * getDenseElementStorageWidth(bitWidth),
662 bitWidth);
663}
664
665//===----------------------------------------------------------------------===//
666// ComplexIntElementIterator
667
668DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
669 DenseElementsAttr attr, size_t dataIndex)
670 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
671 std::complex<APInt>, std::complex<APInt>,
672 std::complex<APInt>>(
673 attr.getRawData().data(), attr.isSplat(), dataIndex) {
674 auto complexType = attr.getElementType().cast<ComplexType>();
675 bitWidth = getDenseElementBitWidth(complexType.getElementType());
676}
677
678std::complex<APInt>
679DenseElementsAttr::ComplexIntElementIterator::operator*() const {
680 size_t storageWidth = getDenseElementStorageWidth(bitWidth);
681 size_t offset = getDataIndex() * storageWidth * 2;
682 return {readBits(getData(), offset, bitWidth),
683 readBits(getData(), offset + storageWidth, bitWidth)};
684}
685
686//===----------------------------------------------------------------------===//
687// DenseArrayAttr
688//===----------------------------------------------------------------------===//
689
690const bool *DenseArrayBaseAttr::value_begin_impl(OverloadToken<bool>) const {
691 return cast<DenseBoolArrayAttr>().asArrayRef().begin();
692}
693const int8_t *
694DenseArrayBaseAttr::value_begin_impl(OverloadToken<int8_t>) const {
695 return cast<DenseI8ArrayAttr>().asArrayRef().begin();
696}
697const int16_t *
698DenseArrayBaseAttr::value_begin_impl(OverloadToken<int16_t>) const {
699 return cast<DenseI16ArrayAttr>().asArrayRef().begin();
700}
701const int32_t *
702DenseArrayBaseAttr::value_begin_impl(OverloadToken<int32_t>) const {
703 return cast<DenseI32ArrayAttr>().asArrayRef().begin();
704}
705const int64_t *
706DenseArrayBaseAttr::value_begin_impl(OverloadToken<int64_t>) const {
707 return cast<DenseI64ArrayAttr>().asArrayRef().begin();
708}
709const float *DenseArrayBaseAttr::value_begin_impl(OverloadToken<float>) const {
710 return cast<DenseF32ArrayAttr>().asArrayRef().begin();
711}
712const double *
713DenseArrayBaseAttr::value_begin_impl(OverloadToken<double>) const {
714 return cast<DenseF64ArrayAttr>().asArrayRef().begin();
715}
716
717void DenseArrayBaseAttr::print(AsmPrinter &printer) const {
718 print(printer.getStream());
719}
720
721void DenseArrayBaseAttr::printWithoutBraces(raw_ostream &os) const {
722 llvm::TypeSwitch<DenseArrayBaseAttr>(*this)
723 .Case<DenseBoolArrayAttr, DenseI8ArrayAttr, DenseI16ArrayAttr,
724 DenseI32ArrayAttr, DenseI64ArrayAttr, DenseF32ArrayAttr,
725 DenseF64ArrayAttr>([&](auto attr) { attr.printWithoutBraces(os); });
726}
727
728void DenseArrayBaseAttr::print(raw_ostream &os) const {
729 os << "[";
730 printWithoutBraces(os);
731 os << "]";
732}
733
734namespace {
735/// Instantiations of this class provide utilities for interacting with native
736/// data types in the context of DenseArrayAttr.
737template <size_t width,
738 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
739struct DenseArrayAttrIntUtil {
740 static bool checkElementType(Type eltType) {
741 auto type = eltType.dyn_cast<IntegerType>();
742 if (!type || type.getWidth() != width)
743 return false;
744 return type.getSignedness() == signedness;
745 }
746
747 static Type getElementType(MLIRContext *ctx) {
748 return IntegerType::get(ctx, width, signedness);
749 }
750
751 template <typename T>
752 static void printElement(raw_ostream &os, T value) {
753 os << value;
754 }
755
756 template <typename T>
757 static ParseResult parseElement(AsmParser &parser, T &value) {
758 return parser.parseInteger(value);
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Returning without writing to 'value'
759 }
760};
761template <typename T>
762struct DenseArrayAttrUtil;
763
764/// Specialization for boolean elements to print 'true' and 'false' literals for
765/// elements.
766template <>
767struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
768 static void printElement(raw_ostream &os, bool value) {
769 os << (value ? "true" : "false");
770 }
771};
772
773/// Specialization for 8-bit integers to ensure values are printed as integers
774/// and not characters.
775template <>
776struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
777 static void printElement(raw_ostream &os, int8_t value) {
778 os << static_cast<int>(value);
779 }
780};
781template <>
782struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
783template <>
784struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
785template <>
786struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
787
788/// Specialization for 32-bit floats.
789template <>
790struct DenseArrayAttrUtil<float> {
791 static bool checkElementType(Type eltType) { return eltType.isF32(); }
792 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
793 static void printElement(raw_ostream &os, float value) { os << value; }
794
795 /// Parse a double and cast it to a float.
796 static ParseResult parseElement(AsmParser &parser, float &value) {
797 double doubleVal;
798 if (parser.parseFloat(doubleVal))
799 return failure();
800 value = doubleVal;
801 return success();
802 }
803};
804
805/// Specialization for 64-bit floats.
806template <>
807struct DenseArrayAttrUtil<double> {
808 static bool checkElementType(Type eltType) { return eltType.isF64(); }
809 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
810 static void printElement(raw_ostream &os, float value) { os << value; }
811 static ParseResult parseElement(AsmParser &parser, double &value) {
812 return parser.parseFloat(value);
813 }
814};
815} // namespace
816
817template <typename T>
818void DenseArrayAttr<T>::print(AsmPrinter &printer) const {
819 print(printer.getStream());
820}
821
822template <typename T>
823void DenseArrayAttr<T>::printWithoutBraces(raw_ostream &os) const {
824 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
825 DenseArrayAttrUtil<T>::printElement(os, value);
826 });
827}
828
829template <typename T>
830void DenseArrayAttr<T>::print(raw_ostream &os) const {
831 os << "[";
832 printWithoutBraces(os);
833 os << "]";
834}
835
836/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
837template <typename T>
838Attribute DenseArrayAttr<T>::parseWithoutBraces(AsmParser &parser,
839 Type odsType) {
840 SmallVector<T> data;
841 if (failed(parser.parseCommaSeparatedList([&]() {
842 T value;
1
'value' declared without an initial value
843 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
2
Calling 'DenseArrayAttrIntUtil::parseElement'
13
Returning from 'DenseArrayAttrIntUtil::parseElement'
14
Taking false branch
844 return failure();
845 data.push_back(value);
15
1st function call argument is an uninitialized value
846 return success();
847 })))
848 return {};
849 return get(parser.getContext(), data);
850}
851
852/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
853template <typename T>
854Attribute DenseArrayAttr<T>::parse(AsmParser &parser, Type odsType) {
855 if (parser.parseLSquare())
856 return {};
857 // Handle empty list case.
858 if (succeeded(parser.parseOptionalRSquare()))
859 return get(parser.getContext(), {});
860 Attribute result = parseWithoutBraces(parser, odsType);
861 if (parser.parseRSquare())
862 return {};
863 return result;
864}
865
866/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
867template <typename T>
868DenseArrayAttr<T>::operator ArrayRef<T>() const {
869 ArrayRef<char> raw = getRawData();
870 assert((raw.size() % sizeof(T)) == 0)(static_cast <bool> ((raw.size() % sizeof(T)) == 0) ? void
(0) : __assert_fail ("(raw.size() % sizeof(T)) == 0", "mlir/lib/IR/BuiltinAttributes.cpp"
, 870, __extension__ __PRETTY_FUNCTION__))
;
871 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
872 raw.size() / sizeof(T));
873}
874
875/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
876template <typename T>
877DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
878 ArrayRef<T> content) {
879 auto shapedType = RankedTensorType::get(
880 content.size(), DenseArrayAttrUtil<T>::getElementType(context));
881 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
882 content.size() * sizeof(T));
883 return Base::get(context, shapedType, rawArray)
884 .template cast<DenseArrayAttr<T>>();
885}
886
887template <typename T>
888bool DenseArrayAttr<T>::classof(Attribute attr) {
889 if (auto denseArray = attr.dyn_cast<DenseArrayBaseAttr>())
890 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
891 return false;
892}
893
894namespace mlir {
895namespace detail {
896// Explicit instantiation for all the supported DenseArrayAttr.
897template class DenseArrayAttr<bool>;
898template class DenseArrayAttr<int8_t>;
899template class DenseArrayAttr<int16_t>;
900template class DenseArrayAttr<int32_t>;
901template class DenseArrayAttr<int64_t>;
902template class DenseArrayAttr<float>;
903template class DenseArrayAttr<double>;
904} // namespace detail
905} // namespace mlir
906
907//===----------------------------------------------------------------------===//
908// DenseElementsAttr
909//===----------------------------------------------------------------------===//
910
911/// Method for support type inquiry through isa, cast and dyn_cast.
912bool DenseElementsAttr::classof(Attribute attr) {
913 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
914}
915
916DenseElementsAttr DenseElementsAttr::get(ShapedType type,
917 ArrayRef<Attribute> values) {
918 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__
))
;
919
920 // If the element type is not based on int/float/index, assume it is a string
921 // type.
922 Type eltType = type.getElementType();
923 if (!eltType.isIntOrIndexOrFloat()) {
924 SmallVector<StringRef, 8> stringValues;
925 stringValues.reserve(values.size());
926 for (Attribute attr : values) {
927 assert(attr.isa<StringAttr>() &&(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__
))
928 "expected string value for non integer/index/float element")(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__
))
;
929 stringValues.push_back(attr.cast<StringAttr>().getValue());
930 }
931 return get(type, stringValues);
932 }
933
934 // Otherwise, get the raw storage width to use for the allocation.
935 size_t bitWidth = getDenseElementBitWidth(eltType);
936 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
937
938 // Compress the attribute values into a character buffer.
939 SmallVector<char, 8> data(
940 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT8));
941 APInt intVal;
942 for (unsigned i = 0, e = values.size(); i < e; ++i) {
943 if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
944 assert(floatAttr.getType() == eltType &&(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 945, __extension__ __PRETTY_FUNCTION__
))
945 "expected float attribute type to equal element type")(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 945, __extension__ __PRETTY_FUNCTION__
))
;
946 intVal = floatAttr.getValue().bitcastToAPInt();
947 } else {
948 auto intAttr = values[i].cast<IntegerAttr>();
949 assert(intAttr.getType() == eltType &&(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 950, __extension__ __PRETTY_FUNCTION__
))
950 "expected integer attribute type to equal element type")(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 950, __extension__ __PRETTY_FUNCTION__
))
;
951 intVal = intAttr.getValue();
952 }
953
954 assert(intVal.getBitWidth() == bitWidth &&(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 955, __extension__ __PRETTY_FUNCTION__
))
955 "expected value to have same bitwidth as element type")(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 955, __extension__ __PRETTY_FUNCTION__
))
;
956 writeBits(data.data(), i * storageBitWidth, intVal);
957 }
958
959 // Handle the special encoding of splat of bool.
960 if (values.size() == 1 && eltType.isInteger(1))
961 data[0] = data[0] ? -1 : 0;
962
963 return DenseIntOrFPElementsAttr::getRaw(type, data);
964}
965
966DenseElementsAttr DenseElementsAttr::get(ShapedType type,
967 ArrayRef<bool> values) {
968 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 968, __extension__ __PRETTY_FUNCTION__
))
;
969 assert(type.getElementType().isInteger(1))(static_cast <bool> (type.getElementType().isInteger(1)
) ? void (0) : __assert_fail ("type.getElementType().isInteger(1)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 969, __extension__ __PRETTY_FUNCTION__
))
;
970
971 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT8));
972
973 if (!values.empty()) {
974 bool isSplat = true;
975 bool firstValue = values[0];
976 for (int i = 0, e = values.size(); i != e; ++i) {
977 isSplat &= values[i] == firstValue;
978 setBit(buff.data(), i, values[i]);
979 }
980
981 // Splat of bool is encoded as a byte with all-ones in it.
982 if (isSplat) {
983 buff.resize(1);
984 buff[0] = values[0] ? -1 : 0;
985 }
986 }
987
988 return DenseIntOrFPElementsAttr::getRaw(type, buff);
989}
990
991DenseElementsAttr DenseElementsAttr::get(ShapedType type,
992 ArrayRef<StringRef> values) {
993 assert(!type.getElementType().isIntOrFloat())(static_cast <bool> (!type.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("!type.getElementType().isIntOrFloat()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 993, __extension__ __PRETTY_FUNCTION__
))
;
994 return DenseStringElementsAttr::get(type, values);
995}
996
997/// Constructs a dense integer elements attribute from an array of APInt
998/// values. Each APInt value is expected to have the same bitwidth as the
999/// element type of 'type'.
1000DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1001 ArrayRef<APInt> values) {
1002 assert(type.getElementType().isIntOrIndex())(static_cast <bool> (type.getElementType().isIntOrIndex
()) ? void (0) : __assert_fail ("type.getElementType().isIntOrIndex()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1002, __extension__ __PRETTY_FUNCTION__
))
;
1003 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1003, __extension__ __PRETTY_FUNCTION__
))
;
1004 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1005 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1006}
1007DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1008 ArrayRef<std::complex<APInt>> values) {
1009 ComplexType complex = type.getElementType().cast<ComplexType>();
1010 assert(complex.getElementType().isa<IntegerType>())(static_cast <bool> (complex.getElementType().isa<IntegerType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<IntegerType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1010, __extension__ __PRETTY_FUNCTION__
))
;
1011 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1011, __extension__ __PRETTY_FUNCTION__
))
;
1012 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1013 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1014 values.size() * 2);
1015 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1016}
1017
1018// Constructs a dense float elements attribute from an array of APFloat
1019// values. Each APFloat value is expected to have the same bitwidth as the
1020// element type of 'type'.
1021DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1022 ArrayRef<APFloat> values) {
1023 assert(type.getElementType().isa<FloatType>())(static_cast <bool> (type.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("type.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1023, __extension__ __PRETTY_FUNCTION__
))
;
1024 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1024, __extension__ __PRETTY_FUNCTION__
))
;
1025 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1026 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1027}
1028DenseElementsAttr
1029DenseElementsAttr::get(ShapedType type,
1030 ArrayRef<std::complex<APFloat>> values) {
1031 ComplexType complex = type.getElementType().cast<ComplexType>();
1032 assert(complex.getElementType().isa<FloatType>())(static_cast <bool> (complex.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1032, __extension__ __PRETTY_FUNCTION__
))
;
1033 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1033, __extension__ __PRETTY_FUNCTION__
))
;
1034 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1035 values.size() * 2);
1036 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1037 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1038}
1039
1040/// Construct a dense elements attribute from a raw buffer representing the
1041/// data for this attribute. Users should generally not use this methods as
1042/// the expected buffer format may not be a form the user expects.
1043DenseElementsAttr
1044DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1045 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1046}
1047
1048/// Returns true if the given buffer is a valid raw buffer for the given type.
1049bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1050 ArrayRef<char> rawBuffer,
1051 bool &detectedSplat) {
1052 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1053 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT8;
1054 int64_t numElements = type.getNumElements();
1055
1056 // The initializer is always a splat if the result type has a single element.
1057 detectedSplat = numElements == 1;
1058
1059 // Storage width of 1 is special as it is packed by the bit.
1060 if (storageWidth == 1) {
1061 // Check for a splat, or a buffer equal to the number of elements which
1062 // consists of either all 0's or all 1's.
1063 if (rawBuffer.size() == 1) {
1064 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1065 if (rawByte == 0 || rawByte == 0xff) {
1066 detectedSplat = true;
1067 return true;
1068 }
1069 }
1070
1071 // This is a valid non-splat buffer if it has the right size.
1072 return rawBufferWidth == llvm::alignTo<8>(numElements);
1073 }
1074
1075 // All other types are 8-bit aligned, so we can just check the buffer width
1076 // to know if only a single initializer element was passed in.
1077 if (rawBufferWidth == storageWidth) {
1078 detectedSplat = true;
1079 return true;
1080 }
1081
1082 // The raw buffer is valid if it has the right size.
1083 return rawBufferWidth == storageWidth * numElements;
1084}
1085
1086/// Check the information for a C++ data type, check if this type is valid for
1087/// the current attribute. This method is used to verify specific type
1088/// invariants that the templatized 'getValues' method cannot.
1089static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1090 bool isSigned) {
1091 // Make sure that the data element size is the same as the type element width.
1092 if (getDenseElementBitWidth(type) !=
1093 static_cast<size_t>(dataEltSize * CHAR_BIT8))
1094 return false;
1095
1096 // Check that the element type is either float or integer or index.
1097 if (!isInt)
1098 return type.isa<FloatType>();
1099 if (type.isIndex())
1100 return true;
1101
1102 auto intType = type.dyn_cast<IntegerType>();
1103 if (!intType)
1104 return false;
1105
1106 // Make sure signedness semantics is consistent.
1107 if (intType.isSignless())
1108 return true;
1109 return intType.isSigned() ? isSigned : !isSigned;
1110}
1111
1112/// Defaults down the subclass implementation.
1113DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1114 ArrayRef<char> data,
1115 int64_t dataEltSize,
1116 bool isInt, bool isSigned) {
1117 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1118 isSigned);
1119}
1120DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1121 ArrayRef<char> data,
1122 int64_t dataEltSize,
1123 bool isInt,
1124 bool isSigned) {
1125 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1126 isInt, isSigned);
1127}
1128
1129bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1130 bool isSigned) const {
1131 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1132}
1133bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1134 bool isSigned) const {
1135 return ::isValidIntOrFloat(
1136 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1137 isInt, isSigned);
1138}
1139
1140/// Returns true if this attribute corresponds to a splat, i.e. if all element
1141/// values are the same.
1142bool DenseElementsAttr::isSplat() const {
1143 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1144}
1145
1146/// Return if the given complex type has an integer element type.
1147LLVM_ATTRIBUTE_UNUSED__attribute__((__unused__)) static bool isComplexOfIntType(Type type) {
1148 return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1149}
1150
1151auto DenseElementsAttr::getComplexIntValues() const
1152 -> iterator_range_impl<ComplexIntElementIterator> {
1153 assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1154, __extension__ __PRETTY_FUNCTION__
))
1154 "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1154, __extension__ __PRETTY_FUNCTION__
))
;
1155 return {getType(), ComplexIntElementIterator(*this, 0),
1156 ComplexIntElementIterator(*this, getNumElements())};
1157}
1158auto DenseElementsAttr::complex_value_begin() const
1159 -> ComplexIntElementIterator {
1160 assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1161, __extension__ __PRETTY_FUNCTION__
))
1161 "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1161, __extension__ __PRETTY_FUNCTION__
))
;
1162 return ComplexIntElementIterator(*this, 0);
1163}
1164auto DenseElementsAttr::complex_value_end() const -> ComplexIntElementIterator {
1165 assert(isComplexOfIntType(getElementType()) &&(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1166, __extension__ __PRETTY_FUNCTION__
))
1166 "expected complex integral type")(static_cast <bool> (isComplexOfIntType(getElementType(
)) && "expected complex integral type") ? void (0) : __assert_fail
("isComplexOfIntType(getElementType()) && \"expected complex integral type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1166, __extension__ __PRETTY_FUNCTION__
))
;
1167 return ComplexIntElementIterator(*this, getNumElements());
1168}
1169
1170/// Return the held element values as a range of APFloat. The element type of
1171/// this attribute must be of float type.
1172auto DenseElementsAttr::getFloatValues() const
1173 -> iterator_range_impl<FloatElementIterator> {
1174 auto elementType = getElementType().cast<FloatType>();
1175 const auto &elementSemantics = elementType.getFloatSemantics();
1176 return {getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1177 FloatElementIterator(elementSemantics, raw_int_end())};
1178}
1179auto DenseElementsAttr::float_value_begin() const -> FloatElementIterator {
1180 auto elementType = getElementType().cast<FloatType>();
1181 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_begin());
1182}
1183auto DenseElementsAttr::float_value_end() const -> FloatElementIterator {
1184 auto elementType = getElementType().cast<FloatType>();
1185 return FloatElementIterator(elementType.getFloatSemantics(), raw_int_end());
1186}
1187
1188auto DenseElementsAttr::getComplexFloatValues() const
1189 -> iterator_range_impl<ComplexFloatElementIterator> {
1190 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1191 assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() &&
"expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1191, __extension__ __PRETTY_FUNCTION__
))
;
1192 const auto &semantics = eltTy.cast<FloatType>().getFloatSemantics();
1193 return {getType(),
1194 {semantics, {*this, 0}},
1195 {semantics, {*this, static_cast<size_t>(getNumElements())}}};
1196}
1197auto DenseElementsAttr::complex_float_value_begin() const
1198 -> ComplexFloatElementIterator {
1199 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1200 assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() &&
"expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1200, __extension__ __PRETTY_FUNCTION__
))
;
1201 return {eltTy.cast<FloatType>().getFloatSemantics(), {*this, 0}};
1202}
1203auto DenseElementsAttr::complex_float_value_end() const
1204 -> ComplexFloatElementIterator {
1205 Type eltTy = getElementType().cast<ComplexType>().getElementType();
1206 assert(eltTy.isa<FloatType>() && "expected complex float type")(static_cast <bool> (eltTy.isa<FloatType>() &&
"expected complex float type") ? void (0) : __assert_fail ("eltTy.isa<FloatType>() && \"expected complex float type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__
))
;
1207 return {eltTy.cast<FloatType>().getFloatSemantics(),
1208 {*this, static_cast<size_t>(getNumElements())}};
1209}
1210
1211/// Return the raw storage data held by this attribute.
1212ArrayRef<char> DenseElementsAttr::getRawData() const {
1213 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1214}
1215
1216ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1217 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1218}
1219
1220/// Return a new DenseElementsAttr that has the same data as the current
1221/// attribute, but has been reshaped to 'newType'. The new type must have the
1222/// same total number of elements as well as element type.
1223DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1224 ShapedType curType = getType();
1225 if (curType == newType)
1226 return *this;
1227
1228 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1229, __extension__ __PRETTY_FUNCTION__
))
1229 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1229, __extension__ __PRETTY_FUNCTION__
))
;
1230 assert(newType.getNumElements() == curType.getNumElements() &&(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1231, __extension__ __PRETTY_FUNCTION__
))
1231 "expected the same number of elements")(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1231, __extension__ __PRETTY_FUNCTION__
))
;
1232 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1233}
1234
1235DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1236 assert(isSplat() && "expected a splat type")(static_cast <bool> (isSplat() && "expected a splat type"
) ? void (0) : __assert_fail ("isSplat() && \"expected a splat type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1236, __extension__ __PRETTY_FUNCTION__
))
;
1237
1238 ShapedType curType = getType();
1239 if (curType == newType)
1240 return *this;
1241
1242 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1243, __extension__ __PRETTY_FUNCTION__
))
1243 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1243, __extension__ __PRETTY_FUNCTION__
))
;
1244 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1245}
1246
1247/// Return a new DenseElementsAttr that has the same data as the current
1248/// attribute, but has bitcast elements such that it is now 'newType'. The new
1249/// type must have the same shape and element types of the same bitwidth as the
1250/// current type.
1251DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1252 ShapedType curType = getType();
1253 Type curElType = curType.getElementType();
1254 if (curElType == newElType)
1255 return *this;
1256
1257 assert(getDenseElementBitWidth(newElType) ==(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__
))
1258 getDenseElementBitWidth(curElType) &&(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__
))
1259 "expected element types with the same bitwidth")(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1259, __extension__ __PRETTY_FUNCTION__
))
;
1260 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1261 getRawData());
1262}
1263
1264DenseElementsAttr
1265DenseElementsAttr::mapValues(Type newElementType,
1266 function_ref<APInt(const APInt &)> mapping) const {
1267 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1268}
1269
1270DenseElementsAttr DenseElementsAttr::mapValues(
1271 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1272 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1273}
1274
1275ShapedType DenseElementsAttr::getType() const {
1276 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1277}
1278
1279Type DenseElementsAttr::getElementType() const {
1280 return getType().getElementType();
1281}
1282
1283int64_t DenseElementsAttr::getNumElements() const {
1284 return getType().getNumElements();
1285}
1286
1287//===----------------------------------------------------------------------===//
1288// DenseIntOrFPElementsAttr
1289//===----------------------------------------------------------------------===//
1290
1291/// Utility method to write a range of APInt values to a buffer.
1292template <typename APRangeT>
1293static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1294 APRangeT &&values) {
1295 size_t numValues = llvm::size(values);
1296 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT8));
1297 size_t offset = 0;
1298 for (auto it = values.begin(), e = values.end(); it != e;
1299 ++it, offset += storageWidth) {
1300 assert((*it).getBitWidth() <= storageWidth)(static_cast <bool> ((*it).getBitWidth() <= storageWidth
) ? void (0) : __assert_fail ("(*it).getBitWidth() <= storageWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1300, __extension__ __PRETTY_FUNCTION__
))
;
1301 writeBits(data.data(), offset, *it);
1302 }
1303
1304 // Handle the special encoding of splat of a boolean.
1305 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1306 data[0] = data[0] ? -1 : 0;
1307}
1308
1309/// Constructs a dense elements attribute from an array of raw APFloat values.
1310/// Each APFloat value is expected to have the same bitwidth as the element
1311/// type of 'type'. 'type' must be a vector or tensor with static shape.
1312DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1313 size_t storageWidth,
1314 ArrayRef<APFloat> values) {
1315 std::vector<char> data;
1316 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1317 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1318 return DenseIntOrFPElementsAttr::getRaw(type, data);
1319}
1320
1321/// Constructs a dense elements attribute from an array of raw APInt values.
1322/// Each APInt value is expected to have the same bitwidth as the element type
1323/// of 'type'.
1324DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1325 size_t storageWidth,
1326 ArrayRef<APInt> values) {
1327 std::vector<char> data;
1328 writeAPIntsToBuffer(storageWidth, data, values);
1329 return DenseIntOrFPElementsAttr::getRaw(type, data);
1330}
1331
1332DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1333 ArrayRef<char> data) {
1334 assert((type.isa<RankedTensorType, VectorType>()) &&(static_cast <bool> ((type.isa<RankedTensorType, VectorType
>()) && "type must be ranked tensor or vector") ? void
(0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1335, __extension__ __PRETTY_FUNCTION__
))
1335 "type must be ranked tensor or vector")(static_cast <bool> ((type.isa<RankedTensorType, VectorType
>()) && "type must be ranked tensor or vector") ? void
(0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1335, __extension__ __PRETTY_FUNCTION__
))
;
1336 assert(type.hasStaticShape() && "type must have static shape")(static_cast <bool> (type.hasStaticShape() && "type must have static shape"
) ? void (0) : __assert_fail ("type.hasStaticShape() && \"type must have static shape\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1336, __extension__ __PRETTY_FUNCTION__
))
;
1337 bool isSplat = false;
1338 bool isValid = isValidRawBuffer(type, data, isSplat);
1339 assert(isValid)(static_cast <bool> (isValid) ? void (0) : __assert_fail
("isValid", "mlir/lib/IR/BuiltinAttributes.cpp", 1339, __extension__
__PRETTY_FUNCTION__))
;
1340 (void)isValid;
1341 return Base::get(type.getContext(), type, data, isSplat);
1342}
1343
1344/// Overload of the raw 'get' method that asserts that the given type is of
1345/// complex type. This method is used to verify type invariants that the
1346/// templatized 'get' method cannot.
1347DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1348 ArrayRef<char> data,
1349 int64_t dataEltSize,
1350 bool isInt,
1351 bool isSigned) {
1352 assert(::isValidIntOrFloat((static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__
))
1353 type.getElementType().cast<ComplexType>().getElementType(),(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__
))
1354 dataEltSize / 2, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1354, __extension__ __PRETTY_FUNCTION__
))
;
1355
1356 int64_t numElements = data.size() / dataEltSize;
1357 (void)numElements;
1358 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1358, __extension__ __PRETTY_FUNCTION__
))
;
1359 return getRaw(type, data);
1360}
1361
1362/// Overload of the 'getRaw' method that asserts that the given type is of
1363/// integer type. This method is used to verify type invariants that the
1364/// templatized 'get' method cannot.
1365DenseElementsAttr
1366DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1367 int64_t dataEltSize, bool isInt,
1368 bool isSigned) {
1369 assert((static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1370, __extension__ __PRETTY_FUNCTION__
))
1370 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1370, __extension__ __PRETTY_FUNCTION__
))
;
1371
1372 int64_t numElements = data.size() / dataEltSize;
1373 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1373, __extension__ __PRETTY_FUNCTION__
))
;
1374 (void)numElements;
1375 return getRaw(type, data);
1376}
1377
1378void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1379 const char *inRawData, char *outRawData, size_t elementBitWidth,
1380 size_t numElements) {
1381 using llvm::support::ulittle16_t;
1382 using llvm::support::ulittle32_t;
1383 using llvm::support::ulittle64_t;
1384
1385 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1386, __extension__ __PRETTY_FUNCTION__
))
1386 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1386, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
1387 // NOLINT to avoid warning message about replacing by static_assert()
1388
1389 // Following std::copy_n always converts endianness on BE machine.
1390 switch (elementBitWidth) {
1391 case 16: {
1392 const ulittle16_t *inRawDataPos =
1393 reinterpret_cast<const ulittle16_t *>(inRawData);
1394 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1395 std::copy_n(inRawDataPos, numElements, outDataPos);
1396 break;
1397 }
1398 case 32: {
1399 const ulittle32_t *inRawDataPos =
1400 reinterpret_cast<const ulittle32_t *>(inRawData);
1401 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1402 std::copy_n(inRawDataPos, numElements, outDataPos);
1403 break;
1404 }
1405 case 64: {
1406 const ulittle64_t *inRawDataPos =
1407 reinterpret_cast<const ulittle64_t *>(inRawData);
1408 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1409 std::copy_n(inRawDataPos, numElements, outDataPos);
1410 break;
1411 }
1412 default: {
1413 size_t nBytes = elementBitWidth / CHAR_BIT8;
1414 for (size_t i = 0; i < nBytes; i++)
1415 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1416 break;
1417 }
1418 }
1419}
1420
1421void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1422 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1423 ShapedType type) {
1424 size_t numElements = type.getNumElements();
1425 Type elementType = type.getElementType();
1426 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1427 elementType = complexTy.getElementType();
1428 numElements = numElements * 2;
1429 }
1430 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1431 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1432, __extension__ __PRETTY_FUNCTION__
))
1432 inRawData.size() <= outRawData.size())(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1432, __extension__ __PRETTY_FUNCTION__
))
;
1433 if (elementBitWidth <= CHAR_BIT8)
1434 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1435 else
1436 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1437 elementBitWidth, numElements);
1438}
1439
1440//===----------------------------------------------------------------------===//
1441// DenseFPElementsAttr
1442//===----------------------------------------------------------------------===//
1443
1444template <typename Fn, typename Attr>
1445static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1446 Type newElementType,
1447 llvm::SmallVectorImpl<char> &data) {
1448 size_t bitWidth = getDenseElementBitWidth(newElementType);
1449 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1450
1451 ShapedType newArrayType;
1452 if (inType.isa<RankedTensorType>())
1453 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1454 else if (inType.isa<UnrankedTensorType>())
1455 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1456 else if (auto vType = inType.dyn_cast<VectorType>())
1457 newArrayType = VectorType::get(vType.getShape(), newElementType,
1458 vType.getNumScalableDims());
1459 else
1460 assert(newArrayType && "Unhandled tensor type")(static_cast <bool> (newArrayType && "Unhandled tensor type"
) ? void (0) : __assert_fail ("newArrayType && \"Unhandled tensor type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1460, __extension__ __PRETTY_FUNCTION__
))
;
1461
1462 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1463 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT8));
1464
1465 // Functor used to process a single element value of the attribute.
1466 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1467 auto newInt = mapping(value);
1468 assert(newInt.getBitWidth() == bitWidth)(static_cast <bool> (newInt.getBitWidth() == bitWidth) ?
void (0) : __assert_fail ("newInt.getBitWidth() == bitWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1468, __extension__ __PRETTY_FUNCTION__
))
;
1469 writeBits(data.data(), index * storageBitWidth, newInt);
1470 };
1471
1472 // Check for the splat case.
1473 if (attr.isSplat()) {
1474 processElt(*attr.begin(), /*index=*/0);
1475 return newArrayType;
1476 }
1477
1478 // Otherwise, process all of the element values.
1479 uint64_t elementIdx = 0;
1480 for (auto value : attr)
1481 processElt(value, elementIdx++);
1482 return newArrayType;
1483}
1484
1485DenseElementsAttr DenseFPElementsAttr::mapValues(
1486 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1487 llvm::SmallVector<char, 8> elementData;
1488 auto newArrayType =
1489 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1490
1491 return getRaw(newArrayType, elementData);
1492}
1493
1494/// Method for supporting type inquiry through isa, cast and dyn_cast.
1495bool DenseFPElementsAttr::classof(Attribute attr) {
1496 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1497 return denseAttr.getType().getElementType().isa<FloatType>();
1498 return false;
1499}
1500
1501//===----------------------------------------------------------------------===//
1502// DenseIntElementsAttr
1503//===----------------------------------------------------------------------===//
1504
1505DenseElementsAttr DenseIntElementsAttr::mapValues(
1506 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1507 llvm::SmallVector<char, 8> elementData;
1508 auto newArrayType =
1509 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1510 return getRaw(newArrayType, elementData);
1511}
1512
1513/// Method for supporting type inquiry through isa, cast and dyn_cast.
1514bool DenseIntElementsAttr::classof(Attribute attr) {
1515 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1516 return denseAttr.getType().getElementType().isIntOrIndex();
1517 return false;
1518}
1519
1520//===----------------------------------------------------------------------===//
1521// DenseResourceElementsAttr
1522//===----------------------------------------------------------------------===//
1523
1524DenseResourceElementsAttr
1525DenseResourceElementsAttr::get(ShapedType type,
1526 DenseResourceElementsHandle handle) {
1527 return Base::get(type.getContext(), type, handle);
1528}
1529
1530DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1531 StringRef blobName,
1532 AsmResourceBlob blob) {
1533 // Extract the builtin dialect resource manager from context and construct a
1534 // handle by inserting a new resource using the provided blob.
1535 auto &manager =
1536 DenseResourceElementsHandle::getManagerInterface(type.getContext());
1537 return get(type, manager.insert(blobName, std::move(blob)));
1538}
1539
1540//===----------------------------------------------------------------------===//
1541// DenseResourceElementsAttrBase
1542
1543namespace {
1544/// Instantiations of this class provide utilities for interacting with native
1545/// data types in the context of DenseResourceElementsAttr.
1546template <typename T>
1547struct DenseResourceAttrUtil;
1548template <size_t width, bool isSigned>
1549struct DenseResourceElementsAttrIntUtil {
1550 static bool checkElementType(Type eltType) {
1551 IntegerType type = eltType.dyn_cast<IntegerType>();
1552 if (!type || type.getWidth() != width)
1553 return false;
1554 return isSigned ? !type.isUnsigned() : !type.isSigned();
1555 }
1556};
1557template <>
1558struct DenseResourceAttrUtil<bool> {
1559 static bool checkElementType(Type eltType) {
1560 return eltType.isSignlessInteger(1);
1561 }
1562};
1563template <>
1564struct DenseResourceAttrUtil<int8_t>
1565 : public DenseResourceElementsAttrIntUtil<8, true> {};
1566template <>
1567struct DenseResourceAttrUtil<uint8_t>
1568 : public DenseResourceElementsAttrIntUtil<8, false> {};
1569template <>
1570struct DenseResourceAttrUtil<int16_t>
1571 : public DenseResourceElementsAttrIntUtil<16, true> {};
1572template <>
1573struct DenseResourceAttrUtil<uint16_t>
1574 : public DenseResourceElementsAttrIntUtil<16, false> {};
1575template <>
1576struct DenseResourceAttrUtil<int32_t>
1577 : public DenseResourceElementsAttrIntUtil<32, true> {};
1578template <>
1579struct DenseResourceAttrUtil<uint32_t>
1580 : public DenseResourceElementsAttrIntUtil<32, false> {};
1581template <>
1582struct DenseResourceAttrUtil<int64_t>
1583 : public DenseResourceElementsAttrIntUtil<64, true> {};
1584template <>
1585struct DenseResourceAttrUtil<uint64_t>
1586 : public DenseResourceElementsAttrIntUtil<64, false> {};
1587template <>
1588struct DenseResourceAttrUtil<float> {
1589 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1590};
1591template <>
1592struct DenseResourceAttrUtil<double> {
1593 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1594};
1595} // namespace
1596
1597template <typename T>
1598DenseResourceElementsAttrBase<T>
1599DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1600 AsmResourceBlob blob) {
1601 // Check that the blob is in the form we were expecting.
1602 assert(blob.getDataAlignment() == alignof(T) &&(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1603, __extension__ __PRETTY_FUNCTION__
))
1603 "alignment mismatch between expected alignment and blob alignment")(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1603, __extension__ __PRETTY_FUNCTION__
))
;
1604 assert(((blob.getData().size() % sizeof(T)) == 0) &&(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1605, __extension__ __PRETTY_FUNCTION__
))
1605 "size mismatch between expected element width and blob size")(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1605, __extension__ __PRETTY_FUNCTION__
))
;
1606 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1607, __extension__ __PRETTY_FUNCTION__
))
1607 "invalid shape element type for provided type `T`")(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1607, __extension__ __PRETTY_FUNCTION__
))
;
1608 return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
1609 .template cast<DenseResourceElementsAttrBase<T>>();
1610}
1611
1612template <typename T>
1613Optional<ArrayRef<T>>
1614DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1615 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1616 return blob->template getDataAs<T>();
1617 return llvm::None;
1618}
1619
1620template <typename T>
1621bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1622 auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
1623 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1624 resourceAttr.getElementType());
1625}
1626
1627namespace mlir {
1628namespace detail {
1629// Explicit instantiation for all the supported DenseResourceElementsAttr.
1630template class DenseResourceElementsAttrBase<bool>;
1631template class DenseResourceElementsAttrBase<int8_t>;
1632template class DenseResourceElementsAttrBase<int16_t>;
1633template class DenseResourceElementsAttrBase<int32_t>;
1634template class DenseResourceElementsAttrBase<int64_t>;
1635template class DenseResourceElementsAttrBase<uint8_t>;
1636template class DenseResourceElementsAttrBase<uint16_t>;
1637template class DenseResourceElementsAttrBase<uint32_t>;
1638template class DenseResourceElementsAttrBase<uint64_t>;
1639template class DenseResourceElementsAttrBase<float>;
1640template class DenseResourceElementsAttrBase<double>;
1641} // namespace detail
1642} // namespace mlir
1643
1644//===----------------------------------------------------------------------===//
1645// SparseElementsAttr
1646//===----------------------------------------------------------------------===//
1647
1648/// Get a zero APFloat for the given sparse attribute.
1649APFloat SparseElementsAttr::getZeroAPFloat() const {
1650 auto eltType = getElementType().cast<FloatType>();
1651 return APFloat(eltType.getFloatSemantics());
1652}
1653
1654/// Get a zero APInt for the given sparse attribute.
1655APInt SparseElementsAttr::getZeroAPInt() const {
1656 auto eltType = getElementType().cast<IntegerType>();
1657 return APInt::getZero(eltType.getWidth());
1658}
1659
1660/// Get a zero attribute for the given attribute type.
1661Attribute SparseElementsAttr::getZeroAttr() const {
1662 auto eltType = getElementType();
1663
1664 // Handle floating point elements.
1665 if (eltType.isa<FloatType>())
1666 return FloatAttr::get(eltType, 0);
1667
1668 // Handle complex elements.
1669 if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1670 auto eltType = complexTy.getElementType();
1671 Attribute zero;
1672 if (eltType.isa<FloatType>())
1673 zero = FloatAttr::get(eltType, 0);
1674 else // must be integer
1675 zero = IntegerAttr::get(eltType, 0);
1676 return ArrayAttr::get(complexTy.getContext(),
1677 ArrayRef<Attribute>{zero, zero});
1678 }
1679
1680 // Handle string type.
1681 if (getValues().isa<DenseStringElementsAttr>())
1682 return StringAttr::get("", eltType);
1683
1684 // Otherwise, this is an integer.
1685 return IntegerAttr::get(eltType, 0);
1686}
1687
1688/// Flatten, and return, all of the sparse indices in this attribute in
1689/// row-major order.
1690std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1691 std::vector<ptrdiff_t> flatSparseIndices;
1692
1693 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1694 // as a 1-D index array.
1695 auto sparseIndices = getIndices();
1696 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1697 if (sparseIndices.isSplat()) {
1698 SmallVector<uint64_t, 8> indices(getType().getRank(),
1699 *sparseIndexValues.begin());
1700 flatSparseIndices.push_back(getFlattenedIndex(indices));
1701 return flatSparseIndices;
1702 }
1703
1704 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1705 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1706 size_t rank = getType().getRank();
1707 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1708 flatSparseIndices.push_back(getFlattenedIndex(
1709 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1710 return flatSparseIndices;
1711}
1712
1713LogicalResult
1714SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1715 ShapedType type, DenseIntElementsAttr sparseIndices,
1716 DenseElementsAttr values) {
1717 ShapedType valuesType = values.getType();
1718 if (valuesType.getRank() != 1)
1719 return emitError() << "expected 1-d tensor for sparse element values";
1720
1721 // Verify the indices and values shape.
1722 ShapedType indicesType = sparseIndices.getType();
1723 auto emitShapeError = [&]() {
1724 return emitError() << "expected shape ([" << type.getShape()
1725 << "]); inferred shape of indices literal (["
1726 << indicesType.getShape()
1727 << "]); inferred shape of values literal (["
1728 << valuesType.getShape() << "])";
1729 };
1730 // Verify indices shape.
1731 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1732 if (indicesRank == 2) {
1733 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1734 return emitShapeError();
1735 } else if (indicesRank != 1 || rank != 1) {
1736 return emitShapeError();
1737 }
1738 // Verify the values shape.
1739 int64_t numSparseIndices = indicesType.getDimSize(0);
1740 if (numSparseIndices != valuesType.getDimSize(0))
1741 return emitShapeError();
1742
1743 // Verify that the sparse indices are within the value shape.
1744 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1745 return emitError()
1746 << "sparse index #" << indexNum
1747 << " is not contained within the value shape, with index=[" << index
1748 << "], and type=" << type;
1749 };
1750
1751 // Handle the case where the index values are a splat.
1752 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1753 if (sparseIndices.isSplat()) {
1754 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1755 if (!ElementsAttr::isValidIndex(type, indices))
1756 return emitIndexError(0, indices);
1757 return success();
1758 }
1759
1760 // Otherwise, reinterpret each index as an ArrayRef.
1761 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1762 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1763 rank);
1764 if (!ElementsAttr::isValidIndex(type, index))
1765 return emitIndexError(i, index);
1766 }
1767
1768 return success();
1769}
1770
1771//===----------------------------------------------------------------------===//
1772// TypeAttr
1773//===----------------------------------------------------------------------===//
1774
1775void TypeAttr::walkImmediateSubElements(
1776 function_ref<void(Attribute)> walkAttrsFn,
1777 function_ref<void(Type)> walkTypesFn) const {
1778 walkTypesFn(getValue());
1779}
1780
1781Attribute
1782TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
1783 ArrayRef<Type> replTypes) const {
1784 return get(replTypes[0]);
1785}

/build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() {}
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <
276 typename AsmPrinterT, typename T,
277 typename std::enable_if<!std::is_convertible<T &, Value &>::value &&
278 !std::is_convertible<T &, Type &>::value &&
279 !std::is_convertible<T &, Attribute &>::value &&
280 !std::is_convertible<T &, ValueRange>::value &&
281 !std::is_convertible<T &, APFloat &>::value &&
282 !llvm::is_one_of<T, bool, float, double>::value,
283 T>::type * = nullptr>
284inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
285 AsmPrinterT &>
286operator<<(AsmPrinterT &p, const T &other) {
287 p.getStream() << other;
288 return p;
289}
290
291template <typename AsmPrinterT>
292inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
293 AsmPrinterT &>
294operator<<(AsmPrinterT &p, bool value) {
295 return p << (value ? StringRef("true") : "false");
296}
297
298template <typename AsmPrinterT, typename ValueRangeT>
299inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
300 AsmPrinterT &>
301operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
302 llvm::interleaveComma(types, p);
303 return p;
304}
305template <typename AsmPrinterT>
306inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
307 AsmPrinterT &>
308operator<<(AsmPrinterT &p, const TypeRange &types) {
309 llvm::interleaveComma(types, p);
310 return p;
311}
312template <typename AsmPrinterT, typename ElementT>
313inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
314 AsmPrinterT &>
315operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
316 llvm::interleaveComma(types, p);
317 return p;
318}
319
320//===----------------------------------------------------------------------===//
321// OpAsmPrinter
322//===----------------------------------------------------------------------===//
323
324/// This is a pure-virtual base class that exposes the asmprinter hooks
325/// necessary to implement a custom print() method.
326class OpAsmPrinter : public AsmPrinter {
327public:
328 using AsmPrinter::AsmPrinter;
329 ~OpAsmPrinter() override;
330
331 /// Print a newline and indent the printer to the start of the current
332 /// operation.
333 virtual void printNewline() = 0;
334
335 /// Print a block argument in the usual format of:
336 /// %ssaName : type {attr1=42} loc("here")
337 /// where location printing is controlled by the standard internal option.
338 /// You may pass omitType=true to not print a type, and pass an empty
339 /// attribute list if you don't care for attributes.
340 virtual void printRegionArgument(BlockArgument arg,
341 ArrayRef<NamedAttribute> argAttrs = {},
342 bool omitType = false) = 0;
343
344 /// Print implementations for various things an operation contains.
345 virtual void printOperand(Value value) = 0;
346 virtual void printOperand(Value value, raw_ostream &os) = 0;
347
348 /// Print a comma separated list of operands.
349 template <typename ContainerType>
350 void printOperands(const ContainerType &container) {
351 printOperands(container.begin(), container.end());
352 }
353
354 /// Print a comma separated list of operands.
355 template <typename IteratorType>
356 void printOperands(IteratorType it, IteratorType end) {
357 if (it == end)
358 return;
359 printOperand(*it);
360 for (++it; it != end; ++it) {
361 getStream() << ", ";
362 printOperand(*it);
363 }
364 }
365
366 /// Print the given successor.
367 virtual void printSuccessor(Block *successor) = 0;
368
369 /// Print the successor and its operands.
370 virtual void printSuccessorAndUseList(Block *successor,
371 ValueRange succOperands) = 0;
372
373 /// If the specified operation has attributes, print out an attribute
374 /// dictionary with their values. elidedAttrs allows the client to ignore
375 /// specific well known attributes, commonly used if the attribute value is
376 /// printed some other way (like as a fixed operand).
377 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
378 ArrayRef<StringRef> elidedAttrs = {}) = 0;
379
380 /// If the specified operation has attributes, print out an attribute
381 /// dictionary prefixed with 'attributes'.
382 virtual void
383 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
384 ArrayRef<StringRef> elidedAttrs = {}) = 0;
385
386 /// Print the entire operation with the default generic assembly form.
387 /// If `printOpName` is true, then the operation name is printed (the default)
388 /// otherwise it is omitted and the print will start with the operand list.
389 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
390
391 /// Prints a region.
392 /// If 'printEntryBlockArgs' is false, the arguments of the
393 /// block are not printed. If 'printBlockTerminator' is false, the terminator
394 /// operation of the block is not printed. If printEmptyBlock is true, then
395 /// the block header is printed even if the block is empty.
396 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
397 bool printBlockTerminators = true,
398 bool printEmptyBlock = false) = 0;
399
400 /// Renumber the arguments for the specified region to the same names as the
401 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
402 /// operations. If any entry in namesToUse is null, the corresponding
403 /// argument name is left alone.
404 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
405
406 /// Prints an affine map of SSA ids, where SSA id names are used in place
407 /// of dims/symbols.
408 /// Operand values must come from single-result sources, and be valid
409 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
410 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
411 ValueRange operands) = 0;
412
413 /// Prints an affine expression of SSA ids with SSA id names used instead of
414 /// dims and symbols.
415 /// Operand values must come from single-result sources, and be valid
416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
418 ValueRange symOperands) = 0;
419
420 /// Print the complete type of an operation in functional form.
421 void printFunctionalType(Operation *op);
422 using AsmPrinter::printFunctionalType;
423};
424
425// Make the implementations convenient to use.
426inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
427 p.printOperand(value);
428 return p;
429}
430
431template <typename T,
432 typename std::enable_if<std::is_convertible<T &, ValueRange>::value &&
433 !std::is_convertible<T &, Value &>::value,
434 T>::type * = nullptr>
435inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
436 p.printOperands(values);
437 return p;
438}
439
440inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
441 p.printSuccessor(value);
442 return p;
443}
444
445//===----------------------------------------------------------------------===//
446// AsmParser
447//===----------------------------------------------------------------------===//
448
449/// This base class exposes generic asm parser hooks, usable across the various
450/// derived parsers.
451class AsmParser {
452public:
453 AsmParser() = default;
454 virtual ~AsmParser();
455
456 MLIRContext *getContext() const;
457
458 /// Return the location of the original name token.
459 virtual SMLoc getNameLoc() const = 0;
460
461 //===--------------------------------------------------------------------===//
462 // Utilities
463 //===--------------------------------------------------------------------===//
464
465 /// Emit a diagnostic at the specified location and return failure.
466 virtual InFlightDiagnostic emitError(SMLoc loc,
467 const Twine &message = {}) = 0;
468
469 /// Return a builder which provides useful access to MLIRContext, global
470 /// objects like types and attributes.
471 virtual Builder &getBuilder() const = 0;
472
473 /// Get the location of the next token and store it into the argument. This
474 /// always succeeds.
475 virtual SMLoc getCurrentLocation() = 0;
476 ParseResult getCurrentLocation(SMLoc *loc) {
477 *loc = getCurrentLocation();
478 return success();
479 }
480
481 /// Re-encode the given source location as an MLIR location and return it.
482 /// Note: This method should only be used when a `Location` is necessary, as
483 /// the encoding process is not efficient.
484 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
485
486 //===--------------------------------------------------------------------===//
487 // Token Parsing
488 //===--------------------------------------------------------------------===//
489
490 /// Parse a '->' token.
491 virtual ParseResult parseArrow() = 0;
492
493 /// Parse a '->' token if present
494 virtual ParseResult parseOptionalArrow() = 0;
495
496 /// Parse a `{` token.
497 virtual ParseResult parseLBrace() = 0;
498
499 /// Parse a `{` token if present.
500 virtual ParseResult parseOptionalLBrace() = 0;
501
502 /// Parse a `}` token.
503 virtual ParseResult parseRBrace() = 0;
504
505 /// Parse a `}` token if present.
506 virtual ParseResult parseOptionalRBrace() = 0;
507
508 /// Parse a `:` token.
509 virtual ParseResult parseColon() = 0;
510
511 /// Parse a `:` token if present.
512 virtual ParseResult parseOptionalColon() = 0;
513
514 /// Parse a `,` token.
515 virtual ParseResult parseComma() = 0;
516
517 /// Parse a `,` token if present.
518 virtual ParseResult parseOptionalComma() = 0;
519
520 /// Parse a `=` token.
521 virtual ParseResult parseEqual() = 0;
522
523 /// Parse a `=` token if present.
524 virtual ParseResult parseOptionalEqual() = 0;
525
526 /// Parse a '<' token.
527 virtual ParseResult parseLess() = 0;
528
529 /// Parse a '<' token if present.
530 virtual ParseResult parseOptionalLess() = 0;
531
532 /// Parse a '>' token.
533 virtual ParseResult parseGreater() = 0;
534
535 /// Parse a '>' token if present.
536 virtual ParseResult parseOptionalGreater() = 0;
537
538 /// Parse a '?' token.
539 virtual ParseResult parseQuestion() = 0;
540
541 /// Parse a '?' token if present.
542 virtual ParseResult parseOptionalQuestion() = 0;
543
544 /// Parse a '+' token.
545 virtual ParseResult parsePlus() = 0;
546
547 /// Parse a '+' token if present.
548 virtual ParseResult parseOptionalPlus() = 0;
549
550 /// Parse a '*' token.
551 virtual ParseResult parseStar() = 0;
552
553 /// Parse a '*' token if present.
554 virtual ParseResult parseOptionalStar() = 0;
555
556 /// Parse a '|' token.
557 virtual ParseResult parseVerticalBar() = 0;
558
559 /// Parse a '|' token if present.
560 virtual ParseResult parseOptionalVerticalBar() = 0;
561
562 /// Parse a quoted string token.
563 ParseResult parseString(std::string *string) {
564 auto loc = getCurrentLocation();
565 if (parseOptionalString(string))
566 return emitError(loc, "expected string");
567 return success();
568 }
569
570 /// Parse a quoted string token if present.
571 virtual ParseResult parseOptionalString(std::string *string) = 0;
572
573 /// Parse a `(` token.
574 virtual ParseResult parseLParen() = 0;
575
576 /// Parse a `(` token if present.
577 virtual ParseResult parseOptionalLParen() = 0;
578
579 /// Parse a `)` token.
580 virtual ParseResult parseRParen() = 0;
581
582 /// Parse a `)` token if present.
583 virtual ParseResult parseOptionalRParen() = 0;
584
585 /// Parse a `[` token.
586 virtual ParseResult parseLSquare() = 0;
587
588 /// Parse a `[` token if present.
589 virtual ParseResult parseOptionalLSquare() = 0;
590
591 /// Parse a `]` token.
592 virtual ParseResult parseRSquare() = 0;
593
594 /// Parse a `]` token if present.
595 virtual ParseResult parseOptionalRSquare() = 0;
596
597 /// Parse a `...` token if present;
598 virtual ParseResult parseOptionalEllipsis() = 0;
599
600 /// Parse a floating point value from the stream.
601 virtual ParseResult parseFloat(double &result) = 0;
602
603 /// Parse an integer value from the stream.
604 template <typename IntT>
605 ParseResult parseInteger(IntT &result) {
606 auto loc = getCurrentLocation();
607 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
608 if (!parseResult.has_value())
9
Taking true branch
609 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
610 return *parseResult;
611 }
612
613 /// Parse an optional integer value from the stream.
614 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
615
616 template <typename IntT>
617 OptionalParseResult parseOptionalInteger(IntT &result) {
618 auto loc = getCurrentLocation();
619
620 // Parse the unsigned variant.
621 APInt uintResult;
622 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
623 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
624 return parseResult;
7
Returning without writing to 'result'
625
626 // Try to convert to the provided integer type. sextOrTrunc is correct even
627 // for unsigned types because parseOptionalInteger ensures the sign bit is
628 // zero for non-negated integers.
629 result =
630 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
631 if (APInt(uintResult.getBitWidth(), result) != uintResult)
632 return emitError(loc, "integer value too large");
633 return success();
634 }
635
636 /// These are the supported delimiters around operand lists and region
637 /// argument lists, used by parseOperandList.
638 enum class Delimiter {
639 /// Zero or more operands with no delimiters.
640 None,
641 /// Parens surrounding zero or more operands.
642 Paren,
643 /// Square brackets surrounding zero or more operands.
644 Square,
645 /// <> brackets surrounding zero or more operands.
646 LessGreater,
647 /// {} brackets surrounding zero or more operands.
648 Braces,
649 /// Parens supporting zero or more operands, or nothing.
650 OptionalParen,
651 /// Square brackets supporting zero or more ops, or nothing.
652 OptionalSquare,
653 /// <> brackets supporting zero or more ops, or nothing.
654 OptionalLessGreater,
655 /// {} brackets surrounding zero or more operands, or nothing.
656 OptionalBraces,
657 };
658
659 /// Parse a list of comma-separated items with an optional delimiter. If a
660 /// delimiter is provided, then an empty list is allowed. If not, then at
661 /// least one element will be parsed.
662 ///
663 /// contextMessage is an optional message appended to "expected '('" sorts of
664 /// diagnostics when parsing the delimeters.
665 virtual ParseResult
666 parseCommaSeparatedList(Delimiter delimiter,
667 function_ref<ParseResult()> parseElementFn,
668 StringRef contextMessage = StringRef()) = 0;
669
670 /// Parse a comma separated list of elements that must have at least one entry
671 /// in it.
672 ParseResult
673 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
674 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
675 }
676
677 //===--------------------------------------------------------------------===//
678 // Keyword Parsing
679 //===--------------------------------------------------------------------===//
680
681 /// This class represents a StringSwitch like class that is useful for parsing
682 /// expected keywords. On construction, it invokes `parseKeyword` and
683 /// processes each of the provided cases statements until a match is hit. The
684 /// provided `ResultT` must be assignable from `failure()`.
685 template <typename ResultT = ParseResult>
686 class KeywordSwitch {
687 public:
688 KeywordSwitch(AsmParser &parser)
689 : parser(parser), loc(parser.getCurrentLocation()) {
690 if (failed(parser.parseKeywordOrCompletion(&keyword)))
691 result = failure();
692 }
693
694 /// Case that uses the provided value when true.
695 KeywordSwitch &Case(StringLiteral str, ResultT value) {
696 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
697 }
698 KeywordSwitch &Default(ResultT value) {
699 return Default([&](StringRef, SMLoc) { return std::move(value); });
700 }
701 /// Case that invokes the provided functor when true. The parameters passed
702 /// to the functor are the keyword, and the location of the keyword (in case
703 /// any errors need to be emitted).
704 template <typename FnT>
705 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
706 Case(StringLiteral str, FnT &&fn) {
707 if (result)
708 return *this;
709
710 // If the word was empty, record this as a completion.
711 if (keyword.empty())
712 parser.codeCompleteExpectedTokens(str);
713 else if (keyword == str)
714 result.emplace(std::move(fn(keyword, loc)));
715 return *this;
716 }
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Default(FnT &&fn) {
720 if (!result)
721 result.emplace(fn(keyword, loc));
722 return *this;
723 }
724
725 /// Returns true if this switch has a value yet.
726 bool hasValue() const { return result.has_value(); }
727
728 /// Return the result of the switch.
729 [[nodiscard]] operator ResultT() {
730 if (!result)
731 return parser.emitError(loc, "unexpected keyword: ") << keyword;
732 return std::move(*result);
733 }
734
735 private:
736 /// The parser used to construct this switch.
737 AsmParser &parser;
738
739 /// The location of the keyword, used to emit errors as necessary.
740 SMLoc loc;
741
742 /// The parsed keyword itself.
743 StringRef keyword;
744
745 /// The result of the switch statement or none if currently unknown.
746 Optional<ResultT> result;
747 };
748
749 /// Parse a given keyword.
750 ParseResult parseKeyword(StringRef keyword) {
751 return parseKeyword(keyword, "");
752 }
753 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
754
755 /// Parse a keyword into 'keyword'.
756 ParseResult parseKeyword(StringRef *keyword) {
757 auto loc = getCurrentLocation();
758 if (parseOptionalKeyword(keyword))
759 return emitError(loc, "expected valid keyword");
760 return success();
761 }
762
763 /// Parse the given keyword if present.
764 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
765
766 /// Parse a keyword, if present, into 'keyword'.
767 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
768
769 /// Parse a keyword, if present, and if one of the 'allowedValues',
770 /// into 'keyword'
771 virtual ParseResult
772 parseOptionalKeyword(StringRef *keyword,
773 ArrayRef<StringRef> allowedValues) = 0;
774
775 /// Parse a keyword or a quoted string.
776 ParseResult parseKeywordOrString(std::string *result) {
777 if (failed(parseOptionalKeywordOrString(result)))
778 return emitError(getCurrentLocation())
779 << "expected valid keyword or string";
780 return success();
781 }
782
783 /// Parse an optional keyword or string.
784 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
785
786 //===--------------------------------------------------------------------===//
787 // Attribute/Type Parsing
788 //===--------------------------------------------------------------------===//
789
790 /// Invoke the `getChecked` method of the given Attribute or Type class, using
791 /// the provided location to emit errors in the case of failure. Note that
792 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
793 /// context parameter.
794 template <typename T, typename... ParamsT>
795 auto getChecked(SMLoc loc, ParamsT &&...params) {
796 return T::getChecked([&] { return emitError(loc); },
797 std::forward<ParamsT>(params)...);
798 }
799 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
800 /// errors.
801 template <typename T, typename... ParamsT>
802 auto getChecked(ParamsT &&...params) {
803 return T::getChecked([&] { return emitError(getNameLoc()); },
804 std::forward<ParamsT>(params)...);
805 }
806
807 //===--------------------------------------------------------------------===//
808 // Attribute Parsing
809 //===--------------------------------------------------------------------===//
810
811 /// Parse an arbitrary attribute of a given type and return it in result.
812 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
813
814 /// Parse a custom attribute with the provided callback, unless the next
815 /// token is `#`, in which case the generic parser is invoked.
816 virtual ParseResult parseCustomAttributeWithFallback(
817 Attribute &result, Type type,
818 function_ref<ParseResult(Attribute &result, Type type)>
819 parseAttribute) = 0;
820
821 /// Parse an attribute of a specific kind and type.
822 template <typename AttrType>
823 ParseResult parseAttribute(AttrType &result, Type type = {}) {
824 SMLoc loc = getCurrentLocation();
825
826 // Parse any kind of attribute.
827 Attribute attr;
828 if (parseAttribute(attr, type))
829 return failure();
830
831 // Check for the right kind of attribute.
832 if (!(result = attr.dyn_cast<AttrType>()))
833 return emitError(loc, "invalid kind of attribute specified");
834
835 return success();
836 }
837
838 /// Parse an arbitrary attribute and return it in result. This also adds the
839 /// attribute to the specified attribute list with the specified name.
840 ParseResult parseAttribute(Attribute &result, StringRef attrName,
841 NamedAttrList &attrs) {
842 return parseAttribute(result, Type(), attrName, attrs);
843 }
844
845 /// Parse an attribute of a specific kind and type.
846 template <typename AttrType>
847 ParseResult parseAttribute(AttrType &result, StringRef attrName,
848 NamedAttrList &attrs) {
849 return parseAttribute(result, Type(), attrName, attrs);
850 }
851
852 /// Parse an arbitrary attribute of a given type and populate it in `result`.
853 /// This also adds the attribute to the specified attribute list with the
854 /// specified name.
855 template <typename AttrType>
856 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
857 NamedAttrList &attrs) {
858 SMLoc loc = getCurrentLocation();
859
860 // Parse any kind of attribute.
861 Attribute attr;
862 if (parseAttribute(attr, type))
863 return failure();
864
865 // Check for the right kind of attribute.
866 result = attr.dyn_cast<AttrType>();
867 if (!result)
868 return emitError(loc, "invalid kind of attribute specified");
869
870 attrs.append(attrName, result);
871 return success();
872 }
873
874 /// Trait to check if `AttrType` provides a `parse` method.
875 template <typename AttrType>
876 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
877 std::declval<Type>()));
878 template <typename AttrType>
879 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
880
881 /// Parse a custom attribute of a given type unless the next token is `#`, in
882 /// which case the generic parser is invoked. The parsed attribute is
883 /// populated in `result` and also added to the specified attribute list with
884 /// the specified name.
885 template <typename AttrType>
886 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
887 parseCustomAttributeWithFallback(AttrType &result, Type type,
888 StringRef attrName, NamedAttrList &attrs) {
889 SMLoc loc = getCurrentLocation();
890
891 // Parse any kind of attribute.
892 Attribute attr;
893 if (parseCustomAttributeWithFallback(
894 attr, type, [&](Attribute &result, Type type) -> ParseResult {
895 result = AttrType::parse(*this, type);
896 if (!result)
897 return failure();
898 return success();
899 }))
900 return failure();
901
902 // Check for the right kind of attribute.
903 result = attr.dyn_cast<AttrType>();
904 if (!result)
905 return emitError(loc, "invalid kind of attribute specified");
906
907 attrs.append(attrName, result);
908 return success();
909 }
910
911 /// SFINAE parsing method for Attribute that don't implement a parse method.
912 template <typename AttrType>
913 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
914 parseCustomAttributeWithFallback(AttrType &result, Type type,
915 StringRef attrName, NamedAttrList &attrs) {
916 return parseAttribute(result, type, attrName, attrs);
917 }
918
919 /// Parse a custom attribute of a given type unless the next token is `#`, in
920 /// which case the generic parser is invoked. The parsed attribute is
921 /// populated in `result`.
922 template <typename AttrType>
923 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
924 parseCustomAttributeWithFallback(AttrType &result) {
925 SMLoc loc = getCurrentLocation();
926
927 // Parse any kind of attribute.
928 Attribute attr;
929 if (parseCustomAttributeWithFallback(
930 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
931 result = AttrType::parse(*this, type);
932 return success(!!result);
933 }))
934 return failure();
935
936 // Check for the right kind of attribute.
937 result = attr.dyn_cast<AttrType>();
938 if (!result)
939 return emitError(loc, "invalid kind of attribute specified");
940 return success();
941 }
942
943 /// SFINAE parsing method for Attribute that don't implement a parse method.
944 template <typename AttrType>
945 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
946 parseCustomAttributeWithFallback(AttrType &result) {
947 return parseAttribute(result);
948 }
949
950 /// Parse an arbitrary optional attribute of a given type and return it in
951 /// result.
952 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
953 Type type = {}) = 0;
954
955 /// Parse an optional array attribute and return it in result.
956 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
957 Type type = {}) = 0;
958
959 /// Parse an optional string attribute and return it in result.
960 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
961 Type type = {}) = 0;
962
963 /// Parse an optional attribute of a specific type and add it to the list with
964 /// the specified name.
965 template <typename AttrType>
966 OptionalParseResult parseOptionalAttribute(AttrType &result,
967 StringRef attrName,
968 NamedAttrList &attrs) {
969 return parseOptionalAttribute(result, Type(), attrName, attrs);
970 }
971
972 /// Parse an optional attribute of a specific type and add it to the list with
973 /// the specified name.
974 template <typename AttrType>
975 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
976 StringRef attrName,
977 NamedAttrList &attrs) {
978 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
979 if (parseResult.has_value() && succeeded(*parseResult))
980 attrs.append(attrName, result);
981 return parseResult;
982 }
983
984 /// Parse a named dictionary into 'result' if it is present.
985 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
986
987 /// Parse a named dictionary into 'result' if the `attributes` keyword is
988 /// present.
989 virtual ParseResult
990 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
991
992 /// Parse an affine map instance into 'map'.
993 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
994
995 /// Parse an integer set instance into 'set'.
996 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
997
998 //===--------------------------------------------------------------------===//
999 // Identifier Parsing
1000 //===--------------------------------------------------------------------===//
1001
1002 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1003 /// attribute named 'attrName'.
1004 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1005 NamedAttrList &attrs) {
1006 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1007 return emitError(getCurrentLocation())
1008 << "expected valid '@'-identifier for symbol name";
1009 return success();
1010 }
1011
1012 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1013 /// string attribute named 'attrName'.
1014 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1015 StringRef attrName,
1016 NamedAttrList &attrs) = 0;
1017
1018 //===--------------------------------------------------------------------===//
1019 // Resource Parsing
1020 //===--------------------------------------------------------------------===//
1021
1022 /// Parse a handle to a resource within the assembly format.
1023 template <typename ResourceT>
1024 FailureOr<ResourceT> parseResourceHandle() {
1025 SMLoc handleLoc = getCurrentLocation();
1026
1027 // Try to load the dialect that owns the handle.
1028 auto *dialect =
1029 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1030 if (!dialect) {
1031 return emitError(handleLoc)
1032 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1033 << "' is unknown";
1034 }
1035
1036 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1037 if (failed(handle))
1038 return failure();
1039 if (auto *result = dyn_cast<ResourceT>(&*handle))
1040 return std::move(*result);
1041 return emitError(handleLoc) << "provided resource handle differs from the "
1042 "expected resource type";
1043 }
1044
1045 //===--------------------------------------------------------------------===//
1046 // Type Parsing
1047 //===--------------------------------------------------------------------===//
1048
1049 /// Parse a type.
1050 virtual ParseResult parseType(Type &result) = 0;
1051
1052 /// Parse a custom type with the provided callback, unless the next
1053 /// token is `#`, in which case the generic parser is invoked.
1054 virtual ParseResult parseCustomTypeWithFallback(
1055 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1056
1057 /// Parse an optional type.
1058 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1059
1060 /// Parse a type of a specific type.
1061 template <typename TypeT>
1062 ParseResult parseType(TypeT &result) {
1063 SMLoc loc = getCurrentLocation();
1064
1065 // Parse any kind of type.
1066 Type type;
1067 if (parseType(type))
1068 return failure();
1069
1070 // Check for the right kind of type.
1071 result = type.dyn_cast<TypeT>();
1072 if (!result)
1073 return emitError(loc, "invalid kind of type specified");
1074
1075 return success();
1076 }
1077
1078 /// Trait to check if `TypeT` provides a `parse` method.
1079 template <typename TypeT>
1080 using type_has_parse_method =
1081 decltype(TypeT::parse(std::declval<AsmParser &>()));
1082 template <typename TypeT>
1083 using detect_type_has_parse_method =
1084 llvm::is_detected<type_has_parse_method, TypeT>;
1085
1086 /// Parse a custom Type of a given type unless the next token is `#`, in
1087 /// which case the generic parser is invoked. The parsed Type is
1088 /// populated in `result`.
1089 template <typename TypeT>
1090 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1091 parseCustomTypeWithFallback(TypeT &result) {
1092 SMLoc loc = getCurrentLocation();
1093
1094 // Parse any kind of Type.
1095 Type type;
1096 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1097 result = TypeT::parse(*this);
1098 return success(!!result);
1099 }))
1100 return failure();
1101
1102 // Check for the right kind of Type.
1103 result = type.dyn_cast<TypeT>();
1104 if (!result)
1105 return emitError(loc, "invalid kind of Type specified");
1106 return success();
1107 }
1108
1109 /// SFINAE parsing method for Type that don't implement a parse method.
1110 template <typename TypeT>
1111 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1112 parseCustomTypeWithFallback(TypeT &result) {
1113 return parseType(result);
1114 }
1115
1116 /// Parse a type list.
1117 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1118 return parseCommaSeparatedList(
1119 [&]() { return parseType(result.emplace_back()); });
1120 }
1121
1122 /// Parse an arrow followed by a type list.
1123 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1124
1125 /// Parse an optional arrow followed by a type list.
1126 virtual ParseResult
1127 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1128
1129 /// Parse a colon followed by a type.
1130 virtual ParseResult parseColonType(Type &result) = 0;
1131
1132 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1133 template <typename TypeType>
1134 ParseResult parseColonType(TypeType &result) {
1135 SMLoc loc = getCurrentLocation();
1136
1137 // Parse any kind of type.
1138 Type type;
1139 if (parseColonType(type))
1140 return failure();
1141
1142 // Check for the right kind of type.
1143 result = type.dyn_cast<TypeType>();
1144 if (!result)
1145 return emitError(loc, "invalid kind of type specified");
1146
1147 return success();
1148 }
1149
1150 /// Parse a colon followed by a type list, which must have at least one type.
1151 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1152
1153 /// Parse an optional colon followed by a type list, which if present must
1154 /// have at least one type.
1155 virtual ParseResult
1156 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1157
1158 /// Parse a keyword followed by a type.
1159 ParseResult parseKeywordType(const char *keyword, Type &result) {
1160 return failure(parseKeyword(keyword) || parseType(result));
1161 }
1162
1163 /// Add the specified type to the end of the specified type list and return
1164 /// success. This is a helper designed to allow parse methods to be simple
1165 /// and chain through || operators.
1166 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1167 result.push_back(type);
1168 return success();
1169 }
1170
1171 /// Add the specified types to the end of the specified type list and return
1172 /// success. This is a helper designed to allow parse methods to be simple
1173 /// and chain through || operators.
1174 ParseResult addTypesToList(ArrayRef<Type> types,
1175 SmallVectorImpl<Type> &result) {
1176 result.append(types.begin(), types.end());
1177 return success();
1178 }
1179
1180 /// Parse a dimension list of a tensor or memref type. This populates the
1181 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1182 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1183 ///
1184 /// dimension-list ::= eps | dimension (`x` dimension)*
1185 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1186 /// dimension ::= `?` | decimal-literal
1187 ///
1188 /// When `allowDynamic` is not set, this is used to parse:
1189 ///
1190 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1191 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1192 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1193 bool allowDynamic = true,
1194 bool withTrailingX = true) = 0;
1195
1196 /// Parse an 'x' token in a dimension list, handling the case where the x is
1197 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1198 /// next token.
1199 virtual ParseResult parseXInDimensionList() = 0;
1200
1201protected:
1202 /// Parse a handle to a resource within the assembly format for the given
1203 /// dialect.
1204 virtual FailureOr<AsmDialectResourceHandle>
1205 parseResourceHandle(Dialect *dialect) = 0;
1206
1207 //===--------------------------------------------------------------------===//
1208 // Code Completion
1209 //===--------------------------------------------------------------------===//
1210
1211 /// Parse a keyword, or an empty string if the current location signals a code
1212 /// completion.
1213 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1214
1215 /// Signal the code completion of a set of expected tokens.
1216 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1217
1218private:
1219 AsmParser(const AsmParser &) = delete;
1220 void operator=(const AsmParser &) = delete;
1221};
1222
1223//===----------------------------------------------------------------------===//
1224// OpAsmParser
1225//===----------------------------------------------------------------------===//
1226
1227/// The OpAsmParser has methods for interacting with the asm parser: parsing
1228/// things from it, emitting errors etc. It has an intentionally high-level API
1229/// that is designed to reduce/constrain syntax innovation in individual
1230/// operations.
1231///
1232/// For example, consider an op like this:
1233///
1234/// %x = load %p[%1, %2] : memref<...>
1235///
1236/// The "%x = load" tokens are already parsed and therefore invisible to the
1237/// custom op parser. This can be supported by calling `parseOperandList` to
1238/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1239/// parse the indices, then calling `parseColonTypeList` to parse the result
1240/// type.
1241///
1242class OpAsmParser : public AsmParser {
1243public:
1244 using AsmParser::AsmParser;
1245 ~OpAsmParser() override;
1246
1247 /// Parse a loc(...) specifier if present, filling in result if so.
1248 /// Location for BlockArgument and Operation may be deferred with an alias, in
1249 /// which case an OpaqueLoc is set and will be resolved when parsing
1250 /// completes.
1251 virtual ParseResult
1252 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1253
1254 /// Return the name of the specified result in the specified syntax, as well
1255 /// as the sub-element in the name. It returns an empty string and ~0U for
1256 /// invalid result numbers. For example, in this operation:
1257 ///
1258 /// %x, %y:2, %z = foo.op
1259 ///
1260 /// getResultName(0) == {"x", 0 }
1261 /// getResultName(1) == {"y", 0 }
1262 /// getResultName(2) == {"y", 1 }
1263 /// getResultName(3) == {"z", 0 }
1264 /// getResultName(4) == {"", ~0U }
1265 virtual std::pair<StringRef, unsigned>
1266 getResultName(unsigned resultNo) const = 0;
1267
1268 /// Return the number of declared SSA results. This returns 4 for the foo.op
1269 /// example in the comment for `getResultName`.
1270 virtual size_t getNumResults() const = 0;
1271
1272 // These methods emit an error and return failure or success. This allows
1273 // these to be chained together into a linear sequence of || expressions in
1274 // many cases.
1275
1276 /// Parse an operation in its generic form.
1277 /// The parsed operation is parsed in the current context and inserted in the
1278 /// provided block and insertion point. The results produced by this operation
1279 /// aren't mapped to any named value in the parser. Returns nullptr on
1280 /// failure.
1281 virtual Operation *parseGenericOperation(Block *insertBlock,
1282 Block::iterator insertPt) = 0;
1283
1284 /// Parse the name of an operation, in the custom form. On success, return a
1285 /// an object of type 'OperationName'. Otherwise, failure is returned.
1286 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1287
1288 //===--------------------------------------------------------------------===//
1289 // Operand Parsing
1290 //===--------------------------------------------------------------------===//
1291
1292 /// This is the representation of an operand reference.
1293 struct UnresolvedOperand {
1294 SMLoc location; // Location of the token.
1295 StringRef name; // Value name, e.g. %42 or %abc
1296 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1297 };
1298
1299 /// Parse different components, viz., use-info of operand(s), successor(s),
1300 /// region(s), attribute(s) and function-type, of the generic form of an
1301 /// operation instance and populate the input operation-state 'result' with
1302 /// those components. If any of the components is explicitly provided, then
1303 /// skip parsing that component.
1304 virtual ParseResult parseGenericOperationAfterOpName(
1305 OperationState &result,
1306 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1307 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1308 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1309 llvm::None,
1310 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1311 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1312
1313 /// Parse a single SSA value operand name along with a result number if
1314 /// `allowResultNumber` is true.
1315 virtual ParseResult parseOperand(UnresolvedOperand &result,
1316 bool allowResultNumber = true) = 0;
1317
1318 /// Parse a single operand if present.
1319 virtual OptionalParseResult
1320 parseOptionalOperand(UnresolvedOperand &result,
1321 bool allowResultNumber = true) = 0;
1322
1323 /// Parse zero or more SSA comma-separated operand references with a specified
1324 /// surrounding delimiter, and an optional required operand count.
1325 virtual ParseResult
1326 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1327 Delimiter delimiter = Delimiter::None,
1328 bool allowResultNumber = true,
1329 int requiredOperandCount = -1) = 0;
1330
1331 /// Parse a specified number of comma separated operands.
1332 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1333 int requiredOperandCount,
1334 Delimiter delimiter = Delimiter::None) {
1335 return parseOperandList(result, delimiter,
1336 /*allowResultNumber=*/true, requiredOperandCount);
1337 }
1338
1339 /// Parse zero or more trailing SSA comma-separated trailing operand
1340 /// references with a specified surrounding delimiter, and an optional
1341 /// required operand count. A leading comma is expected before the
1342 /// operands.
1343 ParseResult
1344 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1345 Delimiter delimiter = Delimiter::None) {
1346 if (failed(parseOptionalComma()))
1347 return success(); // The comma is optional.
1348 return parseOperandList(result, delimiter);
1349 }
1350
1351 /// Resolve an operand to an SSA value, emitting an error on failure.
1352 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1353 Type type,
1354 SmallVectorImpl<Value> &result) = 0;
1355
1356 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1357 /// appending the results to the list on success. This method should be used
1358 /// when all operands have the same type.
1359 ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands, Type type,
1360 SmallVectorImpl<Value> &result) {
1361 for (auto elt : operands)
1362 if (resolveOperand(elt, type, result))
1363 return failure();
1364 return success();
1365 }
1366
1367 /// Resolve a list of operands and a list of operand types to SSA values,
1368 /// emitting an error and returning failure, or appending the results
1369 /// to the list on success.
1370 ParseResult resolveOperands(ArrayRef<UnresolvedOperand> operands,
1371 ArrayRef<Type> types, SMLoc loc,
1372 SmallVectorImpl<Value> &result) {
1373 if (operands.size() != types.size())
1374 return emitError(loc)
1375 << operands.size() << " operands present, but expected "
1376 << types.size();
1377
1378 for (unsigned i = 0, e = operands.size(); i != e; ++i)
1379 if (resolveOperand(operands[i], types[i], result))
1380 return failure();
1381 return success();
1382 }
1383 template <typename Operands>
1384 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1385 SmallVectorImpl<Value> &result) {
1386 return resolveOperands(std::forward<Operands>(operands),
1387 ArrayRef<Type>(type), loc, result);
1388 }
1389 template <typename Operands, typename Types>
1390 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1391 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1392 SmallVectorImpl<Value> &result) {
1393 size_t operandSize = std::distance(operands.begin(), operands.end());
1394 size_t typeSize = std::distance(types.begin(), types.end());
1395 if (operandSize != typeSize)
1396 return emitError(loc)
1397 << operandSize << " operands present, but expected " << typeSize;
1398
1399 for (auto it : llvm::zip(operands, types))
1400 if (resolveOperand(std::get<0>(it), std::get<1>(it), result))
1401 return failure();
1402 return success();
1403 }
1404
1405 /// Parses an affine map attribute where dims and symbols are SSA operands.
1406 /// Operand values must come from single-result sources, and be valid
1407 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1408 virtual ParseResult
1409 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1410 Attribute &map, StringRef attrName,
1411 NamedAttrList &attrs,
1412 Delimiter delimiter = Delimiter::Square) = 0;
1413
1414 /// Parses an affine expression where dims and symbols are SSA operands.
1415 /// Operand values must come from single-result sources, and be valid
1416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1417 virtual ParseResult
1418 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1419 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1420 AffineExpr &expr) = 0;
1421
1422 //===--------------------------------------------------------------------===//
1423 // Argument Parsing
1424 //===--------------------------------------------------------------------===//
1425
1426 struct Argument {
1427 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1428 Type type; // Type.
1429 DictionaryAttr attrs; // Attributes if present.
1430 Optional<Location> sourceLoc; // Source location specifier if present.
1431 };
1432
1433 /// Parse a single argument with the following syntax:
1434 ///
1435 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1436 ///
1437 /// If `allowType` is false or `allowAttrs` are false then the respective
1438 /// parts of the grammar are not parsed.
1439 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1440 bool allowAttrs = false) = 0;
1441
1442 /// Parse a single argument if present.
1443 virtual OptionalParseResult
1444 parseOptionalArgument(Argument &result, bool allowType = false,
1445 bool allowAttrs = false) = 0;
1446
1447 /// Parse zero or more arguments with a specified surrounding delimiter.
1448 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1449 Delimiter delimiter = Delimiter::None,
1450 bool allowType = false,
1451 bool allowAttrs = false) = 0;
1452
1453 //===--------------------------------------------------------------------===//
1454 // Region Parsing
1455 //===--------------------------------------------------------------------===//
1456
1457 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1458 /// moved to the op regions after the op is created. The first block of the
1459 /// region takes 'arguments'.
1460 ///
1461 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1462 /// shadow the names of other existing SSA values defined above the region
1463 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1464 /// to operations that are 'IsolatedFromAbove'.
1465 virtual ParseResult parseRegion(Region &region,
1466 ArrayRef<Argument> arguments = {},
1467 bool enableNameShadowing = false) = 0;
1468
1469 /// Parses a region if present.
1470 virtual OptionalParseResult
1471 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1472 bool enableNameShadowing = false) = 0;
1473
1474 /// Parses a region if present. If the region is present, a new region is
1475 /// allocated and placed in `region`. If no region is present or on failure,
1476 /// `region` remains untouched.
1477 virtual OptionalParseResult
1478 parseOptionalRegion(std::unique_ptr<Region> &region,
1479 ArrayRef<Argument> arguments = {},
1480 bool enableNameShadowing = false) = 0;
1481
1482 //===--------------------------------------------------------------------===//
1483 // Successor Parsing
1484 //===--------------------------------------------------------------------===//
1485
1486 /// Parse a single operation successor.
1487 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1488
1489 /// Parse an optional operation successor.
1490 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1491
1492 /// Parse a single operation successor and its operand list.
1493 virtual ParseResult
1494 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1495
1496 //===--------------------------------------------------------------------===//
1497 // Type Parsing
1498 //===--------------------------------------------------------------------===//
1499
1500 /// Parse a list of assignments of the form
1501 /// (%x1 = %y1, %x2 = %y2, ...)
1502 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1503 SmallVectorImpl<UnresolvedOperand> &rhs) {
1504 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1505 if (!result.has_value())
1506 return emitError(getCurrentLocation(), "expected '('");
1507 return result.value();
1508 }
1509
1510 virtual OptionalParseResult
1511 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1512 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1513};
1514
1515//===--------------------------------------------------------------------===//
1516// Dialect OpAsm interface.
1517//===--------------------------------------------------------------------===//
1518
1519/// A functor used to set the name of the start of a result group of an
1520/// operation. See 'getAsmResultNames' below for more details.
1521using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1522
1523/// A functor used to set the name of blocks in regions directly nested under
1524/// an operation.
1525using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1526
1527class OpAsmDialectInterface
1528 : public DialectInterface::Base<OpAsmDialectInterface> {
1529public:
1530 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1531
1532 //===------------------------------------------------------------------===//
1533 // Aliases
1534 //===------------------------------------------------------------------===//
1535
1536 /// Holds the result of `getAlias` hook call.
1537 enum class AliasResult {
1538 /// The object (type or attribute) is not supported by the hook
1539 /// and an alias was not provided.
1540 NoAlias,
1541 /// An alias was provided, but it might be overriden by other hook.
1542 OverridableAlias,
1543 /// An alias was provided and it should be used
1544 /// (no other hooks will be checked).
1545 FinalAlias
1546 };
1547
1548 /// Hooks for getting an alias identifier alias for a given symbol, that is
1549 /// not necessarily a part of this dialect. The identifier is used in place of
1550 /// the symbol when printing textual IR. These aliases must not contain `.` or
1551 /// end with a numeric digit([0-9]+).
1552 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1553 return AliasResult::NoAlias;
1554 }
1555 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1556 return AliasResult::NoAlias;
1557 }
1558
1559 //===--------------------------------------------------------------------===//
1560 // Resources
1561 //===--------------------------------------------------------------------===//
1562
1563 /// Declare a resource with the given key, returning a handle to use for any
1564 /// references of this resource key within the IR during parsing. The result
1565 /// of `getResourceKey` on the returned handle is permitted to be different
1566 /// than `key`.
1567 virtual FailureOr<AsmDialectResourceHandle>
1568 declareResource(StringRef key) const {
1569 return failure();
1570 }
1571
1572 /// Return a key to use for the given resource. This key should uniquely
1573 /// identify this resource within the dialect.
1574 virtual std::string
1575 getResourceKey(const AsmDialectResourceHandle &handle) const {
1576 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1577)
1577 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1577)
;
1578 }
1579
1580 /// Hook for parsing resource entries. Returns failure if the entry was not
1581 /// valid, or could otherwise not be processed correctly. Any necessary errors
1582 /// can be emitted via the provided entry.
1583 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1584
1585 /// Hook for building resources to use during printing. The given `op` may be
1586 /// inspected to help determine what information to include.
1587 /// `referencedResources` contains all of the resources detected when printing
1588 /// 'op'.
1589 virtual void
1590 buildResources(Operation *op,
1591 const SetVector<AsmDialectResourceHandle> &referencedResources,
1592 AsmResourceBuilder &builder) const {}
1593};
1594} // namespace mlir
1595
1596//===--------------------------------------------------------------------===//
1597// Operation OpAsm interface.
1598//===--------------------------------------------------------------------===//
1599
1600/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1601#include "mlir/IR/OpAsmInterface.h.inc"
1602
1603namespace llvm {
1604template <>
1605struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1606 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1607 return {DenseMapInfo<void *>::getEmptyKey(),
1608 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1609 }
1610 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1611 return {DenseMapInfo<void *>::getTombstoneKey(),
1612 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1613 }
1614 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1615 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1616 }
1617 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1618 const mlir::AsmDialectResourceHandle &rhs) {
1619 return lhs.getResource() == rhs.getResource();
1620 }
1621};
1622} // namespace llvm
1623
1624#endif