Bug Summary

File:build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/IR/BuiltinAttributes.cpp
Warning:line 926, column 9
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name BuiltinAttributes.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/IR -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/IR -I include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/llvm/include -I /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-09-04-125545-48738-1 -x c++ /build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/IR/BuiltinAttributes.cpp

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/IR/BuiltinAttributes.cpp

1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Endian.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29/// Tablegen Attribute Definitions
30//===----------------------------------------------------------------------===//
31
32#define GET_ATTRDEF_CLASSES
33#include "mlir/IR/BuiltinAttributes.cpp.inc"
34
35//===----------------------------------------------------------------------===//
36// BuiltinDialect
37//===----------------------------------------------------------------------===//
38
39void BuiltinDialect::registerAttributes() {
40 addAttributes<
41#define GET_ATTRDEF_LIST
42#include "mlir/IR/BuiltinAttributes.cpp.inc"
43 >();
44}
45
46//===----------------------------------------------------------------------===//
47// ArrayAttr
48//===----------------------------------------------------------------------===//
49
50void ArrayAttr::walkImmediateSubElements(
51 function_ref<void(Attribute)> walkAttrsFn,
52 function_ref<void(Type)> walkTypesFn) const {
53 for (Attribute attr : getValue())
54 walkAttrsFn(attr);
55}
56
57Attribute
58ArrayAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
59 ArrayRef<Type> replTypes) const {
60 return get(getContext(), replAttrs);
61}
62
63//===----------------------------------------------------------------------===//
64// DictionaryAttr
65//===----------------------------------------------------------------------===//
66
67/// Helper function that does either an in place sort or sorts from source array
68/// into destination. If inPlace then storage is both the source and the
69/// destination, else value is the source and storage destination. Returns
70/// whether source was sorted.
71template <bool inPlace>
72static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
73 SmallVectorImpl<NamedAttribute> &storage) {
74 // Specialize for the common case.
75 switch (value.size()) {
76 case 0:
77 // Zero already sorted.
78 if (!inPlace)
79 storage.clear();
80 break;
81 case 1:
82 // One already sorted but may need to be copied.
83 if (!inPlace)
84 storage.assign({value[0]});
85 break;
86 case 2: {
87 bool isSorted = value[0] < value[1];
88 if (inPlace) {
89 if (!isSorted)
90 std::swap(storage[0], storage[1]);
91 } else if (isSorted) {
92 storage.assign({value[0], value[1]});
93 } else {
94 storage.assign({value[1], value[0]});
95 }
96 return !isSorted;
97 }
98 default:
99 if (!inPlace)
100 storage.assign(value.begin(), value.end());
101 // Check to see they are sorted already.
102 bool isSorted = llvm::is_sorted(value);
103 // If not, do a general sort.
104 if (!isSorted)
105 llvm::array_pod_sort(storage.begin(), storage.end());
106 return !isSorted;
107 }
108 return false;
109}
110
111/// Returns an entry with a duplicate name from the given sorted array of named
112/// attributes. Returns llvm::None if all elements have unique names.
113static Optional<NamedAttribute>
114findDuplicateElement(ArrayRef<NamedAttribute> value) {
115 const Optional<NamedAttribute> none{llvm::None};
116 if (value.size() < 2)
117 return none;
118
119 if (value.size() == 2)
120 return value[0].getName() == value[1].getName() ? value[0] : none;
121
122 const auto *it = std::adjacent_find(value.begin(), value.end(),
123 [](NamedAttribute l, NamedAttribute r) {
124 return l.getName() == r.getName();
125 });
126 return it != value.end() ? *it : none;
127}
128
129bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
130 SmallVectorImpl<NamedAttribute> &storage) {
131 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
132 assert(!findDuplicateElement(storage) &&(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__
))
133 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 133, __extension__ __PRETTY_FUNCTION__
))
;
134 return isSorted;
135}
136
137bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
138 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
139 assert(!findDuplicateElement(array) &&(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__
))
140 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 140, __extension__ __PRETTY_FUNCTION__
))
;
141 return isSorted;
142}
143
144Optional<NamedAttribute>
145DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
146 bool isSorted) {
147 if (!isSorted)
148 dictionaryAttrSort</*inPlace=*/true>(array, array);
149 return findDuplicateElement(array);
150}
151
152DictionaryAttr DictionaryAttr::get(MLIRContext *context,
153 ArrayRef<NamedAttribute> value) {
154 if (value.empty())
155 return DictionaryAttr::getEmpty(context);
156
157 // We need to sort the element list to canonicalize it.
158 SmallVector<NamedAttribute, 8> storage;
159 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
160 value = storage;
161 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__
))
162 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 162, __extension__ __PRETTY_FUNCTION__
))
;
163 return Base::get(context, value);
164}
165/// Construct a dictionary with an array of values that is known to already be
166/// sorted by name and uniqued.
167DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
168 ArrayRef<NamedAttribute> value) {
169 if (value.empty())
170 return DictionaryAttr::getEmpty(context);
171 // Ensure that the attribute elements are unique and sorted.
172 assert(llvm::is_sorted((static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
173 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
174 "expected attribute values to be sorted")(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 174, __extension__ __PRETTY_FUNCTION__
))
;
175 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__
))
176 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 176, __extension__ __PRETTY_FUNCTION__
))
;
177 return Base::get(context, value);
178}
179
180/// Return the specified attribute if present, null otherwise.
181Attribute DictionaryAttr::get(StringRef name) const {
182 auto it = impl::findAttrSorted(begin(), end(), name);
183 return it.second ? it.first->getValue() : Attribute();
184}
185Attribute DictionaryAttr::get(StringAttr name) const {
186 auto it = impl::findAttrSorted(begin(), end(), name);
187 return it.second ? it.first->getValue() : Attribute();
188}
189
190/// Return the specified named attribute if present, None otherwise.
191Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
192 auto it = impl::findAttrSorted(begin(), end(), name);
193 return it.second ? *it.first : Optional<NamedAttribute>();
194}
195Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
196 auto it = impl::findAttrSorted(begin(), end(), name);
197 return it.second ? *it.first : Optional<NamedAttribute>();
198}
199
200/// Return whether the specified attribute is present.
201bool DictionaryAttr::contains(StringRef name) const {
202 return impl::findAttrSorted(begin(), end(), name).second;
203}
204bool DictionaryAttr::contains(StringAttr name) const {
205 return impl::findAttrSorted(begin(), end(), name).second;
206}
207
208DictionaryAttr::iterator DictionaryAttr::begin() const {
209 return getValue().begin();
210}
211DictionaryAttr::iterator DictionaryAttr::end() const {
212 return getValue().end();
213}
214size_t DictionaryAttr::size() const { return getValue().size(); }
215
216DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
217 return Base::get(context, ArrayRef<NamedAttribute>());
218}
219
220void DictionaryAttr::walkImmediateSubElements(
221 function_ref<void(Attribute)> walkAttrsFn,
222 function_ref<void(Type)> walkTypesFn) const {
223 for (const NamedAttribute &attr : getValue())
224 walkAttrsFn(attr.getValue());
225}
226
227Attribute
228DictionaryAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
229 ArrayRef<Type> replTypes) const {
230 std::vector<NamedAttribute> vec = getValue().vec();
231 for (auto &it : llvm::enumerate(replAttrs))
232 vec[it.index()].setValue(it.value());
233
234 // The above only modifies the mapped value, but not the key, and therefore
235 // not the order of the elements. It remains sorted
236 return getWithSorted(getContext(), vec);
237}
238
239//===----------------------------------------------------------------------===//
240// StridedLayoutAttr
241//===----------------------------------------------------------------------===//
242
243/// Prints a strided layout attribute.
244void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
245 auto printIntOrQuestion = [&](int64_t value) {
246 if (value == ShapedType::kDynamicStrideOrOffset)
247 os << "?";
248 else
249 os << value;
250 };
251
252 os << "strided<[";
253 llvm::interleaveComma(getStrides(), os, printIntOrQuestion);
254 os << "]";
255
256 if (getOffset() != 0) {
257 os << ", offset: ";
258 printIntOrQuestion(getOffset());
259 }
260 os << ">";
261}
262
263/// Returns the strided layout as an affine map.
264AffineMap StridedLayoutAttr::getAffineMap() const {
265 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
266}
267
268/// Checks that the type-agnostic strided layout invariants are satisfied.
269LogicalResult
270StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
271 int64_t offset, ArrayRef<int64_t> strides) {
272 if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset)
273 return emitError() << "offset must be non-negative or dynamic";
274
275 if (llvm::any_of(strides, [&](int64_t stride) {
276 return stride <= 0 && stride != ShapedType::kDynamicStrideOrOffset;
277 })) {
278 return emitError() << "strides must be positive or dynamic";
279 }
280 return success();
281}
282
283/// Checks that the type-specific strided layout invariants are satisfied.
284LogicalResult StridedLayoutAttr::verifyLayout(
285 ArrayRef<int64_t> shape,
286 function_ref<InFlightDiagnostic()> emitError) const {
287 if (shape.size() != getStrides().size())
288 return emitError() << "expected the number of strides to match the rank";
289
290 return success();
291}
292
293//===----------------------------------------------------------------------===//
294// StringAttr
295//===----------------------------------------------------------------------===//
296
297StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
298 return Base::get(context, "", NoneType::get(context));
299}
300
301/// Twine support for StringAttr.
302StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
303 // Fast-path empty twine.
304 if (twine.isTriviallyEmpty())
305 return get(context);
306 SmallVector<char, 32> tempStr;
307 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
308}
309
310/// Twine support for StringAttr.
311StringAttr StringAttr::get(const Twine &twine, Type type) {
312 SmallVector<char, 32> tempStr;
313 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
314}
315
316StringRef StringAttr::getValue() const { return getImpl()->value; }
317
318Type StringAttr::getType() const { return getImpl()->type; }
319
320Dialect *StringAttr::getReferencedDialect() const {
321 return getImpl()->referencedDialect;
322}
323
324//===----------------------------------------------------------------------===//
325// FloatAttr
326//===----------------------------------------------------------------------===//
327
328double FloatAttr::getValueAsDouble() const {
329 return getValueAsDouble(getValue());
330}
331double FloatAttr::getValueAsDouble(APFloat value) {
332 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
333 bool losesInfo = false;
334 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
335 &losesInfo);
336 }
337 return value.convertToDouble();
338}
339
340LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
341 Type type, APFloat value) {
342 // Verify that the type is correct.
343 if (!type.isa<FloatType>())
344 return emitError() << "expected floating point type";
345
346 // Verify that the type semantics match that of the value.
347 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
348 return emitError()
349 << "FloatAttr type doesn't match the type implied by its value";
350 }
351 return success();
352}
353
354//===----------------------------------------------------------------------===//
355// SymbolRefAttr
356//===----------------------------------------------------------------------===//
357
358SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
359 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
360 return get(StringAttr::get(ctx, value), nestedRefs);
361}
362
363FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
364 return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
365}
366
367FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
368 return get(value, {}).cast<FlatSymbolRefAttr>();
369}
370
371FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
372 auto symName =
373 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
374 assert(symName && "value does not have a valid symbol name")(static_cast <bool> (symName && "value does not have a valid symbol name"
) ? void (0) : __assert_fail ("symName && \"value does not have a valid symbol name\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 374, __extension__ __PRETTY_FUNCTION__
))
;
375 return SymbolRefAttr::get(symName);
376}
377
378StringAttr SymbolRefAttr::getLeafReference() const {
379 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
380 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
381}
382
383void SymbolRefAttr::walkImmediateSubElements(
384 function_ref<void(Attribute)> walkAttrsFn,
385 function_ref<void(Type)> walkTypesFn) const {
386 walkAttrsFn(getRootReference());
387 for (FlatSymbolRefAttr ref : getNestedReferences())
388 walkAttrsFn(ref);
389}
390
391Attribute
392SymbolRefAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
393 ArrayRef<Type> replTypes) const {
394 ArrayRef<Attribute> rawNestedRefs = replAttrs.drop_front();
395 ArrayRef<FlatSymbolRefAttr> nestedRefs(
396 static_cast<const FlatSymbolRefAttr *>(rawNestedRefs.data()),
397 rawNestedRefs.size());
398 return get(replAttrs[0].cast<StringAttr>(), nestedRefs);
399}
400
401//===----------------------------------------------------------------------===//
402// IntegerAttr
403//===----------------------------------------------------------------------===//
404
405int64_t IntegerAttr::getInt() const {
406 assert((getType().isIndex() || getType().isSignlessInteger()) &&(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 407, __extension__ __PRETTY_FUNCTION__
))
407 "must be signless integer")(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 407, __extension__ __PRETTY_FUNCTION__
))
;
408 return getValue().getSExtValue();
409}
410
411int64_t IntegerAttr::getSInt() const {
412 assert(getType().isSignedInteger() && "must be signed integer")(static_cast <bool> (getType().isSignedInteger() &&
"must be signed integer") ? void (0) : __assert_fail ("getType().isSignedInteger() && \"must be signed integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 412, __extension__ __PRETTY_FUNCTION__
))
;
413 return getValue().getSExtValue();
414}
415
416uint64_t IntegerAttr::getUInt() const {
417 assert(getType().isUnsignedInteger() && "must be unsigned integer")(static_cast <bool> (getType().isUnsignedInteger() &&
"must be unsigned integer") ? void (0) : __assert_fail ("getType().isUnsignedInteger() && \"must be unsigned integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 417, __extension__ __PRETTY_FUNCTION__
))
;
418 return getValue().getZExtValue();
419}
420
421/// Return the value as an APSInt which carries the signed from the type of
422/// the attribute. This traps on signless integers types!
423APSInt IntegerAttr::getAPSInt() const {
424 assert(!getType().isSignlessInteger() &&(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 425, __extension__ __PRETTY_FUNCTION__
))
425 "Signless integers don't carry a sign for APSInt")(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 425, __extension__ __PRETTY_FUNCTION__
))
;
426 return APSInt(getValue(), getType().isUnsignedInteger());
427}
428
429LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
430 Type type, APInt value) {
431 if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
432 if (integerType.getWidth() != value.getBitWidth())
433 return emitError() << "integer type bit width (" << integerType.getWidth()
434 << ") doesn't match value bit width ("
435 << value.getBitWidth() << ")";
436 return success();
437 }
438 if (type.isa<IndexType>()) {
439 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
440 return emitError()
441 << "value bit width (" << value.getBitWidth()
442 << ") doesn't match index type internal storage bit width ("
443 << IndexType::kInternalStorageBitWidth << ")";
444 return success();
445 }
446 return emitError() << "expected integer or index type";
447}
448
449BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
450 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
451 return attr.cast<BoolAttr>();
452}
453
454//===----------------------------------------------------------------------===//
455// BoolAttr
456//===----------------------------------------------------------------------===//
457
458bool BoolAttr::getValue() const {
459 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
460 return storage->value.getBoolValue();
461}
462
463bool BoolAttr::classof(Attribute attr) {
464 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
465 return intAttr && intAttr.getType().isSignlessInteger(1);
466}
467
468//===----------------------------------------------------------------------===//
469// OpaqueAttr
470//===----------------------------------------------------------------------===//
471
472LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
473 StringAttr dialect, StringRef attrData,
474 Type type) {
475 if (!Dialect::isValidNamespace(dialect.strref()))
476 return emitError() << "invalid dialect namespace '" << dialect << "'";
477
478 // Check that the dialect is actually registered.
479 MLIRContext *context = dialect.getContext();
480 if (!context->allowsUnregisteredDialects() &&
481 !context->getLoadedDialect(dialect.strref())) {
482 return emitError()
483 << "#" << dialect << "<\"" << attrData << "\"> : " << type
484 << " attribute created with unregistered dialect. If this is "
485 "intended, please call allowUnregisteredDialects() on the "
486 "MLIRContext, or use -allow-unregistered-dialect with "
487 "the MLIR opt tool used";
488 }
489
490 return success();
491}
492
493//===----------------------------------------------------------------------===//
494// DenseElementsAttr Utilities
495//===----------------------------------------------------------------------===//
496
497/// Get the bitwidth of a dense element type within the buffer.
498/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
499static size_t getDenseElementStorageWidth(size_t origWidth) {
500 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
501}
502static size_t getDenseElementStorageWidth(Type elementType) {
503 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
504}
505
506/// Set a bit to a specific value.
507static void setBit(char *rawData, size_t bitPos, bool value) {
508 if (value)
509 rawData[bitPos / CHAR_BIT8] |= (1 << (bitPos % CHAR_BIT8));
510 else
511 rawData[bitPos / CHAR_BIT8] &= ~(1 << (bitPos % CHAR_BIT8));
512}
513
514/// Return the value of the specified bit.
515static bool getBit(const char *rawData, size_t bitPos) {
516 return (rawData[bitPos / CHAR_BIT8] & (1 << (bitPos % CHAR_BIT8))) != 0;
517}
518
519/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
520/// BE format.
521static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
522 char *result) {
523 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 524, __extension__ __PRETTY_FUNCTION__
))
524 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 524, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
525 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (value.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 525, __extension__ __PRETTY_FUNCTION__
))
;
526
527 // Copy the words filled with data.
528 // For example, when `value` has 2 words, the first word is filled with data.
529 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
530 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
531 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
532 numFilledWords, result);
533 // Convert last word of APInt to LE format and store it in char
534 // array(`valueLE`).
535 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
536 size_t lastWordPos = numFilledWords;
537 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
538 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
539 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
540 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
541 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
542 // and store it in `result`.
543 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
544 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
545 valueLE.begin(), result + lastWordPos,
546 (numBytes - lastWordPos) * CHAR_BIT8, 1);
547}
548
549/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
550/// format.
551static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
552 APInt &result) {
553 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 554, __extension__ __PRETTY_FUNCTION__
))
554 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 554, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
555 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (result.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 555, __extension__ __PRETTY_FUNCTION__
))
;
556
557 // Copy the data that fills the word of `result` from `inArray`.
558 // For example, when `result` has 2 words, the first word will be filled with
559 // data. So, the first 8 bytes are copied from `inArray` here.
560 // `inArray` (10 bytes, BE): |abcdefgh|ij|
561 // ==> `result` (2 words, BE): |abcdefgh|--------|
562 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
563 std::copy_n(
564 inArray, numFilledWords,
565 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
566
567 // Convert array data which will be last word of `result` to LE format, and
568 // store it in char array(`inArrayLE`).
569 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
570 size_t lastWordPos = numFilledWords;
571 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
572 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
573 inArray + lastWordPos, inArrayLE.begin(),
574 (numBytes - lastWordPos) * CHAR_BIT8, 1);
575
576 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
577 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
578 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
579 inArrayLE.begin(),
580 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
581 lastWordPos,
582 APInt::APINT_BITS_PER_WORD, 1);
583}
584
585/// Writes value to the bit position `bitPos` in array `rawData`.
586static void writeBits(char *rawData, size_t bitPos, APInt value) {
587 size_t bitWidth = value.getBitWidth();
588
589 // If the bitwidth is 1 we just toggle the specific bit.
590 if (bitWidth == 1)
591 return setBit(rawData, bitPos, value.isOneValue());
592
593 // Otherwise, the bit position is guaranteed to be byte aligned.
594 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 594, __extension__ __PRETTY_FUNCTION__
))
;
595 if (llvm::support::endian::system_endianness() ==
596 llvm::support::endianness::big) {
597 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
598 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
599 // work correctly in BE format.
600 // ex. `value` (2 words including 10 bytes)
601 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
602 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT8),
603 rawData + (bitPos / CHAR_BIT8));
604 } else {
605 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
606 llvm::divideCeil(bitWidth, CHAR_BIT8),
607 rawData + (bitPos / CHAR_BIT8));
608 }
609}
610
611/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
612/// `rawData`.
613static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
614 // Handle a boolean bit position.
615 if (bitWidth == 1)
616 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
617
618 // Otherwise, the bit position must be 8-bit aligned.
619 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 619, __extension__ __PRETTY_FUNCTION__
))
;
620 APInt result(bitWidth, 0);
621 if (llvm::support::endian::system_endianness() ==
622 llvm::support::endianness::big) {
623 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
624 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
625 // work correctly in BE format.
626 // ex. `result` (2 words including 10 bytes)
627 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
628 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT8),
629 llvm::divideCeil(bitWidth, CHAR_BIT8), result);
630 } else {
631 std::copy_n(rawData + (bitPos / CHAR_BIT8),
632 llvm::divideCeil(bitWidth, CHAR_BIT8),
633 const_cast<char *>(
634 reinterpret_cast<const char *>(result.getRawData())));
635 }
636 return result;
637}
638
639/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
640/// the same element count as 'type'.
641template <typename Values>
642static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
643 return (values.size() == 1) ||
644 (type.getNumElements() == static_cast<int64_t>(values.size()));
645}
646
647//===----------------------------------------------------------------------===//
648// DenseElementsAttr Iterators
649//===----------------------------------------------------------------------===//
650
651//===----------------------------------------------------------------------===//
652// AttributeElementIterator
653
654DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
655 DenseElementsAttr attr, size_t index)
656 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
657 Attribute, Attribute, Attribute>(
658 attr.getAsOpaquePointer(), index) {}
659
660Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
661 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
662 Type eltTy = owner.getElementType();
663 if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
664 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
665 if (eltTy.isa<IndexType>())
666 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
667 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
668 IntElementIterator intIt(owner, index);
669 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
670 return FloatAttr::get(eltTy, *floatIt);
671 }
672 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
673 auto complexEltTy = complexTy.getElementType();
674 ComplexIntElementIterator complexIntIt(owner, index);
675 if (complexEltTy.isa<IntegerType>()) {
676 auto value = *complexIntIt;
677 auto real = IntegerAttr::get(complexEltTy, value.real());
678 auto imag = IntegerAttr::get(complexEltTy, value.imag());
679 return ArrayAttr::get(complexTy.getContext(),
680 ArrayRef<Attribute>{real, imag});
681 }
682
683 ComplexFloatElementIterator complexFloatIt(
684 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
685 auto value = *complexFloatIt;
686 auto real = FloatAttr::get(complexEltTy, value.real());
687 auto imag = FloatAttr::get(complexEltTy, value.imag());
688 return ArrayAttr::get(complexTy.getContext(),
689 ArrayRef<Attribute>{real, imag});
690 }
691 if (owner.isa<DenseStringElementsAttr>()) {
692 ArrayRef<StringRef> vals = owner.getRawStringData();
693 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
694 }
695 llvm_unreachable("unexpected element type")::llvm::llvm_unreachable_internal("unexpected element type", "mlir/lib/IR/BuiltinAttributes.cpp"
, 695)
;
696}
697
698//===----------------------------------------------------------------------===//
699// BoolElementIterator
700
701DenseElementsAttr::BoolElementIterator::BoolElementIterator(
702 DenseElementsAttr attr, size_t dataIndex)
703 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
704 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
705
706bool DenseElementsAttr::BoolElementIterator::operator*() const {
707 return getBit(getData(), getDataIndex());
708}
709
710//===----------------------------------------------------------------------===//
711// IntElementIterator
712
713DenseElementsAttr::IntElementIterator::IntElementIterator(
714 DenseElementsAttr attr, size_t dataIndex)
715 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
716 attr.getRawData().data(), attr.isSplat(), dataIndex),
717 bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
718
719APInt DenseElementsAttr::IntElementIterator::operator*() const {
720 return readBits(getData(),
721 getDataIndex() * getDenseElementStorageWidth(bitWidth),
722 bitWidth);
723}
724
725//===----------------------------------------------------------------------===//
726// ComplexIntElementIterator
727
728DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
729 DenseElementsAttr attr, size_t dataIndex)
730 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
731 std::complex<APInt>, std::complex<APInt>,
732 std::complex<APInt>>(
733 attr.getRawData().data(), attr.isSplat(), dataIndex) {
734 auto complexType = attr.getElementType().cast<ComplexType>();
735 bitWidth = getDenseElementBitWidth(complexType.getElementType());
736}
737
738std::complex<APInt>
739DenseElementsAttr::ComplexIntElementIterator::operator*() const {
740 size_t storageWidth = getDenseElementStorageWidth(bitWidth);
741 size_t offset = getDataIndex() * storageWidth * 2;
742 return {readBits(getData(), offset, bitWidth),
743 readBits(getData(), offset + storageWidth, bitWidth)};
744}
745
746//===----------------------------------------------------------------------===//
747// DenseArrayAttr
748//===----------------------------------------------------------------------===//
749
750LogicalResult
751DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
752 RankedTensorType type, ArrayRef<char> rawData) {
753 if (type.getRank() != 1)
754 return emitError() << "expected rank 1 tensor type";
755 if (!type.getElementType().isIntOrIndexOrFloat())
756 return emitError() << "expected integer or floating point element type";
757 int64_t dataSize = rawData.size();
758 int64_t size = type.getShape().front();
759 if (type.getElementType().isInteger(1)) {
760 if (size != dataSize)
761 return emitError() << "expected " << size
762 << " bytes for i1 array but got " << dataSize;
763 } else if (size * type.getElementTypeBitWidth() != dataSize * 8) {
764 return emitError() << "expected data size (" << size << " elements, "
765 << type.getElementTypeBitWidth()
766 << " bits each) does not match: " << dataSize
767 << " bytes";
768 }
769 return success();
770}
771
772FailureOr<const bool *>
773DenseArrayAttr::try_value_begin_impl(OverloadToken<bool>) const {
774 if (auto attr = dyn_cast<DenseBoolArrayAttr>())
775 return attr.asArrayRef().begin();
776 return failure();
777}
778FailureOr<const int8_t *>
779DenseArrayAttr::try_value_begin_impl(OverloadToken<int8_t>) const {
780 if (auto attr = dyn_cast<DenseI8ArrayAttr>())
781 return attr.asArrayRef().begin();
782 return failure();
783}
784FailureOr<const int16_t *>
785DenseArrayAttr::try_value_begin_impl(OverloadToken<int16_t>) const {
786 if (auto attr = dyn_cast<DenseI16ArrayAttr>())
787 return attr.asArrayRef().begin();
788 return failure();
789}
790FailureOr<const int32_t *>
791DenseArrayAttr::try_value_begin_impl(OverloadToken<int32_t>) const {
792 if (auto attr = dyn_cast<DenseI32ArrayAttr>())
793 return attr.asArrayRef().begin();
794 return failure();
795}
796FailureOr<const int64_t *>
797DenseArrayAttr::try_value_begin_impl(OverloadToken<int64_t>) const {
798 if (auto attr = dyn_cast<DenseI64ArrayAttr>())
799 return attr.asArrayRef().begin();
800 return failure();
801}
802FailureOr<const float *>
803DenseArrayAttr::try_value_begin_impl(OverloadToken<float>) const {
804 if (auto attr = dyn_cast<DenseF32ArrayAttr>())
805 return attr.asArrayRef().begin();
806 return failure();
807}
808FailureOr<const double *>
809DenseArrayAttr::try_value_begin_impl(OverloadToken<double>) const {
810 if (auto attr = dyn_cast<DenseF64ArrayAttr>())
811 return attr.asArrayRef().begin();
812 return failure();
813}
814
815namespace {
816/// Instantiations of this class provide utilities for interacting with native
817/// data types in the context of DenseArrayAttr.
818template <size_t width,
819 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
820struct DenseArrayAttrIntUtil {
821 static bool checkElementType(Type eltType) {
822 auto type = eltType.dyn_cast<IntegerType>();
823 if (!type || type.getWidth() != width)
824 return false;
825 return type.getSignedness() == signedness;
826 }
827
828 static Type getElementType(MLIRContext *ctx) {
829 return IntegerType::get(ctx, width, signedness);
830 }
831
832 template <typename T>
833 static void printElement(raw_ostream &os, T value) {
834 os << value;
835 }
836
837 template <typename T>
838 static ParseResult parseElement(AsmParser &parser, T &value) {
839 return parser.parseInteger(value);
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Returning without writing to 'value'
840 }
841};
842template <typename T>
843struct DenseArrayAttrUtil;
844
845/// Specialization for boolean elements to print 'true' and 'false' literals for
846/// elements.
847template <>
848struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
849 static void printElement(raw_ostream &os, bool value) {
850 os << (value ? "true" : "false");
851 }
852};
853
854/// Specialization for 8-bit integers to ensure values are printed as integers
855/// and not characters.
856template <>
857struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
858 static void printElement(raw_ostream &os, int8_t value) {
859 os << static_cast<int>(value);
860 }
861};
862template <>
863struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
864template <>
865struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
866template <>
867struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
868
869/// Specialization for 32-bit floats.
870template <>
871struct DenseArrayAttrUtil<float> {
872 static bool checkElementType(Type eltType) { return eltType.isF32(); }
873 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
874 static void printElement(raw_ostream &os, float value) { os << value; }
875
876 /// Parse a double and cast it to a float.
877 static ParseResult parseElement(AsmParser &parser, float &value) {
878 double doubleVal;
879 if (parser.parseFloat(doubleVal))
880 return failure();
881 value = doubleVal;
882 return success();
883 }
884};
885
886/// Specialization for 64-bit floats.
887template <>
888struct DenseArrayAttrUtil<double> {
889 static bool checkElementType(Type eltType) { return eltType.isF64(); }
890 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
891 static void printElement(raw_ostream &os, float value) { os << value; }
892 static ParseResult parseElement(AsmParser &parser, double &value) {
893 return parser.parseFloat(value);
894 }
895};
896} // namespace
897
898template <typename T>
899void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
900 print(printer.getStream());
901}
902
903template <typename T>
904void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
905 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
906 DenseArrayAttrUtil<T>::printElement(os, value);
907 });
908}
909
910template <typename T>
911void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
912 os << "[";
913 printWithoutBraces(os);
914 os << "]";
915}
916
917/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
918template <typename T>
919Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
920 Type odsType) {
921 SmallVector<T> data;
922 if (failed(parser.parseCommaSeparatedList([&]() {
923 T value;
1
'value' declared without an initial value
924 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
2
Calling 'DenseArrayAttrIntUtil::parseElement'
13
Returning from 'DenseArrayAttrIntUtil::parseElement'
14
Taking false branch
925 return failure();
926 data.push_back(value);
15
1st function call argument is an uninitialized value
927 return success();
928 })))
929 return {};
930 return get(parser.getContext(), data);
931}
932
933/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
934template <typename T>
935Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
936 if (parser.parseLSquare())
937 return {};
938 // Handle empty list case.
939 if (succeeded(parser.parseOptionalRSquare()))
940 return get(parser.getContext(), {});
941 Attribute result = parseWithoutBraces(parser, odsType);
942 if (parser.parseRSquare())
943 return {};
944 return result;
945}
946
947/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
948template <typename T>
949DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
950 ArrayRef<char> raw = getRawData();
951 assert((raw.size() % sizeof(T)) == 0)(static_cast <bool> ((raw.size() % sizeof(T)) == 0) ? void
(0) : __assert_fail ("(raw.size() % sizeof(T)) == 0", "mlir/lib/IR/BuiltinAttributes.cpp"
, 951, __extension__ __PRETTY_FUNCTION__))
;
952 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
953 raw.size() / sizeof(T));
954}
955
956/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
957template <typename T>
958DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
959 ArrayRef<T> content) {
960 auto shapedType = RankedTensorType::get(
961 content.size(), DenseArrayAttrUtil<T>::getElementType(context));
962 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
963 content.size() * sizeof(T));
964 return Base::get(context, shapedType, rawArray)
965 .template cast<DenseArrayAttrImpl<T>>();
966}
967
968template <typename T>
969bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
970 if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
971 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
972 return false;
973}
974
975namespace mlir {
976namespace detail {
977// Explicit instantiation for all the supported DenseArrayAttr.
978template class DenseArrayAttrImpl<bool>;
979template class DenseArrayAttrImpl<int8_t>;
980template class DenseArrayAttrImpl<int16_t>;
981template class DenseArrayAttrImpl<int32_t>;
982template class DenseArrayAttrImpl<int64_t>;
983template class DenseArrayAttrImpl<float>;
984template class DenseArrayAttrImpl<double>;
985} // namespace detail
986} // namespace mlir
987
988//===----------------------------------------------------------------------===//
989// DenseElementsAttr
990//===----------------------------------------------------------------------===//
991
992/// Method for support type inquiry through isa, cast and dyn_cast.
993bool DenseElementsAttr::classof(Attribute attr) {
994 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
995}
996
997DenseElementsAttr DenseElementsAttr::get(ShapedType type,
998 ArrayRef<Attribute> values) {
999 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 999, __extension__ __PRETTY_FUNCTION__
))
;
1000
1001 // If the element type is not based on int/float/index, assume it is a string
1002 // type.
1003 Type eltType = type.getElementType();
1004 if (!eltType.isIntOrIndexOrFloat()) {
1005 SmallVector<StringRef, 8> stringValues;
1006 stringValues.reserve(values.size());
1007 for (Attribute attr : values) {
1008 assert(attr.isa<StringAttr>() &&(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1009, __extension__ __PRETTY_FUNCTION__
))
1009 "expected string value for non integer/index/float element")(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1009, __extension__ __PRETTY_FUNCTION__
))
;
1010 stringValues.push_back(attr.cast<StringAttr>().getValue());
1011 }
1012 return get(type, stringValues);
1013 }
1014
1015 // Otherwise, get the raw storage width to use for the allocation.
1016 size_t bitWidth = getDenseElementBitWidth(eltType);
1017 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1018
1019 // Compress the attribute values into a character buffer.
1020 SmallVector<char, 8> data(
1021 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT8));
1022 APInt intVal;
1023 for (unsigned i = 0, e = values.size(); i < e; ++i) {
1024 if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
1025 assert(floatAttr.getType() == eltType &&(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1026, __extension__ __PRETTY_FUNCTION__
))
1026 "expected float attribute type to equal element type")(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1026, __extension__ __PRETTY_FUNCTION__
))
;
1027 intVal = floatAttr.getValue().bitcastToAPInt();
1028 } else {
1029 auto intAttr = values[i].cast<IntegerAttr>();
1030 assert(intAttr.getType() == eltType &&(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1031, __extension__ __PRETTY_FUNCTION__
))
1031 "expected integer attribute type to equal element type")(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1031, __extension__ __PRETTY_FUNCTION__
))
;
1032 intVal = intAttr.getValue();
1033 }
1034
1035 assert(intVal.getBitWidth() == bitWidth &&(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1036, __extension__ __PRETTY_FUNCTION__
))
1036 "expected value to have same bitwidth as element type")(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1036, __extension__ __PRETTY_FUNCTION__
))
;
1037 writeBits(data.data(), i * storageBitWidth, intVal);
1038 }
1039
1040 // Handle the special encoding of splat of bool.
1041 if (values.size() == 1 && eltType.isInteger(1))
1042 data[0] = data[0] ? -1 : 0;
1043
1044 return DenseIntOrFPElementsAttr::getRaw(type, data);
1045}
1046
1047DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1048 ArrayRef<bool> values) {
1049 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1049, __extension__ __PRETTY_FUNCTION__
))
;
1050 assert(type.getElementType().isInteger(1))(static_cast <bool> (type.getElementType().isInteger(1)
) ? void (0) : __assert_fail ("type.getElementType().isInteger(1)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1050, __extension__ __PRETTY_FUNCTION__
))
;
1051
1052 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT8));
1053
1054 if (!values.empty()) {
1055 bool isSplat = true;
1056 bool firstValue = values[0];
1057 for (int i = 0, e = values.size(); i != e; ++i) {
1058 isSplat &= values[i] == firstValue;
1059 setBit(buff.data(), i, values[i]);
1060 }
1061
1062 // Splat of bool is encoded as a byte with all-ones in it.
1063 if (isSplat) {
1064 buff.resize(1);
1065 buff[0] = values[0] ? -1 : 0;
1066 }
1067 }
1068
1069 return DenseIntOrFPElementsAttr::getRaw(type, buff);
1070}
1071
1072DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1073 ArrayRef<StringRef> values) {
1074 assert(!type.getElementType().isIntOrFloat())(static_cast <bool> (!type.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("!type.getElementType().isIntOrFloat()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1074, __extension__ __PRETTY_FUNCTION__
))
;
1075 return DenseStringElementsAttr::get(type, values);
1076}
1077
1078/// Constructs a dense integer elements attribute from an array of APInt
1079/// values. Each APInt value is expected to have the same bitwidth as the
1080/// element type of 'type'.
1081DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1082 ArrayRef<APInt> values) {
1083 assert(type.getElementType().isIntOrIndex())(static_cast <bool> (type.getElementType().isIntOrIndex
()) ? void (0) : __assert_fail ("type.getElementType().isIntOrIndex()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1083, __extension__ __PRETTY_FUNCTION__
))
;
1084 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1084, __extension__ __PRETTY_FUNCTION__
))
;
1085 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1086 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1087}
1088DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1089 ArrayRef<std::complex<APInt>> values) {
1090 ComplexType complex = type.getElementType().cast<ComplexType>();
1091 assert(complex.getElementType().isa<IntegerType>())(static_cast <bool> (complex.getElementType().isa<IntegerType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<IntegerType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1091, __extension__ __PRETTY_FUNCTION__
))
;
1092 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1092, __extension__ __PRETTY_FUNCTION__
))
;
1093 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1094 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
1095 values.size() * 2);
1096 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
1097}
1098
1099// Constructs a dense float elements attribute from an array of APFloat
1100// values. Each APFloat value is expected to have the same bitwidth as the
1101// element type of 'type'.
1102DenseElementsAttr DenseElementsAttr::get(ShapedType type,
1103 ArrayRef<APFloat> values) {
1104 assert(type.getElementType().isa<FloatType>())(static_cast <bool> (type.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("type.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1104, __extension__ __PRETTY_FUNCTION__
))
;
1105 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1105, __extension__ __PRETTY_FUNCTION__
))
;
1106 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
1107 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1108}
1109DenseElementsAttr
1110DenseElementsAttr::get(ShapedType type,
1111 ArrayRef<std::complex<APFloat>> values) {
1112 ComplexType complex = type.getElementType().cast<ComplexType>();
1113 assert(complex.getElementType().isa<FloatType>())(static_cast <bool> (complex.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1113, __extension__ __PRETTY_FUNCTION__
))
;
1114 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1114, __extension__ __PRETTY_FUNCTION__
))
;
1115 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1116 values.size() * 2);
1117 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1118 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1119}
1120
1121/// Construct a dense elements attribute from a raw buffer representing the
1122/// data for this attribute. Users should generally not use this methods as
1123/// the expected buffer format may not be a form the user expects.
1124DenseElementsAttr
1125DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1126 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1127}
1128
1129/// Returns true if the given buffer is a valid raw buffer for the given type.
1130bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1131 ArrayRef<char> rawBuffer,
1132 bool &detectedSplat) {
1133 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1134 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT8;
1135 int64_t numElements = type.getNumElements();
1136
1137 // The initializer is always a splat if the result type has a single element.
1138 detectedSplat = numElements == 1;
1139
1140 // Storage width of 1 is special as it is packed by the bit.
1141 if (storageWidth == 1) {
1142 // Check for a splat, or a buffer equal to the number of elements which
1143 // consists of either all 0's or all 1's.
1144 if (rawBuffer.size() == 1) {
1145 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1146 if (rawByte == 0 || rawByte == 0xff) {
1147 detectedSplat = true;
1148 return true;
1149 }
1150 }
1151
1152 // This is a valid non-splat buffer if it has the right size.
1153 return rawBufferWidth == llvm::alignTo<8>(numElements);
1154 }
1155
1156 // All other types are 8-bit aligned, so we can just check the buffer width
1157 // to know if only a single initializer element was passed in.
1158 if (rawBufferWidth == storageWidth) {
1159 detectedSplat = true;
1160 return true;
1161 }
1162
1163 // The raw buffer is valid if it has the right size.
1164 return rawBufferWidth == storageWidth * numElements;
1165}
1166
1167/// Check the information for a C++ data type, check if this type is valid for
1168/// the current attribute. This method is used to verify specific type
1169/// invariants that the templatized 'getValues' method cannot.
1170static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1171 bool isSigned) {
1172 // Make sure that the data element size is the same as the type element width.
1173 if (getDenseElementBitWidth(type) !=
1174 static_cast<size_t>(dataEltSize * CHAR_BIT8))
1175 return false;
1176
1177 // Check that the element type is either float or integer or index.
1178 if (!isInt)
1179 return type.isa<FloatType>();
1180 if (type.isIndex())
1181 return true;
1182
1183 auto intType = type.dyn_cast<IntegerType>();
1184 if (!intType)
1185 return false;
1186
1187 // Make sure signedness semantics is consistent.
1188 if (intType.isSignless())
1189 return true;
1190 return intType.isSigned() ? isSigned : !isSigned;
1191}
1192
1193/// Defaults down the subclass implementation.
1194DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1195 ArrayRef<char> data,
1196 int64_t dataEltSize,
1197 bool isInt, bool isSigned) {
1198 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1199 isSigned);
1200}
1201DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1202 ArrayRef<char> data,
1203 int64_t dataEltSize,
1204 bool isInt,
1205 bool isSigned) {
1206 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1207 isInt, isSigned);
1208}
1209
1210bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1211 bool isSigned) const {
1212 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1213}
1214bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1215 bool isSigned) const {
1216 return ::isValidIntOrFloat(
1217 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1218 isInt, isSigned);
1219}
1220
1221/// Returns true if this attribute corresponds to a splat, i.e. if all element
1222/// values are the same.
1223bool DenseElementsAttr::isSplat() const {
1224 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1225}
1226
1227/// Return if the given complex type has an integer element type.
1228static bool isComplexOfIntType(Type type) {
1229 return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1230}
1231
1232auto DenseElementsAttr::tryGetComplexIntValues() const
1233 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1234 if (!isComplexOfIntType(getElementType()))
1235 return failure();
1236 return iterator_range_impl<ComplexIntElementIterator>(
1237 getType(), ComplexIntElementIterator(*this, 0),
1238 ComplexIntElementIterator(*this, getNumElements()));
1239}
1240
1241auto DenseElementsAttr::tryGetFloatValues() const
1242 -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1243 auto eltTy = getElementType().dyn_cast<FloatType>();
1244 if (!eltTy)
1245 return failure();
1246 const auto &elementSemantics = eltTy.getFloatSemantics();
1247 return iterator_range_impl<FloatElementIterator>(
1248 getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1249 FloatElementIterator(elementSemantics, raw_int_end()));
1250}
1251
1252auto DenseElementsAttr::tryGetComplexFloatValues() const
1253 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1254 auto complexTy = getElementType().dyn_cast<ComplexType>();
1255 if (!complexTy)
1256 return failure();
1257 auto eltTy = complexTy.getElementType().dyn_cast<FloatType>();
1258 if (!eltTy)
1259 return failure();
1260 const auto &semantics = eltTy.getFloatSemantics();
1261 return iterator_range_impl<ComplexFloatElementIterator>(
1262 getType(), {semantics, {*this, 0}},
1263 {semantics, {*this, static_cast<size_t>(getNumElements())}});
1264}
1265
1266/// Return the raw storage data held by this attribute.
1267ArrayRef<char> DenseElementsAttr::getRawData() const {
1268 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1269}
1270
1271ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1272 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1273}
1274
1275/// Return a new DenseElementsAttr that has the same data as the current
1276/// attribute, but has been reshaped to 'newType'. The new type must have the
1277/// same total number of elements as well as element type.
1278DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1279 ShapedType curType = getType();
1280 if (curType == newType)
1281 return *this;
1282
1283 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1284, __extension__ __PRETTY_FUNCTION__
))
1284 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1284, __extension__ __PRETTY_FUNCTION__
))
;
1285 assert(newType.getNumElements() == curType.getNumElements() &&(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1286, __extension__ __PRETTY_FUNCTION__
))
1286 "expected the same number of elements")(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1286, __extension__ __PRETTY_FUNCTION__
))
;
1287 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1288}
1289
1290DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1291 assert(isSplat() && "expected a splat type")(static_cast <bool> (isSplat() && "expected a splat type"
) ? void (0) : __assert_fail ("isSplat() && \"expected a splat type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1291, __extension__ __PRETTY_FUNCTION__
))
;
1292
1293 ShapedType curType = getType();
1294 if (curType == newType)
1295 return *this;
1296
1297 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1298, __extension__ __PRETTY_FUNCTION__
))
1298 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1298, __extension__ __PRETTY_FUNCTION__
))
;
1299 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1300}
1301
1302/// Return a new DenseElementsAttr that has the same data as the current
1303/// attribute, but has bitcast elements such that it is now 'newType'. The new
1304/// type must have the same shape and element types of the same bitwidth as the
1305/// current type.
1306DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1307 ShapedType curType = getType();
1308 Type curElType = curType.getElementType();
1309 if (curElType == newElType)
1310 return *this;
1311
1312 assert(getDenseElementBitWidth(newElType) ==(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1314, __extension__ __PRETTY_FUNCTION__
))
1313 getDenseElementBitWidth(curElType) &&(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1314, __extension__ __PRETTY_FUNCTION__
))
1314 "expected element types with the same bitwidth")(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1314, __extension__ __PRETTY_FUNCTION__
))
;
1315 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1316 getRawData());
1317}
1318
1319DenseElementsAttr
1320DenseElementsAttr::mapValues(Type newElementType,
1321 function_ref<APInt(const APInt &)> mapping) const {
1322 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1323}
1324
1325DenseElementsAttr DenseElementsAttr::mapValues(
1326 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1327 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1328}
1329
1330ShapedType DenseElementsAttr::getType() const {
1331 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1332}
1333
1334Type DenseElementsAttr::getElementType() const {
1335 return getType().getElementType();
1336}
1337
1338int64_t DenseElementsAttr::getNumElements() const {
1339 return getType().getNumElements();
1340}
1341
1342//===----------------------------------------------------------------------===//
1343// DenseIntOrFPElementsAttr
1344//===----------------------------------------------------------------------===//
1345
1346/// Utility method to write a range of APInt values to a buffer.
1347template <typename APRangeT>
1348static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1349 APRangeT &&values) {
1350 size_t numValues = llvm::size(values);
1351 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT8));
1352 size_t offset = 0;
1353 for (auto it = values.begin(), e = values.end(); it != e;
1354 ++it, offset += storageWidth) {
1355 assert((*it).getBitWidth() <= storageWidth)(static_cast <bool> ((*it).getBitWidth() <= storageWidth
) ? void (0) : __assert_fail ("(*it).getBitWidth() <= storageWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1355, __extension__ __PRETTY_FUNCTION__
))
;
1356 writeBits(data.data(), offset, *it);
1357 }
1358
1359 // Handle the special encoding of splat of a boolean.
1360 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1361 data[0] = data[0] ? -1 : 0;
1362}
1363
1364/// Constructs a dense elements attribute from an array of raw APFloat values.
1365/// Each APFloat value is expected to have the same bitwidth as the element
1366/// type of 'type'. 'type' must be a vector or tensor with static shape.
1367DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1368 size_t storageWidth,
1369 ArrayRef<APFloat> values) {
1370 std::vector<char> data;
1371 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1372 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1373 return DenseIntOrFPElementsAttr::getRaw(type, data);
1374}
1375
1376/// Constructs a dense elements attribute from an array of raw APInt values.
1377/// Each APInt value is expected to have the same bitwidth as the element type
1378/// of 'type'.
1379DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1380 size_t storageWidth,
1381 ArrayRef<APInt> values) {
1382 std::vector<char> data;
1383 writeAPIntsToBuffer(storageWidth, data, values);
1384 return DenseIntOrFPElementsAttr::getRaw(type, data);
1385}
1386
1387DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1388 ArrayRef<char> data) {
1389 assert((type.isa<RankedTensorType, VectorType>()) &&(static_cast <bool> ((type.isa<RankedTensorType, VectorType
>()) && "type must be ranked tensor or vector") ? void
(0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1390, __extension__ __PRETTY_FUNCTION__
))
1390 "type must be ranked tensor or vector")(static_cast <bool> ((type.isa<RankedTensorType, VectorType
>()) && "type must be ranked tensor or vector") ? void
(0) : __assert_fail ("(type.isa<RankedTensorType, VectorType>()) && \"type must be ranked tensor or vector\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1390, __extension__ __PRETTY_FUNCTION__
))
;
1391 assert(type.hasStaticShape() && "type must have static shape")(static_cast <bool> (type.hasStaticShape() && "type must have static shape"
) ? void (0) : __assert_fail ("type.hasStaticShape() && \"type must have static shape\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1391, __extension__ __PRETTY_FUNCTION__
))
;
1392 bool isSplat = false;
1393 bool isValid = isValidRawBuffer(type, data, isSplat);
1394 assert(isValid)(static_cast <bool> (isValid) ? void (0) : __assert_fail
("isValid", "mlir/lib/IR/BuiltinAttributes.cpp", 1394, __extension__
__PRETTY_FUNCTION__))
;
1395 (void)isValid;
1396 return Base::get(type.getContext(), type, data, isSplat);
1397}
1398
1399/// Overload of the raw 'get' method that asserts that the given type is of
1400/// complex type. This method is used to verify type invariants that the
1401/// templatized 'get' method cannot.
1402DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1403 ArrayRef<char> data,
1404 int64_t dataEltSize,
1405 bool isInt,
1406 bool isSigned) {
1407 assert(::isValidIntOrFloat((static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1409, __extension__ __PRETTY_FUNCTION__
))
1408 type.getElementType().cast<ComplexType>().getElementType(),(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1409, __extension__ __PRETTY_FUNCTION__
))
1409 dataEltSize / 2, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1409, __extension__ __PRETTY_FUNCTION__
))
;
1410
1411 int64_t numElements = data.size() / dataEltSize;
1412 (void)numElements;
1413 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1413, __extension__ __PRETTY_FUNCTION__
))
;
1414 return getRaw(type, data);
1415}
1416
1417/// Overload of the 'getRaw' method that asserts that the given type is of
1418/// integer type. This method is used to verify type invariants that the
1419/// templatized 'get' method cannot.
1420DenseElementsAttr
1421DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1422 int64_t dataEltSize, bool isInt,
1423 bool isSigned) {
1424 assert((static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1425, __extension__ __PRETTY_FUNCTION__
))
1425 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1425, __extension__ __PRETTY_FUNCTION__
))
;
1426
1427 int64_t numElements = data.size() / dataEltSize;
1428 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1428, __extension__ __PRETTY_FUNCTION__
))
;
1429 (void)numElements;
1430 return getRaw(type, data);
1431}
1432
1433void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1434 const char *inRawData, char *outRawData, size_t elementBitWidth,
1435 size_t numElements) {
1436 using llvm::support::ulittle16_t;
1437 using llvm::support::ulittle32_t;
1438 using llvm::support::ulittle64_t;
1439
1440 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1441, __extension__ __PRETTY_FUNCTION__
))
1441 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1441, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
1442 // NOLINT to avoid warning message about replacing by static_assert()
1443
1444 // Following std::copy_n always converts endianness on BE machine.
1445 switch (elementBitWidth) {
1446 case 16: {
1447 const ulittle16_t *inRawDataPos =
1448 reinterpret_cast<const ulittle16_t *>(inRawData);
1449 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1450 std::copy_n(inRawDataPos, numElements, outDataPos);
1451 break;
1452 }
1453 case 32: {
1454 const ulittle32_t *inRawDataPos =
1455 reinterpret_cast<const ulittle32_t *>(inRawData);
1456 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1457 std::copy_n(inRawDataPos, numElements, outDataPos);
1458 break;
1459 }
1460 case 64: {
1461 const ulittle64_t *inRawDataPos =
1462 reinterpret_cast<const ulittle64_t *>(inRawData);
1463 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1464 std::copy_n(inRawDataPos, numElements, outDataPos);
1465 break;
1466 }
1467 default: {
1468 size_t nBytes = elementBitWidth / CHAR_BIT8;
1469 for (size_t i = 0; i < nBytes; i++)
1470 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1471 break;
1472 }
1473 }
1474}
1475
1476void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1477 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1478 ShapedType type) {
1479 size_t numElements = type.getNumElements();
1480 Type elementType = type.getElementType();
1481 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1482 elementType = complexTy.getElementType();
1483 numElements = numElements * 2;
1484 }
1485 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1486 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1487, __extension__ __PRETTY_FUNCTION__
))
1487 inRawData.size() <= outRawData.size())(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1487, __extension__ __PRETTY_FUNCTION__
))
;
1488 if (elementBitWidth <= CHAR_BIT8)
1489 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1490 else
1491 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1492 elementBitWidth, numElements);
1493}
1494
1495//===----------------------------------------------------------------------===//
1496// DenseFPElementsAttr
1497//===----------------------------------------------------------------------===//
1498
1499template <typename Fn, typename Attr>
1500static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1501 Type newElementType,
1502 llvm::SmallVectorImpl<char> &data) {
1503 size_t bitWidth = getDenseElementBitWidth(newElementType);
1504 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1505
1506 ShapedType newArrayType;
1507 if (inType.isa<RankedTensorType>())
1508 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1509 else if (inType.isa<UnrankedTensorType>())
1510 newArrayType = RankedTensorType::get(inType.getShape(), newElementType);
1511 else if (auto vType = inType.dyn_cast<VectorType>())
1512 newArrayType = VectorType::get(vType.getShape(), newElementType,
1513 vType.getNumScalableDims());
1514 else
1515 assert(newArrayType && "Unhandled tensor type")(static_cast <bool> (newArrayType && "Unhandled tensor type"
) ? void (0) : __assert_fail ("newArrayType && \"Unhandled tensor type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1515, __extension__ __PRETTY_FUNCTION__
))
;
1516
1517 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1518 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT8));
1519
1520 // Functor used to process a single element value of the attribute.
1521 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1522 auto newInt = mapping(value);
1523 assert(newInt.getBitWidth() == bitWidth)(static_cast <bool> (newInt.getBitWidth() == bitWidth) ?
void (0) : __assert_fail ("newInt.getBitWidth() == bitWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1523, __extension__ __PRETTY_FUNCTION__
))
;
1524 writeBits(data.data(), index * storageBitWidth, newInt);
1525 };
1526
1527 // Check for the splat case.
1528 if (attr.isSplat()) {
1529 processElt(*attr.begin(), /*index=*/0);
1530 return newArrayType;
1531 }
1532
1533 // Otherwise, process all of the element values.
1534 uint64_t elementIdx = 0;
1535 for (auto value : attr)
1536 processElt(value, elementIdx++);
1537 return newArrayType;
1538}
1539
1540DenseElementsAttr DenseFPElementsAttr::mapValues(
1541 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1542 llvm::SmallVector<char, 8> elementData;
1543 auto newArrayType =
1544 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1545
1546 return getRaw(newArrayType, elementData);
1547}
1548
1549/// Method for supporting type inquiry through isa, cast and dyn_cast.
1550bool DenseFPElementsAttr::classof(Attribute attr) {
1551 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1552 return denseAttr.getType().getElementType().isa<FloatType>();
1553 return false;
1554}
1555
1556//===----------------------------------------------------------------------===//
1557// DenseIntElementsAttr
1558//===----------------------------------------------------------------------===//
1559
1560DenseElementsAttr DenseIntElementsAttr::mapValues(
1561 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1562 llvm::SmallVector<char, 8> elementData;
1563 auto newArrayType =
1564 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1565 return getRaw(newArrayType, elementData);
1566}
1567
1568/// Method for supporting type inquiry through isa, cast and dyn_cast.
1569bool DenseIntElementsAttr::classof(Attribute attr) {
1570 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1571 return denseAttr.getType().getElementType().isIntOrIndex();
1572 return false;
1573}
1574
1575//===----------------------------------------------------------------------===//
1576// DenseResourceElementsAttr
1577//===----------------------------------------------------------------------===//
1578
1579DenseResourceElementsAttr
1580DenseResourceElementsAttr::get(ShapedType type,
1581 DenseResourceElementsHandle handle) {
1582 return Base::get(type.getContext(), type, handle);
1583}
1584
1585DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1586 StringRef blobName,
1587 AsmResourceBlob blob) {
1588 // Extract the builtin dialect resource manager from context and construct a
1589 // handle by inserting a new resource using the provided blob.
1590 auto &manager =
1591 DenseResourceElementsHandle::getManagerInterface(type.getContext());
1592 return get(type, manager.insert(blobName, std::move(blob)));
1593}
1594
1595//===----------------------------------------------------------------------===//
1596// DenseResourceElementsAttrBase
1597
1598namespace {
1599/// Instantiations of this class provide utilities for interacting with native
1600/// data types in the context of DenseResourceElementsAttr.
1601template <typename T>
1602struct DenseResourceAttrUtil;
1603template <size_t width, bool isSigned>
1604struct DenseResourceElementsAttrIntUtil {
1605 static bool checkElementType(Type eltType) {
1606 IntegerType type = eltType.dyn_cast<IntegerType>();
1607 if (!type || type.getWidth() != width)
1608 return false;
1609 return isSigned ? !type.isUnsigned() : !type.isSigned();
1610 }
1611};
1612template <>
1613struct DenseResourceAttrUtil<bool> {
1614 static bool checkElementType(Type eltType) {
1615 return eltType.isSignlessInteger(1);
1616 }
1617};
1618template <>
1619struct DenseResourceAttrUtil<int8_t>
1620 : public DenseResourceElementsAttrIntUtil<8, true> {};
1621template <>
1622struct DenseResourceAttrUtil<uint8_t>
1623 : public DenseResourceElementsAttrIntUtil<8, false> {};
1624template <>
1625struct DenseResourceAttrUtil<int16_t>
1626 : public DenseResourceElementsAttrIntUtil<16, true> {};
1627template <>
1628struct DenseResourceAttrUtil<uint16_t>
1629 : public DenseResourceElementsAttrIntUtil<16, false> {};
1630template <>
1631struct DenseResourceAttrUtil<int32_t>
1632 : public DenseResourceElementsAttrIntUtil<32, true> {};
1633template <>
1634struct DenseResourceAttrUtil<uint32_t>
1635 : public DenseResourceElementsAttrIntUtil<32, false> {};
1636template <>
1637struct DenseResourceAttrUtil<int64_t>
1638 : public DenseResourceElementsAttrIntUtil<64, true> {};
1639template <>
1640struct DenseResourceAttrUtil<uint64_t>
1641 : public DenseResourceElementsAttrIntUtil<64, false> {};
1642template <>
1643struct DenseResourceAttrUtil<float> {
1644 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1645};
1646template <>
1647struct DenseResourceAttrUtil<double> {
1648 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1649};
1650} // namespace
1651
1652template <typename T>
1653DenseResourceElementsAttrBase<T>
1654DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1655 AsmResourceBlob blob) {
1656 // Check that the blob is in the form we were expecting.
1657 assert(blob.getDataAlignment() == alignof(T) &&(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1658, __extension__ __PRETTY_FUNCTION__
))
1658 "alignment mismatch between expected alignment and blob alignment")(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1658, __extension__ __PRETTY_FUNCTION__
))
;
1659 assert(((blob.getData().size() % sizeof(T)) == 0) &&(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1660, __extension__ __PRETTY_FUNCTION__
))
1660 "size mismatch between expected element width and blob size")(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1660, __extension__ __PRETTY_FUNCTION__
))
;
1661 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1662, __extension__ __PRETTY_FUNCTION__
))
1662 "invalid shape element type for provided type `T`")(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1662, __extension__ __PRETTY_FUNCTION__
))
;
1663 return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
1664 .template cast<DenseResourceElementsAttrBase<T>>();
1665}
1666
1667template <typename T>
1668Optional<ArrayRef<T>>
1669DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1670 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1671 return blob->template getDataAs<T>();
1672 return llvm::None;
1673}
1674
1675template <typename T>
1676bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1677 auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
1678 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1679 resourceAttr.getElementType());
1680}
1681
1682namespace mlir {
1683namespace detail {
1684// Explicit instantiation for all the supported DenseResourceElementsAttr.
1685template class DenseResourceElementsAttrBase<bool>;
1686template class DenseResourceElementsAttrBase<int8_t>;
1687template class DenseResourceElementsAttrBase<int16_t>;
1688template class DenseResourceElementsAttrBase<int32_t>;
1689template class DenseResourceElementsAttrBase<int64_t>;
1690template class DenseResourceElementsAttrBase<uint8_t>;
1691template class DenseResourceElementsAttrBase<uint16_t>;
1692template class DenseResourceElementsAttrBase<uint32_t>;
1693template class DenseResourceElementsAttrBase<uint64_t>;
1694template class DenseResourceElementsAttrBase<float>;
1695template class DenseResourceElementsAttrBase<double>;
1696} // namespace detail
1697} // namespace mlir
1698
1699//===----------------------------------------------------------------------===//
1700// SparseElementsAttr
1701//===----------------------------------------------------------------------===//
1702
1703/// Get a zero APFloat for the given sparse attribute.
1704APFloat SparseElementsAttr::getZeroAPFloat() const {
1705 auto eltType = getElementType().cast<FloatType>();
1706 return APFloat(eltType.getFloatSemantics());
1707}
1708
1709/// Get a zero APInt for the given sparse attribute.
1710APInt SparseElementsAttr::getZeroAPInt() const {
1711 auto eltType = getElementType().cast<IntegerType>();
1712 return APInt::getZero(eltType.getWidth());
1713}
1714
1715/// Get a zero attribute for the given attribute type.
1716Attribute SparseElementsAttr::getZeroAttr() const {
1717 auto eltType = getElementType();
1718
1719 // Handle floating point elements.
1720 if (eltType.isa<FloatType>())
1721 return FloatAttr::get(eltType, 0);
1722
1723 // Handle complex elements.
1724 if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1725 auto eltType = complexTy.getElementType();
1726 Attribute zero;
1727 if (eltType.isa<FloatType>())
1728 zero = FloatAttr::get(eltType, 0);
1729 else // must be integer
1730 zero = IntegerAttr::get(eltType, 0);
1731 return ArrayAttr::get(complexTy.getContext(),
1732 ArrayRef<Attribute>{zero, zero});
1733 }
1734
1735 // Handle string type.
1736 if (getValues().isa<DenseStringElementsAttr>())
1737 return StringAttr::get("", eltType);
1738
1739 // Otherwise, this is an integer.
1740 return IntegerAttr::get(eltType, 0);
1741}
1742
1743/// Flatten, and return, all of the sparse indices in this attribute in
1744/// row-major order.
1745std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1746 std::vector<ptrdiff_t> flatSparseIndices;
1747
1748 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1749 // as a 1-D index array.
1750 auto sparseIndices = getIndices();
1751 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1752 if (sparseIndices.isSplat()) {
1753 SmallVector<uint64_t, 8> indices(getType().getRank(),
1754 *sparseIndexValues.begin());
1755 flatSparseIndices.push_back(getFlattenedIndex(indices));
1756 return flatSparseIndices;
1757 }
1758
1759 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1760 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1761 size_t rank = getType().getRank();
1762 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1763 flatSparseIndices.push_back(getFlattenedIndex(
1764 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1765 return flatSparseIndices;
1766}
1767
1768LogicalResult
1769SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1770 ShapedType type, DenseIntElementsAttr sparseIndices,
1771 DenseElementsAttr values) {
1772 ShapedType valuesType = values.getType();
1773 if (valuesType.getRank() != 1)
1774 return emitError() << "expected 1-d tensor for sparse element values";
1775
1776 // Verify the indices and values shape.
1777 ShapedType indicesType = sparseIndices.getType();
1778 auto emitShapeError = [&]() {
1779 return emitError() << "expected shape ([" << type.getShape()
1780 << "]); inferred shape of indices literal (["
1781 << indicesType.getShape()
1782 << "]); inferred shape of values literal (["
1783 << valuesType.getShape() << "])";
1784 };
1785 // Verify indices shape.
1786 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1787 if (indicesRank == 2) {
1788 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1789 return emitShapeError();
1790 } else if (indicesRank != 1 || rank != 1) {
1791 return emitShapeError();
1792 }
1793 // Verify the values shape.
1794 int64_t numSparseIndices = indicesType.getDimSize(0);
1795 if (numSparseIndices != valuesType.getDimSize(0))
1796 return emitShapeError();
1797
1798 // Verify that the sparse indices are within the value shape.
1799 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1800 return emitError()
1801 << "sparse index #" << indexNum
1802 << " is not contained within the value shape, with index=[" << index
1803 << "], and type=" << type;
1804 };
1805
1806 // Handle the case where the index values are a splat.
1807 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1808 if (sparseIndices.isSplat()) {
1809 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1810 if (!ElementsAttr::isValidIndex(type, indices))
1811 return emitIndexError(0, indices);
1812 return success();
1813 }
1814
1815 // Otherwise, reinterpret each index as an ArrayRef.
1816 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1817 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1818 rank);
1819 if (!ElementsAttr::isValidIndex(type, index))
1820 return emitIndexError(i, index);
1821 }
1822
1823 return success();
1824}
1825
1826//===----------------------------------------------------------------------===//
1827// TypeAttr
1828//===----------------------------------------------------------------------===//
1829
1830void TypeAttr::walkImmediateSubElements(
1831 function_ref<void(Attribute)> walkAttrsFn,
1832 function_ref<void(Type)> walkTypesFn) const {
1833 walkTypesFn(getValue());
1834}
1835
1836Attribute
1837TypeAttr::replaceImmediateSubElements(ArrayRef<Attribute> replAttrs,
1838 ArrayRef<Type> replTypes) const {
1839 return get(replTypes[0]);
1840}
1841
1842//===----------------------------------------------------------------------===//
1843// Attribute Utilities
1844//===----------------------------------------------------------------------===//
1845
1846AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1847 int64_t offset,
1848 MLIRContext *context) {
1849 AffineExpr expr;
1850 unsigned nSymbols = 0;
1851
1852 // AffineExpr for offset.
1853 // Static case.
1854 if (offset != MemRefType::getDynamicStrideOrOffset()) {
1855 auto cst = getAffineConstantExpr(offset, context);
1856 expr = cst;
1857 } else {
1858 // Dynamic case, new symbol for the offset.
1859 auto sym = getAffineSymbolExpr(nSymbols++, context);
1860 expr = sym;
1861 }
1862
1863 // AffineExpr for strides.
1864 for (const auto &en : llvm::enumerate(strides)) {
1865 auto dim = en.index();
1866 auto stride = en.value();
1867 assert(stride != 0 && "Invalid stride specification")(static_cast <bool> (stride != 0 && "Invalid stride specification"
) ? void (0) : __assert_fail ("stride != 0 && \"Invalid stride specification\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1867, __extension__ __PRETTY_FUNCTION__
))
;
1868 auto d = getAffineDimExpr(dim, context);
1869 AffineExpr mult;
1870 // Static case.
1871 if (stride != MemRefType::getDynamicStrideOrOffset())
1872 mult = getAffineConstantExpr(stride, context);
1873 else
1874 // Dynamic case, new symbol for each new stride.
1875 mult = getAffineSymbolExpr(nSymbols++, context);
1876 expr = expr + d * mult;
1877 }
1878
1879 return AffineMap::get(strides.size(), nSymbols, expr);
1880}

/build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() {}
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a newline and indent the printer to the start of the current
331 /// operation.
332 virtual void printNewline() = 0;
333
334 /// Print a block argument in the usual format of:
335 /// %ssaName : type {attr1=42} loc("here")
336 /// where location printing is controlled by the standard internal option.
337 /// You may pass omitType=true to not print a type, and pass an empty
338 /// attribute list if you don't care for attributes.
339 virtual void printRegionArgument(BlockArgument arg,
340 ArrayRef<NamedAttribute> argAttrs = {},
341 bool omitType = false) = 0;
342
343 /// Print implementations for various things an operation contains.
344 virtual void printOperand(Value value) = 0;
345 virtual void printOperand(Value value, raw_ostream &os) = 0;
346
347 /// Print a comma separated list of operands.
348 template <typename ContainerType>
349 void printOperands(const ContainerType &container) {
350 printOperands(container.begin(), container.end());
351 }
352
353 /// Print a comma separated list of operands.
354 template <typename IteratorType>
355 void printOperands(IteratorType it, IteratorType end) {
356 if (it == end)
357 return;
358 printOperand(*it);
359 for (++it; it != end; ++it) {
360 getStream() << ", ";
361 printOperand(*it);
362 }
363 }
364
365 /// Print the given successor.
366 virtual void printSuccessor(Block *successor) = 0;
367
368 /// Print the successor and its operands.
369 virtual void printSuccessorAndUseList(Block *successor,
370 ValueRange succOperands) = 0;
371
372 /// If the specified operation has attributes, print out an attribute
373 /// dictionary with their values. elidedAttrs allows the client to ignore
374 /// specific well known attributes, commonly used if the attribute value is
375 /// printed some other way (like as a fixed operand).
376 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
377 ArrayRef<StringRef> elidedAttrs = {}) = 0;
378
379 /// If the specified operation has attributes, print out an attribute
380 /// dictionary prefixed with 'attributes'.
381 virtual void
382 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
383 ArrayRef<StringRef> elidedAttrs = {}) = 0;
384
385 /// Print the entire operation with the default generic assembly form.
386 /// If `printOpName` is true, then the operation name is printed (the default)
387 /// otherwise it is omitted and the print will start with the operand list.
388 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
389
390 /// Prints a region.
391 /// If 'printEntryBlockArgs' is false, the arguments of the
392 /// block are not printed. If 'printBlockTerminator' is false, the terminator
393 /// operation of the block is not printed. If printEmptyBlock is true, then
394 /// the block header is printed even if the block is empty.
395 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
396 bool printBlockTerminators = true,
397 bool printEmptyBlock = false) = 0;
398
399 /// Renumber the arguments for the specified region to the same names as the
400 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
401 /// operations. If any entry in namesToUse is null, the corresponding
402 /// argument name is left alone.
403 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
404
405 /// Prints an affine map of SSA ids, where SSA id names are used in place
406 /// of dims/symbols.
407 /// Operand values must come from single-result sources, and be valid
408 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
409 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
410 ValueRange operands) = 0;
411
412 /// Prints an affine expression of SSA ids with SSA id names used instead of
413 /// dims and symbols.
414 /// Operand values must come from single-result sources, and be valid
415 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
416 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
417 ValueRange symOperands) = 0;
418
419 /// Print the complete type of an operation in functional form.
420 void printFunctionalType(Operation *op);
421 using AsmPrinter::printFunctionalType;
422};
423
424// Make the implementations convenient to use.
425inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
426 p.printOperand(value);
427 return p;
428}
429
430template <typename T,
431 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
432 !std::is_convertible<T &, Value &>::value,
433 T> * = nullptr>
434inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
435 p.printOperands(values);
436 return p;
437}
438
439inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
440 p.printSuccessor(value);
441 return p;
442}
443
444//===----------------------------------------------------------------------===//
445// AsmParser
446//===----------------------------------------------------------------------===//
447
448/// This base class exposes generic asm parser hooks, usable across the various
449/// derived parsers.
450class AsmParser {
451public:
452 AsmParser() = default;
453 virtual ~AsmParser();
454
455 MLIRContext *getContext() const;
456
457 /// Return the location of the original name token.
458 virtual SMLoc getNameLoc() const = 0;
459
460 //===--------------------------------------------------------------------===//
461 // Utilities
462 //===--------------------------------------------------------------------===//
463
464 /// Emit a diagnostic at the specified location and return failure.
465 virtual InFlightDiagnostic emitError(SMLoc loc,
466 const Twine &message = {}) = 0;
467
468 /// Return a builder which provides useful access to MLIRContext, global
469 /// objects like types and attributes.
470 virtual Builder &getBuilder() const = 0;
471
472 /// Get the location of the next token and store it into the argument. This
473 /// always succeeds.
474 virtual SMLoc getCurrentLocation() = 0;
475 ParseResult getCurrentLocation(SMLoc *loc) {
476 *loc = getCurrentLocation();
477 return success();
478 }
479
480 /// Re-encode the given source location as an MLIR location and return it.
481 /// Note: This method should only be used when a `Location` is necessary, as
482 /// the encoding process is not efficient.
483 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
484
485 //===--------------------------------------------------------------------===//
486 // Token Parsing
487 //===--------------------------------------------------------------------===//
488
489 /// Parse a '->' token.
490 virtual ParseResult parseArrow() = 0;
491
492 /// Parse a '->' token if present
493 virtual ParseResult parseOptionalArrow() = 0;
494
495 /// Parse a `{` token.
496 virtual ParseResult parseLBrace() = 0;
497
498 /// Parse a `{` token if present.
499 virtual ParseResult parseOptionalLBrace() = 0;
500
501 /// Parse a `}` token.
502 virtual ParseResult parseRBrace() = 0;
503
504 /// Parse a `}` token if present.
505 virtual ParseResult parseOptionalRBrace() = 0;
506
507 /// Parse a `:` token.
508 virtual ParseResult parseColon() = 0;
509
510 /// Parse a `:` token if present.
511 virtual ParseResult parseOptionalColon() = 0;
512
513 /// Parse a `,` token.
514 virtual ParseResult parseComma() = 0;
515
516 /// Parse a `,` token if present.
517 virtual ParseResult parseOptionalComma() = 0;
518
519 /// Parse a `=` token.
520 virtual ParseResult parseEqual() = 0;
521
522 /// Parse a `=` token if present.
523 virtual ParseResult parseOptionalEqual() = 0;
524
525 /// Parse a '<' token.
526 virtual ParseResult parseLess() = 0;
527
528 /// Parse a '<' token if present.
529 virtual ParseResult parseOptionalLess() = 0;
530
531 /// Parse a '>' token.
532 virtual ParseResult parseGreater() = 0;
533
534 /// Parse a '>' token if present.
535 virtual ParseResult parseOptionalGreater() = 0;
536
537 /// Parse a '?' token.
538 virtual ParseResult parseQuestion() = 0;
539
540 /// Parse a '?' token if present.
541 virtual ParseResult parseOptionalQuestion() = 0;
542
543 /// Parse a '+' token.
544 virtual ParseResult parsePlus() = 0;
545
546 /// Parse a '+' token if present.
547 virtual ParseResult parseOptionalPlus() = 0;
548
549 /// Parse a '*' token.
550 virtual ParseResult parseStar() = 0;
551
552 /// Parse a '*' token if present.
553 virtual ParseResult parseOptionalStar() = 0;
554
555 /// Parse a '|' token.
556 virtual ParseResult parseVerticalBar() = 0;
557
558 /// Parse a '|' token if present.
559 virtual ParseResult parseOptionalVerticalBar() = 0;
560
561 /// Parse a quoted string token.
562 ParseResult parseString(std::string *string) {
563 auto loc = getCurrentLocation();
564 if (parseOptionalString(string))
565 return emitError(loc, "expected string");
566 return success();
567 }
568
569 /// Parse a quoted string token if present.
570 virtual ParseResult parseOptionalString(std::string *string) = 0;
571
572 /// Parse a `(` token.
573 virtual ParseResult parseLParen() = 0;
574
575 /// Parse a `(` token if present.
576 virtual ParseResult parseOptionalLParen() = 0;
577
578 /// Parse a `)` token.
579 virtual ParseResult parseRParen() = 0;
580
581 /// Parse a `)` token if present.
582 virtual ParseResult parseOptionalRParen() = 0;
583
584 /// Parse a `[` token.
585 virtual ParseResult parseLSquare() = 0;
586
587 /// Parse a `[` token if present.
588 virtual ParseResult parseOptionalLSquare() = 0;
589
590 /// Parse a `]` token.
591 virtual ParseResult parseRSquare() = 0;
592
593 /// Parse a `]` token if present.
594 virtual ParseResult parseOptionalRSquare() = 0;
595
596 /// Parse a `...` token if present;
597 virtual ParseResult parseOptionalEllipsis() = 0;
598
599 /// Parse a floating point value from the stream.
600 virtual ParseResult parseFloat(double &result) = 0;
601
602 /// Parse an integer value from the stream.
603 template <typename IntT>
604 ParseResult parseInteger(IntT &result) {
605 auto loc = getCurrentLocation();
606 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
607 if (!parseResult.has_value())
9
Taking true branch
608 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
609 return *parseResult;
610 }
611
612 /// Parse an optional integer value from the stream.
613 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
614
615 template <typename IntT>
616 OptionalParseResult parseOptionalInteger(IntT &result) {
617 auto loc = getCurrentLocation();
618
619 // Parse the unsigned variant.
620 APInt uintResult;
621 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
622 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
623 return parseResult;
7
Returning without writing to 'result'
624
625 // Try to convert to the provided integer type. sextOrTrunc is correct even
626 // for unsigned types because parseOptionalInteger ensures the sign bit is
627 // zero for non-negated integers.
628 result =
629 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
630 if (APInt(uintResult.getBitWidth(), result) != uintResult)
631 return emitError(loc, "integer value too large");
632 return success();
633 }
634
635 /// These are the supported delimiters around operand lists and region
636 /// argument lists, used by parseOperandList.
637 enum class Delimiter {
638 /// Zero or more operands with no delimiters.
639 None,
640 /// Parens surrounding zero or more operands.
641 Paren,
642 /// Square brackets surrounding zero or more operands.
643 Square,
644 /// <> brackets surrounding zero or more operands.
645 LessGreater,
646 /// {} brackets surrounding zero or more operands.
647 Braces,
648 /// Parens supporting zero or more operands, or nothing.
649 OptionalParen,
650 /// Square brackets supporting zero or more ops, or nothing.
651 OptionalSquare,
652 /// <> brackets supporting zero or more ops, or nothing.
653 OptionalLessGreater,
654 /// {} brackets surrounding zero or more operands, or nothing.
655 OptionalBraces,
656 };
657
658 /// Parse a list of comma-separated items with an optional delimiter. If a
659 /// delimiter is provided, then an empty list is allowed. If not, then at
660 /// least one element will be parsed.
661 ///
662 /// contextMessage is an optional message appended to "expected '('" sorts of
663 /// diagnostics when parsing the delimeters.
664 virtual ParseResult
665 parseCommaSeparatedList(Delimiter delimiter,
666 function_ref<ParseResult()> parseElementFn,
667 StringRef contextMessage = StringRef()) = 0;
668
669 /// Parse a comma separated list of elements that must have at least one entry
670 /// in it.
671 ParseResult
672 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
673 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
674 }
675
676 //===--------------------------------------------------------------------===//
677 // Keyword Parsing
678 //===--------------------------------------------------------------------===//
679
680 /// This class represents a StringSwitch like class that is useful for parsing
681 /// expected keywords. On construction, it invokes `parseKeyword` and
682 /// processes each of the provided cases statements until a match is hit. The
683 /// provided `ResultT` must be assignable from `failure()`.
684 template <typename ResultT = ParseResult>
685 class KeywordSwitch {
686 public:
687 KeywordSwitch(AsmParser &parser)
688 : parser(parser), loc(parser.getCurrentLocation()) {
689 if (failed(parser.parseKeywordOrCompletion(&keyword)))
690 result = failure();
691 }
692
693 /// Case that uses the provided value when true.
694 KeywordSwitch &Case(StringLiteral str, ResultT value) {
695 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
696 }
697 KeywordSwitch &Default(ResultT value) {
698 return Default([&](StringRef, SMLoc) { return std::move(value); });
699 }
700 /// Case that invokes the provided functor when true. The parameters passed
701 /// to the functor are the keyword, and the location of the keyword (in case
702 /// any errors need to be emitted).
703 template <typename FnT>
704 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
705 Case(StringLiteral str, FnT &&fn) {
706 if (result)
707 return *this;
708
709 // If the word was empty, record this as a completion.
710 if (keyword.empty())
711 parser.codeCompleteExpectedTokens(str);
712 else if (keyword == str)
713 result.emplace(std::move(fn(keyword, loc)));
714 return *this;
715 }
716 template <typename FnT>
717 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
718 Default(FnT &&fn) {
719 if (!result)
720 result.emplace(fn(keyword, loc));
721 return *this;
722 }
723
724 /// Returns true if this switch has a value yet.
725 bool hasValue() const { return result.has_value(); }
726
727 /// Return the result of the switch.
728 [[nodiscard]] operator ResultT() {
729 if (!result)
730 return parser.emitError(loc, "unexpected keyword: ") << keyword;
731 return std::move(*result);
732 }
733
734 private:
735 /// The parser used to construct this switch.
736 AsmParser &parser;
737
738 /// The location of the keyword, used to emit errors as necessary.
739 SMLoc loc;
740
741 /// The parsed keyword itself.
742 StringRef keyword;
743
744 /// The result of the switch statement or none if currently unknown.
745 Optional<ResultT> result;
746 };
747
748 /// Parse a given keyword.
749 ParseResult parseKeyword(StringRef keyword) {
750 return parseKeyword(keyword, "");
751 }
752 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
753
754 /// Parse a keyword into 'keyword'.
755 ParseResult parseKeyword(StringRef *keyword) {
756 auto loc = getCurrentLocation();
757 if (parseOptionalKeyword(keyword))
758 return emitError(loc, "expected valid keyword");
759 return success();
760 }
761
762 /// Parse the given keyword if present.
763 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
764
765 /// Parse a keyword, if present, into 'keyword'.
766 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
767
768 /// Parse a keyword, if present, and if one of the 'allowedValues',
769 /// into 'keyword'
770 virtual ParseResult
771 parseOptionalKeyword(StringRef *keyword,
772 ArrayRef<StringRef> allowedValues) = 0;
773
774 /// Parse a keyword or a quoted string.
775 ParseResult parseKeywordOrString(std::string *result) {
776 if (failed(parseOptionalKeywordOrString(result)))
777 return emitError(getCurrentLocation())
778 << "expected valid keyword or string";
779 return success();
780 }
781
782 /// Parse an optional keyword or string.
783 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
784
785 //===--------------------------------------------------------------------===//
786 // Attribute/Type Parsing
787 //===--------------------------------------------------------------------===//
788
789 /// Invoke the `getChecked` method of the given Attribute or Type class, using
790 /// the provided location to emit errors in the case of failure. Note that
791 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
792 /// context parameter.
793 template <typename T, typename... ParamsT>
794 auto getChecked(SMLoc loc, ParamsT &&...params) {
795 return T::getChecked([&] { return emitError(loc); },
796 std::forward<ParamsT>(params)...);
797 }
798 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
799 /// errors.
800 template <typename T, typename... ParamsT>
801 auto getChecked(ParamsT &&...params) {
802 return T::getChecked([&] { return emitError(getNameLoc()); },
803 std::forward<ParamsT>(params)...);
804 }
805
806 //===--------------------------------------------------------------------===//
807 // Attribute Parsing
808 //===--------------------------------------------------------------------===//
809
810 /// Parse an arbitrary attribute of a given type and return it in result.
811 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
812
813 /// Parse a custom attribute with the provided callback, unless the next
814 /// token is `#`, in which case the generic parser is invoked.
815 virtual ParseResult parseCustomAttributeWithFallback(
816 Attribute &result, Type type,
817 function_ref<ParseResult(Attribute &result, Type type)>
818 parseAttribute) = 0;
819
820 /// Parse an attribute of a specific kind and type.
821 template <typename AttrType>
822 ParseResult parseAttribute(AttrType &result, Type type = {}) {
823 SMLoc loc = getCurrentLocation();
824
825 // Parse any kind of attribute.
826 Attribute attr;
827 if (parseAttribute(attr, type))
828 return failure();
829
830 // Check for the right kind of attribute.
831 if (!(result = attr.dyn_cast<AttrType>()))
832 return emitError(loc, "invalid kind of attribute specified");
833
834 return success();
835 }
836
837 /// Parse an arbitrary attribute and return it in result. This also adds the
838 /// attribute to the specified attribute list with the specified name.
839 ParseResult parseAttribute(Attribute &result, StringRef attrName,
840 NamedAttrList &attrs) {
841 return parseAttribute(result, Type(), attrName, attrs);
842 }
843
844 /// Parse an attribute of a specific kind and type.
845 template <typename AttrType>
846 ParseResult parseAttribute(AttrType &result, StringRef attrName,
847 NamedAttrList &attrs) {
848 return parseAttribute(result, Type(), attrName, attrs);
849 }
850
851 /// Parse an arbitrary attribute of a given type and populate it in `result`.
852 /// This also adds the attribute to the specified attribute list with the
853 /// specified name.
854 template <typename AttrType>
855 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
856 NamedAttrList &attrs) {
857 SMLoc loc = getCurrentLocation();
858
859 // Parse any kind of attribute.
860 Attribute attr;
861 if (parseAttribute(attr, type))
862 return failure();
863
864 // Check for the right kind of attribute.
865 result = attr.dyn_cast<AttrType>();
866 if (!result)
867 return emitError(loc, "invalid kind of attribute specified");
868
869 attrs.append(attrName, result);
870 return success();
871 }
872
873 /// Trait to check if `AttrType` provides a `parse` method.
874 template <typename AttrType>
875 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
876 std::declval<Type>()));
877 template <typename AttrType>
878 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
879
880 /// Parse a custom attribute of a given type unless the next token is `#`, in
881 /// which case the generic parser is invoked. The parsed attribute is
882 /// populated in `result` and also added to the specified attribute list with
883 /// the specified name.
884 template <typename AttrType>
885 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
886 parseCustomAttributeWithFallback(AttrType &result, Type type,
887 StringRef attrName, NamedAttrList &attrs) {
888 SMLoc loc = getCurrentLocation();
889
890 // Parse any kind of attribute.
891 Attribute attr;
892 if (parseCustomAttributeWithFallback(
893 attr, type, [&](Attribute &result, Type type) -> ParseResult {
894 result = AttrType::parse(*this, type);
895 if (!result)
896 return failure();
897 return success();
898 }))
899 return failure();
900
901 // Check for the right kind of attribute.
902 result = attr.dyn_cast<AttrType>();
903 if (!result)
904 return emitError(loc, "invalid kind of attribute specified");
905
906 attrs.append(attrName, result);
907 return success();
908 }
909
910 /// SFINAE parsing method for Attribute that don't implement a parse method.
911 template <typename AttrType>
912 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
913 parseCustomAttributeWithFallback(AttrType &result, Type type,
914 StringRef attrName, NamedAttrList &attrs) {
915 return parseAttribute(result, type, attrName, attrs);
916 }
917
918 /// Parse a custom attribute of a given type unless the next token is `#`, in
919 /// which case the generic parser is invoked. The parsed attribute is
920 /// populated in `result`.
921 template <typename AttrType>
922 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
923 parseCustomAttributeWithFallback(AttrType &result) {
924 SMLoc loc = getCurrentLocation();
925
926 // Parse any kind of attribute.
927 Attribute attr;
928 if (parseCustomAttributeWithFallback(
929 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
930 result = AttrType::parse(*this, type);
931 return success(!!result);
932 }))
933 return failure();
934
935 // Check for the right kind of attribute.
936 result = attr.dyn_cast<AttrType>();
937 if (!result)
938 return emitError(loc, "invalid kind of attribute specified");
939 return success();
940 }
941
942 /// SFINAE parsing method for Attribute that don't implement a parse method.
943 template <typename AttrType>
944 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
945 parseCustomAttributeWithFallback(AttrType &result) {
946 return parseAttribute(result);
947 }
948
949 /// Parse an arbitrary optional attribute of a given type and return it in
950 /// result.
951 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
952 Type type = {}) = 0;
953
954 /// Parse an optional array attribute and return it in result.
955 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
956 Type type = {}) = 0;
957
958 /// Parse an optional string attribute and return it in result.
959 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
960 Type type = {}) = 0;
961
962 /// Parse an optional attribute of a specific type and add it to the list with
963 /// the specified name.
964 template <typename AttrType>
965 OptionalParseResult parseOptionalAttribute(AttrType &result,
966 StringRef attrName,
967 NamedAttrList &attrs) {
968 return parseOptionalAttribute(result, Type(), attrName, attrs);
969 }
970
971 /// Parse an optional attribute of a specific type and add it to the list with
972 /// the specified name.
973 template <typename AttrType>
974 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
975 StringRef attrName,
976 NamedAttrList &attrs) {
977 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
978 if (parseResult.has_value() && succeeded(*parseResult))
979 attrs.append(attrName, result);
980 return parseResult;
981 }
982
983 /// Parse a named dictionary into 'result' if it is present.
984 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
985
986 /// Parse a named dictionary into 'result' if the `attributes` keyword is
987 /// present.
988 virtual ParseResult
989 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
990
991 /// Parse an affine map instance into 'map'.
992 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
993
994 /// Parse an integer set instance into 'set'.
995 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
996
997 //===--------------------------------------------------------------------===//
998 // Identifier Parsing
999 //===--------------------------------------------------------------------===//
1000
1001 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1002 /// attribute named 'attrName'.
1003 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1004 NamedAttrList &attrs) {
1005 if (failed(parseOptionalSymbolName(result, attrName, attrs)))
1006 return emitError(getCurrentLocation())
1007 << "expected valid '@'-identifier for symbol name";
1008 return success();
1009 }
1010
1011 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1012 /// string attribute named 'attrName'.
1013 virtual ParseResult parseOptionalSymbolName(StringAttr &result,
1014 StringRef attrName,
1015 NamedAttrList &attrs) = 0;
1016
1017 //===--------------------------------------------------------------------===//
1018 // Resource Parsing
1019 //===--------------------------------------------------------------------===//
1020
1021 /// Parse a handle to a resource within the assembly format.
1022 template <typename ResourceT>
1023 FailureOr<ResourceT> parseResourceHandle() {
1024 SMLoc handleLoc = getCurrentLocation();
1025
1026 // Try to load the dialect that owns the handle.
1027 auto *dialect =
1028 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1029 if (!dialect) {
1030 return emitError(handleLoc)
1031 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1032 << "' is unknown";
1033 }
1034
1035 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1036 if (failed(handle))
1037 return failure();
1038 if (auto *result = dyn_cast<ResourceT>(&*handle))
1039 return std::move(*result);
1040 return emitError(handleLoc) << "provided resource handle differs from the "
1041 "expected resource type";
1042 }
1043
1044 //===--------------------------------------------------------------------===//
1045 // Type Parsing
1046 //===--------------------------------------------------------------------===//
1047
1048 /// Parse a type.
1049 virtual ParseResult parseType(Type &result) = 0;
1050
1051 /// Parse a custom type with the provided callback, unless the next
1052 /// token is `#`, in which case the generic parser is invoked.
1053 virtual ParseResult parseCustomTypeWithFallback(
1054 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1055
1056 /// Parse an optional type.
1057 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1058
1059 /// Parse a type of a specific type.
1060 template <typename TypeT>
1061 ParseResult parseType(TypeT &result) {
1062 SMLoc loc = getCurrentLocation();
1063
1064 // Parse any kind of type.
1065 Type type;
1066 if (parseType(type))
1067 return failure();
1068
1069 // Check for the right kind of type.
1070 result = type.dyn_cast<TypeT>();
1071 if (!result)
1072 return emitError(loc, "invalid kind of type specified");
1073
1074 return success();
1075 }
1076
1077 /// Trait to check if `TypeT` provides a `parse` method.
1078 template <typename TypeT>
1079 using type_has_parse_method =
1080 decltype(TypeT::parse(std::declval<AsmParser &>()));
1081 template <typename TypeT>
1082 using detect_type_has_parse_method =
1083 llvm::is_detected<type_has_parse_method, TypeT>;
1084
1085 /// Parse a custom Type of a given type unless the next token is `#`, in
1086 /// which case the generic parser is invoked. The parsed Type is
1087 /// populated in `result`.
1088 template <typename TypeT>
1089 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1090 parseCustomTypeWithFallback(TypeT &result) {
1091 SMLoc loc = getCurrentLocation();
1092
1093 // Parse any kind of Type.
1094 Type type;
1095 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1096 result = TypeT::parse(*this);
1097 return success(!!result);
1098 }))
1099 return failure();
1100
1101 // Check for the right kind of Type.
1102 result = type.dyn_cast<TypeT>();
1103 if (!result)
1104 return emitError(loc, "invalid kind of Type specified");
1105 return success();
1106 }
1107
1108 /// SFINAE parsing method for Type that don't implement a parse method.
1109 template <typename TypeT>
1110 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1111 parseCustomTypeWithFallback(TypeT &result) {
1112 return parseType(result);
1113 }
1114
1115 /// Parse a type list.
1116 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1117 return parseCommaSeparatedList(
1118 [&]() { return parseType(result.emplace_back()); });
1119 }
1120
1121 /// Parse an arrow followed by a type list.
1122 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1123
1124 /// Parse an optional arrow followed by a type list.
1125 virtual ParseResult
1126 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1127
1128 /// Parse a colon followed by a type.
1129 virtual ParseResult parseColonType(Type &result) = 0;
1130
1131 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1132 template <typename TypeType>
1133 ParseResult parseColonType(TypeType &result) {
1134 SMLoc loc = getCurrentLocation();
1135
1136 // Parse any kind of type.
1137 Type type;
1138 if (parseColonType(type))
1139 return failure();
1140
1141 // Check for the right kind of type.
1142 result = type.dyn_cast<TypeType>();
1143 if (!result)
1144 return emitError(loc, "invalid kind of type specified");
1145
1146 return success();
1147 }
1148
1149 /// Parse a colon followed by a type list, which must have at least one type.
1150 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1151
1152 /// Parse an optional colon followed by a type list, which if present must
1153 /// have at least one type.
1154 virtual ParseResult
1155 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1156
1157 /// Parse a keyword followed by a type.
1158 ParseResult parseKeywordType(const char *keyword, Type &result) {
1159 return failure(parseKeyword(keyword) || parseType(result));
1160 }
1161
1162 /// Add the specified type to the end of the specified type list and return
1163 /// success. This is a helper designed to allow parse methods to be simple
1164 /// and chain through || operators.
1165 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1166 result.push_back(type);
1167 return success();
1168 }
1169
1170 /// Add the specified types to the end of the specified type list and return
1171 /// success. This is a helper designed to allow parse methods to be simple
1172 /// and chain through || operators.
1173 ParseResult addTypesToList(ArrayRef<Type> types,
1174 SmallVectorImpl<Type> &result) {
1175 result.append(types.begin(), types.end());
1176 return success();
1177 }
1178
1179 /// Parse a dimension list of a tensor or memref type. This populates the
1180 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1181 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1182 ///
1183 /// dimension-list ::= eps | dimension (`x` dimension)*
1184 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1185 /// dimension ::= `?` | decimal-literal
1186 ///
1187 /// When `allowDynamic` is not set, this is used to parse:
1188 ///
1189 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1190 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1191 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1192 bool allowDynamic = true,
1193 bool withTrailingX = true) = 0;
1194
1195 /// Parse an 'x' token in a dimension list, handling the case where the x is
1196 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1197 /// next token.
1198 virtual ParseResult parseXInDimensionList() = 0;
1199
1200protected:
1201 /// Parse a handle to a resource within the assembly format for the given
1202 /// dialect.
1203 virtual FailureOr<AsmDialectResourceHandle>
1204 parseResourceHandle(Dialect *dialect) = 0;
1205
1206 //===--------------------------------------------------------------------===//
1207 // Code Completion
1208 //===--------------------------------------------------------------------===//
1209
1210 /// Parse a keyword, or an empty string if the current location signals a code
1211 /// completion.
1212 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1213
1214 /// Signal the code completion of a set of expected tokens.
1215 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1216
1217private:
1218 AsmParser(const AsmParser &) = delete;
1219 void operator=(const AsmParser &) = delete;
1220};
1221
1222//===----------------------------------------------------------------------===//
1223// OpAsmParser
1224//===----------------------------------------------------------------------===//
1225
1226/// The OpAsmParser has methods for interacting with the asm parser: parsing
1227/// things from it, emitting errors etc. It has an intentionally high-level API
1228/// that is designed to reduce/constrain syntax innovation in individual
1229/// operations.
1230///
1231/// For example, consider an op like this:
1232///
1233/// %x = load %p[%1, %2] : memref<...>
1234///
1235/// The "%x = load" tokens are already parsed and therefore invisible to the
1236/// custom op parser. This can be supported by calling `parseOperandList` to
1237/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1238/// parse the indices, then calling `parseColonTypeList` to parse the result
1239/// type.
1240///
1241class OpAsmParser : public AsmParser {
1242public:
1243 using AsmParser::AsmParser;
1244 ~OpAsmParser() override;
1245
1246 /// Parse a loc(...) specifier if present, filling in result if so.
1247 /// Location for BlockArgument and Operation may be deferred with an alias, in
1248 /// which case an OpaqueLoc is set and will be resolved when parsing
1249 /// completes.
1250 virtual ParseResult
1251 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1252
1253 /// Return the name of the specified result in the specified syntax, as well
1254 /// as the sub-element in the name. It returns an empty string and ~0U for
1255 /// invalid result numbers. For example, in this operation:
1256 ///
1257 /// %x, %y:2, %z = foo.op
1258 ///
1259 /// getResultName(0) == {"x", 0 }
1260 /// getResultName(1) == {"y", 0 }
1261 /// getResultName(2) == {"y", 1 }
1262 /// getResultName(3) == {"z", 0 }
1263 /// getResultName(4) == {"", ~0U }
1264 virtual std::pair<StringRef, unsigned>
1265 getResultName(unsigned resultNo) const = 0;
1266
1267 /// Return the number of declared SSA results. This returns 4 for the foo.op
1268 /// example in the comment for `getResultName`.
1269 virtual size_t getNumResults() const = 0;
1270
1271 // These methods emit an error and return failure or success. This allows
1272 // these to be chained together into a linear sequence of || expressions in
1273 // many cases.
1274
1275 /// Parse an operation in its generic form.
1276 /// The parsed operation is parsed in the current context and inserted in the
1277 /// provided block and insertion point. The results produced by this operation
1278 /// aren't mapped to any named value in the parser. Returns nullptr on
1279 /// failure.
1280 virtual Operation *parseGenericOperation(Block *insertBlock,
1281 Block::iterator insertPt) = 0;
1282
1283 /// Parse the name of an operation, in the custom form. On success, return a
1284 /// an object of type 'OperationName'. Otherwise, failure is returned.
1285 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1286
1287 //===--------------------------------------------------------------------===//
1288 // Operand Parsing
1289 //===--------------------------------------------------------------------===//
1290
1291 /// This is the representation of an operand reference.
1292 struct UnresolvedOperand {
1293 SMLoc location; // Location of the token.
1294 StringRef name; // Value name, e.g. %42 or %abc
1295 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1296 };
1297
1298 /// Parse different components, viz., use-info of operand(s), successor(s),
1299 /// region(s), attribute(s) and function-type, of the generic form of an
1300 /// operation instance and populate the input operation-state 'result' with
1301 /// those components. If any of the components is explicitly provided, then
1302 /// skip parsing that component.
1303 virtual ParseResult parseGenericOperationAfterOpName(
1304 OperationState &result,
1305 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = llvm::None,
1306 Optional<ArrayRef<Block *>> parsedSuccessors = llvm::None,
1307 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1308 llvm::None,
1309 Optional<ArrayRef<NamedAttribute>> parsedAttributes = llvm::None,
1310 Optional<FunctionType> parsedFnType = llvm::None) = 0;
1311
1312 /// Parse a single SSA value operand name along with a result number if
1313 /// `allowResultNumber` is true.
1314 virtual ParseResult parseOperand(UnresolvedOperand &result,
1315 bool allowResultNumber = true) = 0;
1316
1317 /// Parse a single operand if present.
1318 virtual OptionalParseResult
1319 parseOptionalOperand(UnresolvedOperand &result,
1320 bool allowResultNumber = true) = 0;
1321
1322 /// Parse zero or more SSA comma-separated operand references with a specified
1323 /// surrounding delimiter, and an optional required operand count.
1324 virtual ParseResult
1325 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1326 Delimiter delimiter = Delimiter::None,
1327 bool allowResultNumber = true,
1328 int requiredOperandCount = -1) = 0;
1329
1330 /// Parse a specified number of comma separated operands.
1331 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1332 int requiredOperandCount,
1333 Delimiter delimiter = Delimiter::None) {
1334 return parseOperandList(result, delimiter,
1335 /*allowResultNumber=*/true, requiredOperandCount);
1336 }
1337
1338 /// Parse zero or more trailing SSA comma-separated trailing operand
1339 /// references with a specified surrounding delimiter, and an optional
1340 /// required operand count. A leading comma is expected before the
1341 /// operands.
1342 ParseResult
1343 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1344 Delimiter delimiter = Delimiter::None) {
1345 if (failed(parseOptionalComma()))
1346 return success(); // The comma is optional.
1347 return parseOperandList(result, delimiter);
1348 }
1349
1350 /// Resolve an operand to an SSA value, emitting an error on failure.
1351 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1352 Type type,
1353 SmallVectorImpl<Value> &result) = 0;
1354
1355 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1356 /// appending the results to the list on success. This method should be used
1357 /// when all operands have the same type.
1358 template <typename Operands = ArrayRef<UnresolvedOperand>>
1359 ParseResult resolveOperands(Operands &&operands, Type type,
1360 SmallVectorImpl<Value> &result) {
1361 for (const UnresolvedOperand &operand : operands)
1362 if (resolveOperand(operand, type, result))
1363 return failure();
1364 return success();
1365 }
1366 template <typename Operands = ArrayRef<UnresolvedOperand>>
1367 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1368 SmallVectorImpl<Value> &result) {
1369 return resolveOperands(std::forward<Operands>(operands), type, result);
1370 }
1371
1372 /// Resolve a list of operands and a list of operand types to SSA values,
1373 /// emitting an error and returning failure, or appending the results
1374 /// to the list on success.
1375 template <typename Operands = ArrayRef<UnresolvedOperand>,
1376 typename Types = ArrayRef<Type>>
1377 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1378 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1379 SmallVectorImpl<Value> &result) {
1380 size_t operandSize = std::distance(operands.begin(), operands.end());
1381 size_t typeSize = std::distance(types.begin(), types.end());
1382 if (operandSize != typeSize)
1383 return emitError(loc)
1384 << operandSize << " operands present, but expected " << typeSize;
1385
1386 for (auto [operand, type] : llvm::zip(operands, types))
1387 if (resolveOperand(operand, type, result))
1388 return failure();
1389 return success();
1390 }
1391
1392 /// Parses an affine map attribute where dims and symbols are SSA operands.
1393 /// Operand values must come from single-result sources, and be valid
1394 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1395 virtual ParseResult
1396 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1397 Attribute &map, StringRef attrName,
1398 NamedAttrList &attrs,
1399 Delimiter delimiter = Delimiter::Square) = 0;
1400
1401 /// Parses an affine expression where dims and symbols are SSA operands.
1402 /// Operand values must come from single-result sources, and be valid
1403 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1404 virtual ParseResult
1405 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1406 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1407 AffineExpr &expr) = 0;
1408
1409 //===--------------------------------------------------------------------===//
1410 // Argument Parsing
1411 //===--------------------------------------------------------------------===//
1412
1413 struct Argument {
1414 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1415 Type type; // Type.
1416 DictionaryAttr attrs; // Attributes if present.
1417 Optional<Location> sourceLoc; // Source location specifier if present.
1418 };
1419
1420 /// Parse a single argument with the following syntax:
1421 ///
1422 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1423 ///
1424 /// If `allowType` is false or `allowAttrs` are false then the respective
1425 /// parts of the grammar are not parsed.
1426 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1427 bool allowAttrs = false) = 0;
1428
1429 /// Parse a single argument if present.
1430 virtual OptionalParseResult
1431 parseOptionalArgument(Argument &result, bool allowType = false,
1432 bool allowAttrs = false) = 0;
1433
1434 /// Parse zero or more arguments with a specified surrounding delimiter.
1435 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1436 Delimiter delimiter = Delimiter::None,
1437 bool allowType = false,
1438 bool allowAttrs = false) = 0;
1439
1440 //===--------------------------------------------------------------------===//
1441 // Region Parsing
1442 //===--------------------------------------------------------------------===//
1443
1444 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1445 /// moved to the op regions after the op is created. The first block of the
1446 /// region takes 'arguments'.
1447 ///
1448 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1449 /// shadow the names of other existing SSA values defined above the region
1450 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1451 /// to operations that are 'IsolatedFromAbove'.
1452 virtual ParseResult parseRegion(Region &region,
1453 ArrayRef<Argument> arguments = {},
1454 bool enableNameShadowing = false) = 0;
1455
1456 /// Parses a region if present.
1457 virtual OptionalParseResult
1458 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1459 bool enableNameShadowing = false) = 0;
1460
1461 /// Parses a region if present. If the region is present, a new region is
1462 /// allocated and placed in `region`. If no region is present or on failure,
1463 /// `region` remains untouched.
1464 virtual OptionalParseResult
1465 parseOptionalRegion(std::unique_ptr<Region> &region,
1466 ArrayRef<Argument> arguments = {},
1467 bool enableNameShadowing = false) = 0;
1468
1469 //===--------------------------------------------------------------------===//
1470 // Successor Parsing
1471 //===--------------------------------------------------------------------===//
1472
1473 /// Parse a single operation successor.
1474 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1475
1476 /// Parse an optional operation successor.
1477 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1478
1479 /// Parse a single operation successor and its operand list.
1480 virtual ParseResult
1481 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1482
1483 //===--------------------------------------------------------------------===//
1484 // Type Parsing
1485 //===--------------------------------------------------------------------===//
1486
1487 /// Parse a list of assignments of the form
1488 /// (%x1 = %y1, %x2 = %y2, ...)
1489 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1490 SmallVectorImpl<UnresolvedOperand> &rhs) {
1491 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1492 if (!result.has_value())
1493 return emitError(getCurrentLocation(), "expected '('");
1494 return result.value();
1495 }
1496
1497 virtual OptionalParseResult
1498 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1499 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1500};
1501
1502//===--------------------------------------------------------------------===//
1503// Dialect OpAsm interface.
1504//===--------------------------------------------------------------------===//
1505
1506/// A functor used to set the name of the start of a result group of an
1507/// operation. See 'getAsmResultNames' below for more details.
1508using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1509
1510/// A functor used to set the name of blocks in regions directly nested under
1511/// an operation.
1512using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1513
1514class OpAsmDialectInterface
1515 : public DialectInterface::Base<OpAsmDialectInterface> {
1516public:
1517 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1518
1519 //===------------------------------------------------------------------===//
1520 // Aliases
1521 //===------------------------------------------------------------------===//
1522
1523 /// Holds the result of `getAlias` hook call.
1524 enum class AliasResult {
1525 /// The object (type or attribute) is not supported by the hook
1526 /// and an alias was not provided.
1527 NoAlias,
1528 /// An alias was provided, but it might be overriden by other hook.
1529 OverridableAlias,
1530 /// An alias was provided and it should be used
1531 /// (no other hooks will be checked).
1532 FinalAlias
1533 };
1534
1535 /// Hooks for getting an alias identifier alias for a given symbol, that is
1536 /// not necessarily a part of this dialect. The identifier is used in place of
1537 /// the symbol when printing textual IR. These aliases must not contain `.` or
1538 /// end with a numeric digit([0-9]+).
1539 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1540 return AliasResult::NoAlias;
1541 }
1542 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1543 return AliasResult::NoAlias;
1544 }
1545
1546 //===--------------------------------------------------------------------===//
1547 // Resources
1548 //===--------------------------------------------------------------------===//
1549
1550 /// Declare a resource with the given key, returning a handle to use for any
1551 /// references of this resource key within the IR during parsing. The result
1552 /// of `getResourceKey` on the returned handle is permitted to be different
1553 /// than `key`.
1554 virtual FailureOr<AsmDialectResourceHandle>
1555 declareResource(StringRef key) const {
1556 return failure();
1557 }
1558
1559 /// Return a key to use for the given resource. This key should uniquely
1560 /// identify this resource within the dialect.
1561 virtual std::string
1562 getResourceKey(const AsmDialectResourceHandle &handle) const {
1563 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1564)
1564 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1564)
;
1565 }
1566
1567 /// Hook for parsing resource entries. Returns failure if the entry was not
1568 /// valid, or could otherwise not be processed correctly. Any necessary errors
1569 /// can be emitted via the provided entry.
1570 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1571
1572 /// Hook for building resources to use during printing. The given `op` may be
1573 /// inspected to help determine what information to include.
1574 /// `referencedResources` contains all of the resources detected when printing
1575 /// 'op'.
1576 virtual void
1577 buildResources(Operation *op,
1578 const SetVector<AsmDialectResourceHandle> &referencedResources,
1579 AsmResourceBuilder &builder) const {}
1580};
1581} // namespace mlir
1582
1583//===--------------------------------------------------------------------===//
1584// Operation OpAsm interface.
1585//===--------------------------------------------------------------------===//
1586
1587/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1588#include "mlir/IR/OpAsmInterface.h.inc"
1589
1590namespace llvm {
1591template <>
1592struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1593 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1594 return {DenseMapInfo<void *>::getEmptyKey(),
1595 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1596 }
1597 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1598 return {DenseMapInfo<void *>::getTombstoneKey(),
1599 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1600 }
1601 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1602 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1603 }
1604 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1605 const mlir::AsmDialectResourceHandle &rhs) {
1606 return lhs.getResource() == rhs.getResource();
1607 }
1608};
1609} // namespace llvm
1610
1611#endif