Bug Summary

File:build/source/mlir/lib/IR/BuiltinAttributes.cpp
Warning:line 819, column 9
1st function call argument is an uninitialized value

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name BuiltinAttributes.cpp -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-16/lib/clang/16 -I tools/mlir/lib/IR -I /build/source/mlir/lib/IR -I include -I /build/source/llvm/include -I /build/source/mlir/include -I tools/mlir/include -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-16/lib/clang/16/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/source/= -fcoverage-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/source/= -source-date-epoch 1671487667 -O2 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -Wno-misleading-indentation -std=c++17 -fdeprecated-macro -fdebug-compilation-dir=/build/source/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/source/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-12-20-010714-16201-1 -x c++ /build/source/mlir/lib/IR/BuiltinAttributes.cpp

/build/source/mlir/lib/IR/BuiltinAttributes.cpp

1//===- BuiltinAttributes.cpp - MLIR Builtin Attribute Classes -------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/IR/BuiltinAttributes.h"
10#include "AttributeDetail.h"
11#include "mlir/IR/AffineMap.h"
12#include "mlir/IR/BuiltinDialect.h"
13#include "mlir/IR/Dialect.h"
14#include "mlir/IR/DialectResourceBlobManager.h"
15#include "mlir/IR/IntegerSet.h"
16#include "mlir/IR/OpImplementation.h"
17#include "mlir/IR/Operation.h"
18#include "mlir/IR/SymbolTable.h"
19#include "mlir/IR/Types.h"
20#include "llvm/ADT/APSInt.h"
21#include "llvm/ADT/Sequence.h"
22#include "llvm/ADT/TypeSwitch.h"
23#include "llvm/Support/Endian.h"
24
25using namespace mlir;
26using namespace mlir::detail;
27
28//===----------------------------------------------------------------------===//
29/// Tablegen Attribute Definitions
30//===----------------------------------------------------------------------===//
31
32#define GET_ATTRDEF_CLASSES
33#include "mlir/IR/BuiltinAttributes.cpp.inc"
34
35//===----------------------------------------------------------------------===//
36// BuiltinDialect
37//===----------------------------------------------------------------------===//
38
39void BuiltinDialect::registerAttributes() {
40 addAttributes<
41#define GET_ATTRDEF_LIST
42#include "mlir/IR/BuiltinAttributes.cpp.inc"
43 >();
44}
45
46//===----------------------------------------------------------------------===//
47// DictionaryAttr
48//===----------------------------------------------------------------------===//
49
50/// Helper function that does either an in place sort or sorts from source array
51/// into destination. If inPlace then storage is both the source and the
52/// destination, else value is the source and storage destination. Returns
53/// whether source was sorted.
54template <bool inPlace>
55static bool dictionaryAttrSort(ArrayRef<NamedAttribute> value,
56 SmallVectorImpl<NamedAttribute> &storage) {
57 // Specialize for the common case.
58 switch (value.size()) {
59 case 0:
60 // Zero already sorted.
61 if (!inPlace)
62 storage.clear();
63 break;
64 case 1:
65 // One already sorted but may need to be copied.
66 if (!inPlace)
67 storage.assign({value[0]});
68 break;
69 case 2: {
70 bool isSorted = value[0] < value[1];
71 if (inPlace) {
72 if (!isSorted)
73 std::swap(storage[0], storage[1]);
74 } else if (isSorted) {
75 storage.assign({value[0], value[1]});
76 } else {
77 storage.assign({value[1], value[0]});
78 }
79 return !isSorted;
80 }
81 default:
82 if (!inPlace)
83 storage.assign(value.begin(), value.end());
84 // Check to see they are sorted already.
85 bool isSorted = llvm::is_sorted(value);
86 // If not, do a general sort.
87 if (!isSorted)
88 llvm::array_pod_sort(storage.begin(), storage.end());
89 return !isSorted;
90 }
91 return false;
92}
93
94/// Returns an entry with a duplicate name from the given sorted array of named
95/// attributes. Returns std::nullopt if all elements have unique names.
96static Optional<NamedAttribute>
97findDuplicateElement(ArrayRef<NamedAttribute> value) {
98 const Optional<NamedAttribute> none{std::nullopt};
99 if (value.size() < 2)
100 return none;
101
102 if (value.size() == 2)
103 return value[0].getName() == value[1].getName() ? value[0] : none;
104
105 const auto *it = std::adjacent_find(value.begin(), value.end(),
106 [](NamedAttribute l, NamedAttribute r) {
107 return l.getName() == r.getName();
108 });
109 return it != value.end() ? *it : none;
110}
111
112bool DictionaryAttr::sort(ArrayRef<NamedAttribute> value,
113 SmallVectorImpl<NamedAttribute> &storage) {
114 bool isSorted = dictionaryAttrSort</*inPlace=*/false>(value, storage);
115 assert(!findDuplicateElement(storage) &&(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 116, __extension__ __PRETTY_FUNCTION__
))
116 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(storage) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(storage) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 116, __extension__ __PRETTY_FUNCTION__
))
;
117 return isSorted;
118}
119
120bool DictionaryAttr::sortInPlace(SmallVectorImpl<NamedAttribute> &array) {
121 bool isSorted = dictionaryAttrSort</*inPlace=*/true>(array, array);
122 assert(!findDuplicateElement(array) &&(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 123, __extension__ __PRETTY_FUNCTION__
))
123 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(array) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(array) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 123, __extension__ __PRETTY_FUNCTION__
))
;
124 return isSorted;
125}
126
127Optional<NamedAttribute>
128DictionaryAttr::findDuplicate(SmallVectorImpl<NamedAttribute> &array,
129 bool isSorted) {
130 if (!isSorted)
131 dictionaryAttrSort</*inPlace=*/true>(array, array);
132 return findDuplicateElement(array);
133}
134
135DictionaryAttr DictionaryAttr::get(MLIRContext *context,
136 ArrayRef<NamedAttribute> value) {
137 if (value.empty())
138 return DictionaryAttr::getEmpty(context);
139
140 // We need to sort the element list to canonicalize it.
141 SmallVector<NamedAttribute, 8> storage;
142 if (dictionaryAttrSort</*inPlace=*/false>(value, storage))
143 value = storage;
144 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 145, __extension__ __PRETTY_FUNCTION__
))
145 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 145, __extension__ __PRETTY_FUNCTION__
))
;
146 return Base::get(context, value);
147}
148/// Construct a dictionary with an array of values that is known to already be
149/// sorted by name and uniqued.
150DictionaryAttr DictionaryAttr::getWithSorted(MLIRContext *context,
151 ArrayRef<NamedAttribute> value) {
152 if (value.empty())
153 return DictionaryAttr::getEmpty(context);
154 // Ensure that the attribute elements are unique and sorted.
155 assert(llvm::is_sorted((static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__
))
156 value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) &&(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__
))
157 "expected attribute values to be sorted")(static_cast <bool> (llvm::is_sorted( value, [](NamedAttribute
l, NamedAttribute r) { return l < r; }) && "expected attribute values to be sorted"
) ? void (0) : __assert_fail ("llvm::is_sorted( value, [](NamedAttribute l, NamedAttribute r) { return l < r; }) && \"expected attribute values to be sorted\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 157, __extension__ __PRETTY_FUNCTION__
))
;
158 assert(!findDuplicateElement(value) &&(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 159, __extension__ __PRETTY_FUNCTION__
))
159 "DictionaryAttr element names must be unique")(static_cast <bool> (!findDuplicateElement(value) &&
"DictionaryAttr element names must be unique") ? void (0) : __assert_fail
("!findDuplicateElement(value) && \"DictionaryAttr element names must be unique\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 159, __extension__ __PRETTY_FUNCTION__
))
;
160 return Base::get(context, value);
161}
162
163/// Return the specified attribute if present, null otherwise.
164Attribute DictionaryAttr::get(StringRef name) const {
165 auto it = impl::findAttrSorted(begin(), end(), name);
166 return it.second ? it.first->getValue() : Attribute();
167}
168Attribute DictionaryAttr::get(StringAttr name) const {
169 auto it = impl::findAttrSorted(begin(), end(), name);
170 return it.second ? it.first->getValue() : Attribute();
171}
172
173/// Return the specified named attribute if present, std::nullopt otherwise.
174Optional<NamedAttribute> DictionaryAttr::getNamed(StringRef name) const {
175 auto it = impl::findAttrSorted(begin(), end(), name);
176 return it.second ? *it.first : Optional<NamedAttribute>();
177}
178Optional<NamedAttribute> DictionaryAttr::getNamed(StringAttr name) const {
179 auto it = impl::findAttrSorted(begin(), end(), name);
180 return it.second ? *it.first : Optional<NamedAttribute>();
181}
182
183/// Return whether the specified attribute is present.
184bool DictionaryAttr::contains(StringRef name) const {
185 return impl::findAttrSorted(begin(), end(), name).second;
186}
187bool DictionaryAttr::contains(StringAttr name) const {
188 return impl::findAttrSorted(begin(), end(), name).second;
189}
190
191DictionaryAttr::iterator DictionaryAttr::begin() const {
192 return getValue().begin();
193}
194DictionaryAttr::iterator DictionaryAttr::end() const {
195 return getValue().end();
196}
197size_t DictionaryAttr::size() const { return getValue().size(); }
198
199DictionaryAttr DictionaryAttr::getEmptyUnchecked(MLIRContext *context) {
200 return Base::get(context, ArrayRef<NamedAttribute>());
201}
202
203//===----------------------------------------------------------------------===//
204// StridedLayoutAttr
205//===----------------------------------------------------------------------===//
206
207/// Prints a strided layout attribute.
208void StridedLayoutAttr::print(llvm::raw_ostream &os) const {
209 auto printIntOrQuestion = [&](int64_t value) {
210 if (ShapedType::isDynamic(value))
211 os << "?";
212 else
213 os << value;
214 };
215
216 os << "strided<[";
217 llvm::interleaveComma(getStrides(), os, printIntOrQuestion);
218 os << "]";
219
220 if (getOffset() != 0) {
221 os << ", offset: ";
222 printIntOrQuestion(getOffset());
223 }
224 os << ">";
225}
226
227/// Returns the strided layout as an affine map.
228AffineMap StridedLayoutAttr::getAffineMap() const {
229 return makeStridedLinearLayoutMap(getStrides(), getOffset(), getContext());
230}
231
232/// Checks that the type-agnostic strided layout invariants are satisfied.
233LogicalResult
234StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
235 int64_t offset, ArrayRef<int64_t> strides) {
236 if (llvm::any_of(strides, [&](int64_t stride) { return stride == 0; }))
237 return emitError() << "strides must not be zero";
238
239 return success();
240}
241
242/// Checks that the type-specific strided layout invariants are satisfied.
243LogicalResult StridedLayoutAttr::verifyLayout(
244 ArrayRef<int64_t> shape,
245 function_ref<InFlightDiagnostic()> emitError) const {
246 if (shape.size() != getStrides().size())
247 return emitError() << "expected the number of strides to match the rank";
248
249 return success();
250}
251
252//===----------------------------------------------------------------------===//
253// StringAttr
254//===----------------------------------------------------------------------===//
255
256StringAttr StringAttr::getEmptyStringAttrUnchecked(MLIRContext *context) {
257 return Base::get(context, "", NoneType::get(context));
258}
259
260/// Twine support for StringAttr.
261StringAttr StringAttr::get(MLIRContext *context, const Twine &twine) {
262 // Fast-path empty twine.
263 if (twine.isTriviallyEmpty())
264 return get(context);
265 SmallVector<char, 32> tempStr;
266 return Base::get(context, twine.toStringRef(tempStr), NoneType::get(context));
267}
268
269/// Twine support for StringAttr.
270StringAttr StringAttr::get(const Twine &twine, Type type) {
271 SmallVector<char, 32> tempStr;
272 return Base::get(type.getContext(), twine.toStringRef(tempStr), type);
273}
274
275StringRef StringAttr::getValue() const { return getImpl()->value; }
276
277Type StringAttr::getType() const { return getImpl()->type; }
278
279Dialect *StringAttr::getReferencedDialect() const {
280 return getImpl()->referencedDialect;
281}
282
283//===----------------------------------------------------------------------===//
284// FloatAttr
285//===----------------------------------------------------------------------===//
286
287double FloatAttr::getValueAsDouble() const {
288 return getValueAsDouble(getValue());
289}
290double FloatAttr::getValueAsDouble(APFloat value) {
291 if (&value.getSemantics() != &APFloat::IEEEdouble()) {
292 bool losesInfo = false;
293 value.convert(APFloat::IEEEdouble(), APFloat::rmNearestTiesToEven,
294 &losesInfo);
295 }
296 return value.convertToDouble();
297}
298
299LogicalResult FloatAttr::verify(function_ref<InFlightDiagnostic()> emitError,
300 Type type, APFloat value) {
301 // Verify that the type is correct.
302 if (!type.isa<FloatType>())
303 return emitError() << "expected floating point type";
304
305 // Verify that the type semantics match that of the value.
306 if (&type.cast<FloatType>().getFloatSemantics() != &value.getSemantics()) {
307 return emitError()
308 << "FloatAttr type doesn't match the type implied by its value";
309 }
310 return success();
311}
312
313//===----------------------------------------------------------------------===//
314// SymbolRefAttr
315//===----------------------------------------------------------------------===//
316
317SymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value,
318 ArrayRef<FlatSymbolRefAttr> nestedRefs) {
319 return get(StringAttr::get(ctx, value), nestedRefs);
320}
321
322FlatSymbolRefAttr SymbolRefAttr::get(MLIRContext *ctx, StringRef value) {
323 return get(ctx, value, {}).cast<FlatSymbolRefAttr>();
324}
325
326FlatSymbolRefAttr SymbolRefAttr::get(StringAttr value) {
327 return get(value, {}).cast<FlatSymbolRefAttr>();
328}
329
330FlatSymbolRefAttr SymbolRefAttr::get(Operation *symbol) {
331 auto symName =
332 symbol->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
333 assert(symName && "value does not have a valid symbol name")(static_cast <bool> (symName && "value does not have a valid symbol name"
) ? void (0) : __assert_fail ("symName && \"value does not have a valid symbol name\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 333, __extension__ __PRETTY_FUNCTION__
))
;
334 return SymbolRefAttr::get(symName);
335}
336
337StringAttr SymbolRefAttr::getLeafReference() const {
338 ArrayRef<FlatSymbolRefAttr> nestedRefs = getNestedReferences();
339 return nestedRefs.empty() ? getRootReference() : nestedRefs.back().getAttr();
340}
341
342//===----------------------------------------------------------------------===//
343// IntegerAttr
344//===----------------------------------------------------------------------===//
345
346int64_t IntegerAttr::getInt() const {
347 assert((getType().isIndex() || getType().isSignlessInteger()) &&(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 348, __extension__ __PRETTY_FUNCTION__
))
348 "must be signless integer")(static_cast <bool> ((getType().isIndex() || getType().
isSignlessInteger()) && "must be signless integer") ?
void (0) : __assert_fail ("(getType().isIndex() || getType().isSignlessInteger()) && \"must be signless integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 348, __extension__ __PRETTY_FUNCTION__
))
;
349 return getValue().getSExtValue();
350}
351
352int64_t IntegerAttr::getSInt() const {
353 assert(getType().isSignedInteger() && "must be signed integer")(static_cast <bool> (getType().isSignedInteger() &&
"must be signed integer") ? void (0) : __assert_fail ("getType().isSignedInteger() && \"must be signed integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 353, __extension__ __PRETTY_FUNCTION__
))
;
354 return getValue().getSExtValue();
355}
356
357uint64_t IntegerAttr::getUInt() const {
358 assert(getType().isUnsignedInteger() && "must be unsigned integer")(static_cast <bool> (getType().isUnsignedInteger() &&
"must be unsigned integer") ? void (0) : __assert_fail ("getType().isUnsignedInteger() && \"must be unsigned integer\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 358, __extension__ __PRETTY_FUNCTION__
))
;
359 return getValue().getZExtValue();
360}
361
362/// Return the value as an APSInt which carries the signed from the type of
363/// the attribute. This traps on signless integers types!
364APSInt IntegerAttr::getAPSInt() const {
365 assert(!getType().isSignlessInteger() &&(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 366, __extension__ __PRETTY_FUNCTION__
))
366 "Signless integers don't carry a sign for APSInt")(static_cast <bool> (!getType().isSignlessInteger() &&
"Signless integers don't carry a sign for APSInt") ? void (0
) : __assert_fail ("!getType().isSignlessInteger() && \"Signless integers don't carry a sign for APSInt\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 366, __extension__ __PRETTY_FUNCTION__
))
;
367 return APSInt(getValue(), getType().isUnsignedInteger());
368}
369
370LogicalResult IntegerAttr::verify(function_ref<InFlightDiagnostic()> emitError,
371 Type type, APInt value) {
372 if (IntegerType integerType = type.dyn_cast<IntegerType>()) {
373 if (integerType.getWidth() != value.getBitWidth())
374 return emitError() << "integer type bit width (" << integerType.getWidth()
375 << ") doesn't match value bit width ("
376 << value.getBitWidth() << ")";
377 return success();
378 }
379 if (type.isa<IndexType>()) {
380 if (value.getBitWidth() != IndexType::kInternalStorageBitWidth)
381 return emitError()
382 << "value bit width (" << value.getBitWidth()
383 << ") doesn't match index type internal storage bit width ("
384 << IndexType::kInternalStorageBitWidth << ")";
385 return success();
386 }
387 return emitError() << "expected integer or index type";
388}
389
390BoolAttr IntegerAttr::getBoolAttrUnchecked(IntegerType type, bool value) {
391 auto attr = Base::get(type.getContext(), type, APInt(/*numBits=*/1, value));
392 return attr.cast<BoolAttr>();
393}
394
395//===----------------------------------------------------------------------===//
396// BoolAttr
397//===----------------------------------------------------------------------===//
398
399bool BoolAttr::getValue() const {
400 auto *storage = reinterpret_cast<IntegerAttrStorage *>(impl);
401 return storage->value.getBoolValue();
402}
403
404bool BoolAttr::classof(Attribute attr) {
405 IntegerAttr intAttr = attr.dyn_cast<IntegerAttr>();
406 return intAttr && intAttr.getType().isSignlessInteger(1);
407}
408
409//===----------------------------------------------------------------------===//
410// OpaqueAttr
411//===----------------------------------------------------------------------===//
412
413LogicalResult OpaqueAttr::verify(function_ref<InFlightDiagnostic()> emitError,
414 StringAttr dialect, StringRef attrData,
415 Type type) {
416 if (!Dialect::isValidNamespace(dialect.strref()))
417 return emitError() << "invalid dialect namespace '" << dialect << "'";
418
419 // Check that the dialect is actually registered.
420 MLIRContext *context = dialect.getContext();
421 if (!context->allowsUnregisteredDialects() &&
422 !context->getLoadedDialect(dialect.strref())) {
423 return emitError()
424 << "#" << dialect << "<\"" << attrData << "\"> : " << type
425 << " attribute created with unregistered dialect. If this is "
426 "intended, please call allowUnregisteredDialects() on the "
427 "MLIRContext, or use -allow-unregistered-dialect with "
428 "the MLIR opt tool used";
429 }
430
431 return success();
432}
433
434//===----------------------------------------------------------------------===//
435// DenseElementsAttr Utilities
436//===----------------------------------------------------------------------===//
437
438/// Get the bitwidth of a dense element type within the buffer.
439/// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
440static size_t getDenseElementStorageWidth(size_t origWidth) {
441 return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
442}
443static size_t getDenseElementStorageWidth(Type elementType) {
444 return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
445}
446
447/// Set a bit to a specific value.
448static void setBit(char *rawData, size_t bitPos, bool value) {
449 if (value)
450 rawData[bitPos / CHAR_BIT8] |= (1 << (bitPos % CHAR_BIT8));
451 else
452 rawData[bitPos / CHAR_BIT8] &= ~(1 << (bitPos % CHAR_BIT8));
453}
454
455/// Return the value of the specified bit.
456static bool getBit(const char *rawData, size_t bitPos) {
457 return (rawData[bitPos / CHAR_BIT8] & (1 << (bitPos % CHAR_BIT8))) != 0;
458}
459
460/// Copy actual `numBytes` data from `value` (APInt) to char array(`result`) for
461/// BE format.
462static void copyAPIntToArrayForBEmachine(APInt value, size_t numBytes,
463 char *result) {
464 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__
))
465 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 465, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
466 assert(value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (value.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("value.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 466, __extension__ __PRETTY_FUNCTION__
))
;
467
468 // Copy the words filled with data.
469 // For example, when `value` has 2 words, the first word is filled with data.
470 // `value` (10 bytes, BE):|abcdefgh|------ij| ==> `result` (BE):|abcdefgh|--|
471 size_t numFilledWords = (value.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
472 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
473 numFilledWords, result);
474 // Convert last word of APInt to LE format and store it in char
475 // array(`valueLE`).
476 // ex. last word of `value` (BE): |------ij| ==> `valueLE` (LE): |ji------|
477 size_t lastWordPos = numFilledWords;
478 SmallVector<char, 8> valueLE(APInt::APINT_WORD_SIZE);
479 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
480 reinterpret_cast<const char *>(value.getRawData()) + lastWordPos,
481 valueLE.begin(), APInt::APINT_BITS_PER_WORD, 1);
482 // Extract actual APInt data from `valueLE`, convert endianness to BE format,
483 // and store it in `result`.
484 // ex. `valueLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|ij|
485 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
486 valueLE.begin(), result + lastWordPos,
487 (numBytes - lastWordPos) * CHAR_BIT8, 1);
488}
489
490/// Copy `numBytes` data from `inArray`(char array) to `result`(APINT) for BE
491/// format.
492static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes,
493 APInt &result) {
494 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__
))
495 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 495, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
496 assert(result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes)(static_cast <bool> (result.getNumWords() * APInt::APINT_WORD_SIZE
>= numBytes) ? void (0) : __assert_fail ("result.getNumWords() * APInt::APINT_WORD_SIZE >= numBytes"
, "mlir/lib/IR/BuiltinAttributes.cpp", 496, __extension__ __PRETTY_FUNCTION__
))
;
497
498 // Copy the data that fills the word of `result` from `inArray`.
499 // For example, when `result` has 2 words, the first word will be filled with
500 // data. So, the first 8 bytes are copied from `inArray` here.
501 // `inArray` (10 bytes, BE): |abcdefgh|ij|
502 // ==> `result` (2 words, BE): |abcdefgh|--------|
503 size_t numFilledWords = (result.getNumWords() - 1) * APInt::APINT_WORD_SIZE;
504 std::copy_n(
505 inArray, numFilledWords,
506 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())));
507
508 // Convert array data which will be last word of `result` to LE format, and
509 // store it in char array(`inArrayLE`).
510 // ex. `inArray` (last two bytes, BE): |ij| ==> `inArrayLE` (LE): |ji------|
511 size_t lastWordPos = numFilledWords;
512 SmallVector<char, 8> inArrayLE(APInt::APINT_WORD_SIZE);
513 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
514 inArray + lastWordPos, inArrayLE.begin(),
515 (numBytes - lastWordPos) * CHAR_BIT8, 1);
516
517 // Convert `inArrayLE` to BE format, and store it in last word of `result`.
518 // ex. `inArrayLE` (LE): |ji------| ==> `result` (BE): |abcdefgh|------ij|
519 DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
520 inArrayLE.begin(),
521 const_cast<char *>(reinterpret_cast<const char *>(result.getRawData())) +
522 lastWordPos,
523 APInt::APINT_BITS_PER_WORD, 1);
524}
525
526/// Writes value to the bit position `bitPos` in array `rawData`.
527static void writeBits(char *rawData, size_t bitPos, APInt value) {
528 size_t bitWidth = value.getBitWidth();
529
530 // If the bitwidth is 1 we just toggle the specific bit.
531 if (bitWidth == 1)
532 return setBit(rawData, bitPos, value.isOneValue());
533
534 // Otherwise, the bit position is guaranteed to be byte aligned.
535 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 535, __extension__ __PRETTY_FUNCTION__
))
;
536 if (llvm::support::endian::system_endianness() ==
537 llvm::support::endianness::big) {
538 // Copy from `value` to `rawData + (bitPos / CHAR_BIT)`.
539 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
540 // work correctly in BE format.
541 // ex. `value` (2 words including 10 bytes)
542 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------|
543 copyAPIntToArrayForBEmachine(value, llvm::divideCeil(bitWidth, CHAR_BIT8),
544 rawData + (bitPos / CHAR_BIT8));
545 } else {
546 std::copy_n(reinterpret_cast<const char *>(value.getRawData()),
547 llvm::divideCeil(bitWidth, CHAR_BIT8),
548 rawData + (bitPos / CHAR_BIT8));
549 }
550}
551
552/// Reads the next `bitWidth` bits from the bit position `bitPos` in array
553/// `rawData`.
554static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) {
555 // Handle a boolean bit position.
556 if (bitWidth == 1)
557 return APInt(1, getBit(rawData, bitPos) ? 1 : 0);
558
559 // Otherwise, the bit position must be 8-bit aligned.
560 assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned")(static_cast <bool> ((bitPos % 8) == 0 && "expected bitPos to be 8-bit aligned"
) ? void (0) : __assert_fail ("(bitPos % CHAR_BIT) == 0 && \"expected bitPos to be 8-bit aligned\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 560, __extension__ __PRETTY_FUNCTION__
))
;
561 APInt result(bitWidth, 0);
562 if (llvm::support::endian::system_endianness() ==
563 llvm::support::endianness::big) {
564 // Copy from `rawData + (bitPos / CHAR_BIT)` to `result`.
565 // Copying the first `llvm::divideCeil(bitWidth, CHAR_BIT)` bytes doesn't
566 // work correctly in BE format.
567 // ex. `result` (2 words including 10 bytes)
568 // ==> BE: |abcdefgh|------ij|, LE: |hgfedcba|ji------| This function
569 copyArrayToAPIntForBEmachine(rawData + (bitPos / CHAR_BIT8),
570 llvm::divideCeil(bitWidth, CHAR_BIT8), result);
571 } else {
572 std::copy_n(rawData + (bitPos / CHAR_BIT8),
573 llvm::divideCeil(bitWidth, CHAR_BIT8),
574 const_cast<char *>(
575 reinterpret_cast<const char *>(result.getRawData())));
576 }
577 return result;
578}
579
580/// Returns true if 'values' corresponds to a splat, i.e. one element, or has
581/// the same element count as 'type'.
582template <typename Values>
583static bool hasSameElementsOrSplat(ShapedType type, const Values &values) {
584 return (values.size() == 1) ||
585 (type.getNumElements() == static_cast<int64_t>(values.size()));
586}
587
588//===----------------------------------------------------------------------===//
589// DenseElementsAttr Iterators
590//===----------------------------------------------------------------------===//
591
592//===----------------------------------------------------------------------===//
593// AttributeElementIterator
594
595DenseElementsAttr::AttributeElementIterator::AttributeElementIterator(
596 DenseElementsAttr attr, size_t index)
597 : llvm::indexed_accessor_iterator<AttributeElementIterator, const void *,
598 Attribute, Attribute, Attribute>(
599 attr.getAsOpaquePointer(), index) {}
600
601Attribute DenseElementsAttr::AttributeElementIterator::operator*() const {
602 auto owner = getFromOpaquePointer(base).cast<DenseElementsAttr>();
603 Type eltTy = owner.getElementType();
604 if (auto intEltTy = eltTy.dyn_cast<IntegerType>())
605 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
606 if (eltTy.isa<IndexType>())
607 return IntegerAttr::get(eltTy, *IntElementIterator(owner, index));
608 if (auto floatEltTy = eltTy.dyn_cast<FloatType>()) {
609 IntElementIterator intIt(owner, index);
610 FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt);
611 return FloatAttr::get(eltTy, *floatIt);
612 }
613 if (auto complexTy = eltTy.dyn_cast<ComplexType>()) {
614 auto complexEltTy = complexTy.getElementType();
615 ComplexIntElementIterator complexIntIt(owner, index);
616 if (complexEltTy.isa<IntegerType>()) {
617 auto value = *complexIntIt;
618 auto real = IntegerAttr::get(complexEltTy, value.real());
619 auto imag = IntegerAttr::get(complexEltTy, value.imag());
620 return ArrayAttr::get(complexTy.getContext(),
621 ArrayRef<Attribute>{real, imag});
622 }
623
624 ComplexFloatElementIterator complexFloatIt(
625 complexEltTy.cast<FloatType>().getFloatSemantics(), complexIntIt);
626 auto value = *complexFloatIt;
627 auto real = FloatAttr::get(complexEltTy, value.real());
628 auto imag = FloatAttr::get(complexEltTy, value.imag());
629 return ArrayAttr::get(complexTy.getContext(),
630 ArrayRef<Attribute>{real, imag});
631 }
632 if (owner.isa<DenseStringElementsAttr>()) {
633 ArrayRef<StringRef> vals = owner.getRawStringData();
634 return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy);
635 }
636 llvm_unreachable("unexpected element type")::llvm::llvm_unreachable_internal("unexpected element type", "mlir/lib/IR/BuiltinAttributes.cpp"
, 636)
;
637}
638
639//===----------------------------------------------------------------------===//
640// BoolElementIterator
641
642DenseElementsAttr::BoolElementIterator::BoolElementIterator(
643 DenseElementsAttr attr, size_t dataIndex)
644 : DenseElementIndexedIteratorImpl<BoolElementIterator, bool, bool, bool>(
645 attr.getRawData().data(), attr.isSplat(), dataIndex) {}
646
647bool DenseElementsAttr::BoolElementIterator::operator*() const {
648 return getBit(getData(), getDataIndex());
649}
650
651//===----------------------------------------------------------------------===//
652// IntElementIterator
653
654DenseElementsAttr::IntElementIterator::IntElementIterator(
655 DenseElementsAttr attr, size_t dataIndex)
656 : DenseElementIndexedIteratorImpl<IntElementIterator, APInt, APInt, APInt>(
657 attr.getRawData().data(), attr.isSplat(), dataIndex),
658 bitWidth(getDenseElementBitWidth(attr.getElementType())) {}
659
660APInt DenseElementsAttr::IntElementIterator::operator*() const {
661 return readBits(getData(),
662 getDataIndex() * getDenseElementStorageWidth(bitWidth),
663 bitWidth);
664}
665
666//===----------------------------------------------------------------------===//
667// ComplexIntElementIterator
668
669DenseElementsAttr::ComplexIntElementIterator::ComplexIntElementIterator(
670 DenseElementsAttr attr, size_t dataIndex)
671 : DenseElementIndexedIteratorImpl<ComplexIntElementIterator,
672 std::complex<APInt>, std::complex<APInt>,
673 std::complex<APInt>>(
674 attr.getRawData().data(), attr.isSplat(), dataIndex) {
675 auto complexType = attr.getElementType().cast<ComplexType>();
676 bitWidth = getDenseElementBitWidth(complexType.getElementType());
677}
678
679std::complex<APInt>
680DenseElementsAttr::ComplexIntElementIterator::operator*() const {
681 size_t storageWidth = getDenseElementStorageWidth(bitWidth);
682 size_t offset = getDataIndex() * storageWidth * 2;
683 return {readBits(getData(), offset, bitWidth),
684 readBits(getData(), offset + storageWidth, bitWidth)};
685}
686
687//===----------------------------------------------------------------------===//
688// DenseArrayAttr
689//===----------------------------------------------------------------------===//
690
691LogicalResult
692DenseArrayAttr::verify(function_ref<InFlightDiagnostic()> emitError,
693 Type elementType, int64_t size, ArrayRef<char> rawData) {
694 if (!elementType.isIntOrIndexOrFloat())
695 return emitError() << "expected integer or floating point element type";
696 int64_t dataSize = rawData.size();
697 int64_t elementSize =
698 llvm::divideCeil(elementType.getIntOrFloatBitWidth(), CHAR_BIT8);
699 if (size * elementSize != dataSize) {
700 return emitError() << "expected data size (" << size << " elements, "
701 << elementSize
702 << " bytes each) does not match: " << dataSize
703 << " bytes";
704 }
705 return success();
706}
707
708namespace {
709/// Instantiations of this class provide utilities for interacting with native
710/// data types in the context of DenseArrayAttr.
711template <size_t width,
712 IntegerType::SignednessSemantics signedness = IntegerType::Signless>
713struct DenseArrayAttrIntUtil {
714 static bool checkElementType(Type eltType) {
715 auto type = eltType.dyn_cast<IntegerType>();
716 if (!type || type.getWidth() != width)
717 return false;
718 return type.getSignedness() == signedness;
719 }
720
721 static Type getElementType(MLIRContext *ctx) {
722 return IntegerType::get(ctx, width, signedness);
723 }
724
725 template <typename T>
726 static void printElement(raw_ostream &os, T value) {
727 os << value;
728 }
729
730 template <typename T>
731 static ParseResult parseElement(AsmParser &parser, T &value) {
732 return parser.parseInteger(value);
3
Calling 'AsmParser::parseInteger'
11
Returning from 'AsmParser::parseInteger'
12
Returning without writing to 'value'
733 }
734};
735template <typename T>
736struct DenseArrayAttrUtil;
737
738/// Specialization for boolean elements to print 'true' and 'false' literals for
739/// elements.
740template <>
741struct DenseArrayAttrUtil<bool> : public DenseArrayAttrIntUtil<1> {
742 static void printElement(raw_ostream &os, bool value) {
743 os << (value ? "true" : "false");
744 }
745};
746
747/// Specialization for 8-bit integers to ensure values are printed as integers
748/// and not characters.
749template <>
750struct DenseArrayAttrUtil<int8_t> : public DenseArrayAttrIntUtil<8> {
751 static void printElement(raw_ostream &os, int8_t value) {
752 os << static_cast<int>(value);
753 }
754};
755template <>
756struct DenseArrayAttrUtil<int16_t> : public DenseArrayAttrIntUtil<16> {};
757template <>
758struct DenseArrayAttrUtil<int32_t> : public DenseArrayAttrIntUtil<32> {};
759template <>
760struct DenseArrayAttrUtil<int64_t> : public DenseArrayAttrIntUtil<64> {};
761
762/// Specialization for 32-bit floats.
763template <>
764struct DenseArrayAttrUtil<float> {
765 static bool checkElementType(Type eltType) { return eltType.isF32(); }
766 static Type getElementType(MLIRContext *ctx) { return Float32Type::get(ctx); }
767 static void printElement(raw_ostream &os, float value) { os << value; }
768
769 /// Parse a double and cast it to a float.
770 static ParseResult parseElement(AsmParser &parser, float &value) {
771 double doubleVal;
772 if (parser.parseFloat(doubleVal))
773 return failure();
774 value = doubleVal;
775 return success();
776 }
777};
778
779/// Specialization for 64-bit floats.
780template <>
781struct DenseArrayAttrUtil<double> {
782 static bool checkElementType(Type eltType) { return eltType.isF64(); }
783 static Type getElementType(MLIRContext *ctx) { return Float64Type::get(ctx); }
784 static void printElement(raw_ostream &os, float value) { os << value; }
785 static ParseResult parseElement(AsmParser &parser, double &value) {
786 return parser.parseFloat(value);
787 }
788};
789} // namespace
790
791template <typename T>
792void DenseArrayAttrImpl<T>::print(AsmPrinter &printer) const {
793 print(printer.getStream());
794}
795
796template <typename T>
797void DenseArrayAttrImpl<T>::printWithoutBraces(raw_ostream &os) const {
798 llvm::interleaveComma(asArrayRef(), os, [&](T value) {
799 DenseArrayAttrUtil<T>::printElement(os, value);
800 });
801}
802
803template <typename T>
804void DenseArrayAttrImpl<T>::print(raw_ostream &os) const {
805 os << "[";
806 printWithoutBraces(os);
807 os << "]";
808}
809
810/// Parse a DenseArrayAttr without the braces: `1, 2, 3`
811template <typename T>
812Attribute DenseArrayAttrImpl<T>::parseWithoutBraces(AsmParser &parser,
813 Type odsType) {
814 SmallVector<T> data;
815 if (failed(parser.parseCommaSeparatedList([&]() {
816 T value;
1
'value' declared without an initial value
817 if (DenseArrayAttrUtil<T>::parseElement(parser, value))
2
Calling 'DenseArrayAttrIntUtil::parseElement'
13
Returning from 'DenseArrayAttrIntUtil::parseElement'
14
Taking false branch
818 return failure();
819 data.push_back(value);
15
1st function call argument is an uninitialized value
820 return success();
821 })))
822 return {};
823 return get(parser.getContext(), data);
824}
825
826/// Parse a DenseArrayAttr: `[ 1, 2, 3 ]`
827template <typename T>
828Attribute DenseArrayAttrImpl<T>::parse(AsmParser &parser, Type odsType) {
829 if (parser.parseLSquare())
830 return {};
831 // Handle empty list case.
832 if (succeeded(parser.parseOptionalRSquare()))
833 return get(parser.getContext(), {});
834 Attribute result = parseWithoutBraces(parser, odsType);
835 if (parser.parseRSquare())
836 return {};
837 return result;
838}
839
840/// Conversion from DenseArrayAttr<T> to ArrayRef<T>.
841template <typename T>
842DenseArrayAttrImpl<T>::operator ArrayRef<T>() const {
843 ArrayRef<char> raw = getRawData();
844 assert((raw.size() % sizeof(T)) == 0)(static_cast <bool> ((raw.size() % sizeof(T)) == 0) ? void
(0) : __assert_fail ("(raw.size() % sizeof(T)) == 0", "mlir/lib/IR/BuiltinAttributes.cpp"
, 844, __extension__ __PRETTY_FUNCTION__))
;
845 return ArrayRef<T>(reinterpret_cast<const T *>(raw.data()),
846 raw.size() / sizeof(T));
847}
848
849/// Builds a DenseArrayAttr<T> from an ArrayRef<T>.
850template <typename T>
851DenseArrayAttrImpl<T> DenseArrayAttrImpl<T>::get(MLIRContext *context,
852 ArrayRef<T> content) {
853 Type elementType = DenseArrayAttrUtil<T>::getElementType(context);
854 auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
855 content.size() * sizeof(T));
856 return llvm::cast<DenseArrayAttrImpl<T>>(
857 Base::get(context, elementType, content.size(), rawArray));
858}
859
860template <typename T>
861bool DenseArrayAttrImpl<T>::classof(Attribute attr) {
862 if (auto denseArray = attr.dyn_cast<DenseArrayAttr>())
863 return DenseArrayAttrUtil<T>::checkElementType(denseArray.getElementType());
864 return false;
865}
866
867namespace mlir {
868namespace detail {
869// Explicit instantiation for all the supported DenseArrayAttr.
870template class DenseArrayAttrImpl<bool>;
871template class DenseArrayAttrImpl<int8_t>;
872template class DenseArrayAttrImpl<int16_t>;
873template class DenseArrayAttrImpl<int32_t>;
874template class DenseArrayAttrImpl<int64_t>;
875template class DenseArrayAttrImpl<float>;
876template class DenseArrayAttrImpl<double>;
877} // namespace detail
878} // namespace mlir
879
880//===----------------------------------------------------------------------===//
881// DenseElementsAttr
882//===----------------------------------------------------------------------===//
883
884/// Method for support type inquiry through isa, cast and dyn_cast.
885bool DenseElementsAttr::classof(Attribute attr) {
886 return attr.isa<DenseIntOrFPElementsAttr, DenseStringElementsAttr>();
887}
888
889DenseElementsAttr DenseElementsAttr::get(ShapedType type,
890 ArrayRef<Attribute> values) {
891 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 891, __extension__ __PRETTY_FUNCTION__
))
;
892
893 // If the element type is not based on int/float/index, assume it is a string
894 // type.
895 Type eltType = type.getElementType();
896 if (!eltType.isIntOrIndexOrFloat()) {
897 SmallVector<StringRef, 8> stringValues;
898 stringValues.reserve(values.size());
899 for (Attribute attr : values) {
900 assert(attr.isa<StringAttr>() &&(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 901, __extension__ __PRETTY_FUNCTION__
))
901 "expected string value for non integer/index/float element")(static_cast <bool> (attr.isa<StringAttr>() &&
"expected string value for non integer/index/float element")
? void (0) : __assert_fail ("attr.isa<StringAttr>() && \"expected string value for non integer/index/float element\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 901, __extension__ __PRETTY_FUNCTION__
))
;
902 stringValues.push_back(attr.cast<StringAttr>().getValue());
903 }
904 return get(type, stringValues);
905 }
906
907 // Otherwise, get the raw storage width to use for the allocation.
908 size_t bitWidth = getDenseElementBitWidth(eltType);
909 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
910
911 // Compress the attribute values into a character buffer.
912 SmallVector<char, 8> data(
913 llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT8));
914 APInt intVal;
915 for (unsigned i = 0, e = values.size(); i < e; ++i) {
916 if (auto floatAttr = values[i].dyn_cast<FloatAttr>()) {
917 assert(floatAttr.getType() == eltType &&(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__
))
918 "expected float attribute type to equal element type")(static_cast <bool> (floatAttr.getType() == eltType &&
"expected float attribute type to equal element type") ? void
(0) : __assert_fail ("floatAttr.getType() == eltType && \"expected float attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 918, __extension__ __PRETTY_FUNCTION__
))
;
919 intVal = floatAttr.getValue().bitcastToAPInt();
920 } else {
921 auto intAttr = values[i].cast<IntegerAttr>();
922 assert(intAttr.getType() == eltType &&(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 923, __extension__ __PRETTY_FUNCTION__
))
923 "expected integer attribute type to equal element type")(static_cast <bool> (intAttr.getType() == eltType &&
"expected integer attribute type to equal element type") ? void
(0) : __assert_fail ("intAttr.getType() == eltType && \"expected integer attribute type to equal element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 923, __extension__ __PRETTY_FUNCTION__
))
;
924 intVal = intAttr.getValue();
925 }
926
927 assert(intVal.getBitWidth() == bitWidth &&(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__
))
928 "expected value to have same bitwidth as element type")(static_cast <bool> (intVal.getBitWidth() == bitWidth &&
"expected value to have same bitwidth as element type") ? void
(0) : __assert_fail ("intVal.getBitWidth() == bitWidth && \"expected value to have same bitwidth as element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 928, __extension__ __PRETTY_FUNCTION__
))
;
929 writeBits(data.data(), i * storageBitWidth, intVal);
930 }
931
932 // Handle the special encoding of splat of bool.
933 if (values.size() == 1 && eltType.isInteger(1))
934 data[0] = data[0] ? -1 : 0;
935
936 return DenseIntOrFPElementsAttr::getRaw(type, data);
937}
938
939DenseElementsAttr DenseElementsAttr::get(ShapedType type,
940 ArrayRef<bool> values) {
941 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 941, __extension__ __PRETTY_FUNCTION__
))
;
942 assert(type.getElementType().isInteger(1))(static_cast <bool> (type.getElementType().isInteger(1)
) ? void (0) : __assert_fail ("type.getElementType().isInteger(1)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 942, __extension__ __PRETTY_FUNCTION__
))
;
943
944 std::vector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT8));
945
946 if (!values.empty()) {
947 bool isSplat = true;
948 bool firstValue = values[0];
949 for (int i = 0, e = values.size(); i != e; ++i) {
950 isSplat &= values[i] == firstValue;
951 setBit(buff.data(), i, values[i]);
952 }
953
954 // Splat of bool is encoded as a byte with all-ones in it.
955 if (isSplat) {
956 buff.resize(1);
957 buff[0] = values[0] ? -1 : 0;
958 }
959 }
960
961 return DenseIntOrFPElementsAttr::getRaw(type, buff);
962}
963
964DenseElementsAttr DenseElementsAttr::get(ShapedType type,
965 ArrayRef<StringRef> values) {
966 assert(!type.getElementType().isIntOrFloat())(static_cast <bool> (!type.getElementType().isIntOrFloat
()) ? void (0) : __assert_fail ("!type.getElementType().isIntOrFloat()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 966, __extension__ __PRETTY_FUNCTION__
))
;
967 return DenseStringElementsAttr::get(type, values);
968}
969
970/// Constructs a dense integer elements attribute from an array of APInt
971/// values. Each APInt value is expected to have the same bitwidth as the
972/// element type of 'type'.
973DenseElementsAttr DenseElementsAttr::get(ShapedType type,
974 ArrayRef<APInt> values) {
975 assert(type.getElementType().isIntOrIndex())(static_cast <bool> (type.getElementType().isIntOrIndex
()) ? void (0) : __assert_fail ("type.getElementType().isIntOrIndex()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 975, __extension__ __PRETTY_FUNCTION__
))
;
976 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 976, __extension__ __PRETTY_FUNCTION__
))
;
977 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
978 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
979}
980DenseElementsAttr DenseElementsAttr::get(ShapedType type,
981 ArrayRef<std::complex<APInt>> values) {
982 ComplexType complex = type.getElementType().cast<ComplexType>();
983 assert(complex.getElementType().isa<IntegerType>())(static_cast <bool> (complex.getElementType().isa<IntegerType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<IntegerType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 983, __extension__ __PRETTY_FUNCTION__
))
;
984 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 984, __extension__ __PRETTY_FUNCTION__
))
;
985 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
986 ArrayRef<APInt> intVals(reinterpret_cast<const APInt *>(values.data()),
987 values.size() * 2);
988 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, intVals);
989}
990
991// Constructs a dense float elements attribute from an array of APFloat
992// values. Each APFloat value is expected to have the same bitwidth as the
993// element type of 'type'.
994DenseElementsAttr DenseElementsAttr::get(ShapedType type,
995 ArrayRef<APFloat> values) {
996 assert(type.getElementType().isa<FloatType>())(static_cast <bool> (type.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("type.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 996, __extension__ __PRETTY_FUNCTION__
))
;
997 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 997, __extension__ __PRETTY_FUNCTION__
))
;
998 size_t storageBitWidth = getDenseElementStorageWidth(type.getElementType());
999 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, values);
1000}
1001DenseElementsAttr
1002DenseElementsAttr::get(ShapedType type,
1003 ArrayRef<std::complex<APFloat>> values) {
1004 ComplexType complex = type.getElementType().cast<ComplexType>();
1005 assert(complex.getElementType().isa<FloatType>())(static_cast <bool> (complex.getElementType().isa<FloatType
>()) ? void (0) : __assert_fail ("complex.getElementType().isa<FloatType>()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1005, __extension__ __PRETTY_FUNCTION__
))
;
1006 assert(hasSameElementsOrSplat(type, values))(static_cast <bool> (hasSameElementsOrSplat(type, values
)) ? void (0) : __assert_fail ("hasSameElementsOrSplat(type, values)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1006, __extension__ __PRETTY_FUNCTION__
))
;
1007 ArrayRef<APFloat> apVals(reinterpret_cast<const APFloat *>(values.data()),
1008 values.size() * 2);
1009 size_t storageBitWidth = getDenseElementStorageWidth(complex) / 2;
1010 return DenseIntOrFPElementsAttr::getRaw(type, storageBitWidth, apVals);
1011}
1012
1013/// Construct a dense elements attribute from a raw buffer representing the
1014/// data for this attribute. Users should generally not use this methods as
1015/// the expected buffer format may not be a form the user expects.
1016DenseElementsAttr
1017DenseElementsAttr::getFromRawBuffer(ShapedType type, ArrayRef<char> rawBuffer) {
1018 return DenseIntOrFPElementsAttr::getRaw(type, rawBuffer);
1019}
1020
1021/// Returns true if the given buffer is a valid raw buffer for the given type.
1022bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
1023 ArrayRef<char> rawBuffer,
1024 bool &detectedSplat) {
1025 size_t storageWidth = getDenseElementStorageWidth(type.getElementType());
1026 size_t rawBufferWidth = rawBuffer.size() * CHAR_BIT8;
1027 int64_t numElements = type.getNumElements();
1028
1029 // The initializer is always a splat if the result type has a single element.
1030 detectedSplat = numElements == 1;
1031
1032 // Storage width of 1 is special as it is packed by the bit.
1033 if (storageWidth == 1) {
1034 // Check for a splat, or a buffer equal to the number of elements which
1035 // consists of either all 0's or all 1's.
1036 if (rawBuffer.size() == 1) {
1037 auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
1038 if (rawByte == 0 || rawByte == 0xff) {
1039 detectedSplat = true;
1040 return true;
1041 }
1042 }
1043
1044 // This is a valid non-splat buffer if it has the right size.
1045 return rawBufferWidth == llvm::alignTo<8>(numElements);
1046 }
1047
1048 // All other types are 8-bit aligned, so we can just check the buffer width
1049 // to know if only a single initializer element was passed in.
1050 if (rawBufferWidth == storageWidth) {
1051 detectedSplat = true;
1052 return true;
1053 }
1054
1055 // The raw buffer is valid if it has the right size.
1056 return rawBufferWidth == storageWidth * numElements;
1057}
1058
1059/// Check the information for a C++ data type, check if this type is valid for
1060/// the current attribute. This method is used to verify specific type
1061/// invariants that the templatized 'getValues' method cannot.
1062static bool isValidIntOrFloat(Type type, int64_t dataEltSize, bool isInt,
1063 bool isSigned) {
1064 // Make sure that the data element size is the same as the type element width.
1065 if (getDenseElementBitWidth(type) !=
1066 static_cast<size_t>(dataEltSize * CHAR_BIT8))
1067 return false;
1068
1069 // Check that the element type is either float or integer or index.
1070 if (!isInt)
1071 return type.isa<FloatType>();
1072 if (type.isIndex())
1073 return true;
1074
1075 auto intType = type.dyn_cast<IntegerType>();
1076 if (!intType)
1077 return false;
1078
1079 // Make sure signedness semantics is consistent.
1080 if (intType.isSignless())
1081 return true;
1082 return intType.isSigned() ? isSigned : !isSigned;
1083}
1084
1085/// Defaults down the subclass implementation.
1086DenseElementsAttr DenseElementsAttr::getRawComplex(ShapedType type,
1087 ArrayRef<char> data,
1088 int64_t dataEltSize,
1089 bool isInt, bool isSigned) {
1090 return DenseIntOrFPElementsAttr::getRawComplex(type, data, dataEltSize, isInt,
1091 isSigned);
1092}
1093DenseElementsAttr DenseElementsAttr::getRawIntOrFloat(ShapedType type,
1094 ArrayRef<char> data,
1095 int64_t dataEltSize,
1096 bool isInt,
1097 bool isSigned) {
1098 return DenseIntOrFPElementsAttr::getRawIntOrFloat(type, data, dataEltSize,
1099 isInt, isSigned);
1100}
1101
1102bool DenseElementsAttr::isValidIntOrFloat(int64_t dataEltSize, bool isInt,
1103 bool isSigned) const {
1104 return ::isValidIntOrFloat(getElementType(), dataEltSize, isInt, isSigned);
1105}
1106bool DenseElementsAttr::isValidComplex(int64_t dataEltSize, bool isInt,
1107 bool isSigned) const {
1108 return ::isValidIntOrFloat(
1109 getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2,
1110 isInt, isSigned);
1111}
1112
1113/// Returns true if this attribute corresponds to a splat, i.e. if all element
1114/// values are the same.
1115bool DenseElementsAttr::isSplat() const {
1116 return static_cast<DenseElementsAttributeStorage *>(impl)->isSplat;
1117}
1118
1119/// Return if the given complex type has an integer element type.
1120static bool isComplexOfIntType(Type type) {
1121 return type.cast<ComplexType>().getElementType().isa<IntegerType>();
1122}
1123
1124auto DenseElementsAttr::tryGetComplexIntValues() const
1125 -> FailureOr<iterator_range_impl<ComplexIntElementIterator>> {
1126 if (!isComplexOfIntType(getElementType()))
1127 return failure();
1128 return iterator_range_impl<ComplexIntElementIterator>(
1129 getType(), ComplexIntElementIterator(*this, 0),
1130 ComplexIntElementIterator(*this, getNumElements()));
1131}
1132
1133auto DenseElementsAttr::tryGetFloatValues() const
1134 -> FailureOr<iterator_range_impl<FloatElementIterator>> {
1135 auto eltTy = getElementType().dyn_cast<FloatType>();
1136 if (!eltTy)
1137 return failure();
1138 const auto &elementSemantics = eltTy.getFloatSemantics();
1139 return iterator_range_impl<FloatElementIterator>(
1140 getType(), FloatElementIterator(elementSemantics, raw_int_begin()),
1141 FloatElementIterator(elementSemantics, raw_int_end()));
1142}
1143
1144auto DenseElementsAttr::tryGetComplexFloatValues() const
1145 -> FailureOr<iterator_range_impl<ComplexFloatElementIterator>> {
1146 auto complexTy = getElementType().dyn_cast<ComplexType>();
1147 if (!complexTy)
1148 return failure();
1149 auto eltTy = complexTy.getElementType().dyn_cast<FloatType>();
1150 if (!eltTy)
1151 return failure();
1152 const auto &semantics = eltTy.getFloatSemantics();
1153 return iterator_range_impl<ComplexFloatElementIterator>(
1154 getType(), {semantics, {*this, 0}},
1155 {semantics, {*this, static_cast<size_t>(getNumElements())}});
1156}
1157
1158/// Return the raw storage data held by this attribute.
1159ArrayRef<char> DenseElementsAttr::getRawData() const {
1160 return static_cast<DenseIntOrFPElementsAttrStorage *>(impl)->data;
1161}
1162
1163ArrayRef<StringRef> DenseElementsAttr::getRawStringData() const {
1164 return static_cast<DenseStringElementsAttrStorage *>(impl)->data;
1165}
1166
1167/// Return a new DenseElementsAttr that has the same data as the current
1168/// attribute, but has been reshaped to 'newType'. The new type must have the
1169/// same total number of elements as well as element type.
1170DenseElementsAttr DenseElementsAttr::reshape(ShapedType newType) {
1171 ShapedType curType = getType();
1172 if (curType == newType)
1173 return *this;
1174
1175 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1176, __extension__ __PRETTY_FUNCTION__
))
1176 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1176, __extension__ __PRETTY_FUNCTION__
))
;
1177 assert(newType.getNumElements() == curType.getNumElements() &&(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1178, __extension__ __PRETTY_FUNCTION__
))
1178 "expected the same number of elements")(static_cast <bool> (newType.getNumElements() == curType
.getNumElements() && "expected the same number of elements"
) ? void (0) : __assert_fail ("newType.getNumElements() == curType.getNumElements() && \"expected the same number of elements\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1178, __extension__ __PRETTY_FUNCTION__
))
;
1179 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1180}
1181
1182DenseElementsAttr DenseElementsAttr::resizeSplat(ShapedType newType) {
1183 assert(isSplat() && "expected a splat type")(static_cast <bool> (isSplat() && "expected a splat type"
) ? void (0) : __assert_fail ("isSplat() && \"expected a splat type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1183, __extension__ __PRETTY_FUNCTION__
))
;
1184
1185 ShapedType curType = getType();
1186 if (curType == newType)
1187 return *this;
1188
1189 assert(newType.getElementType() == curType.getElementType() &&(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1190, __extension__ __PRETTY_FUNCTION__
))
1190 "expected the same element type")(static_cast <bool> (newType.getElementType() == curType
.getElementType() && "expected the same element type"
) ? void (0) : __assert_fail ("newType.getElementType() == curType.getElementType() && \"expected the same element type\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1190, __extension__ __PRETTY_FUNCTION__
))
;
1191 return DenseIntOrFPElementsAttr::getRaw(newType, getRawData());
1192}
1193
1194/// Return a new DenseElementsAttr that has the same data as the current
1195/// attribute, but has bitcast elements such that it is now 'newType'. The new
1196/// type must have the same shape and element types of the same bitwidth as the
1197/// current type.
1198DenseElementsAttr DenseElementsAttr::bitcast(Type newElType) {
1199 ShapedType curType = getType();
1200 Type curElType = curType.getElementType();
1201 if (curElType == newElType)
1202 return *this;
1203
1204 assert(getDenseElementBitWidth(newElType) ==(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__
))
1205 getDenseElementBitWidth(curElType) &&(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__
))
1206 "expected element types with the same bitwidth")(static_cast <bool> (getDenseElementBitWidth(newElType)
== getDenseElementBitWidth(curElType) && "expected element types with the same bitwidth"
) ? void (0) : __assert_fail ("getDenseElementBitWidth(newElType) == getDenseElementBitWidth(curElType) && \"expected element types with the same bitwidth\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1206, __extension__ __PRETTY_FUNCTION__
))
;
1207 return DenseIntOrFPElementsAttr::getRaw(curType.clone(newElType),
1208 getRawData());
1209}
1210
1211DenseElementsAttr
1212DenseElementsAttr::mapValues(Type newElementType,
1213 function_ref<APInt(const APInt &)> mapping) const {
1214 return cast<DenseIntElementsAttr>().mapValues(newElementType, mapping);
1215}
1216
1217DenseElementsAttr DenseElementsAttr::mapValues(
1218 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1219 return cast<DenseFPElementsAttr>().mapValues(newElementType, mapping);
1220}
1221
1222ShapedType DenseElementsAttr::getType() const {
1223 return static_cast<const DenseElementsAttributeStorage *>(impl)->type;
1224}
1225
1226Type DenseElementsAttr::getElementType() const {
1227 return getType().getElementType();
1228}
1229
1230int64_t DenseElementsAttr::getNumElements() const {
1231 return getType().getNumElements();
1232}
1233
1234//===----------------------------------------------------------------------===//
1235// DenseIntOrFPElementsAttr
1236//===----------------------------------------------------------------------===//
1237
1238/// Utility method to write a range of APInt values to a buffer.
1239template <typename APRangeT>
1240static void writeAPIntsToBuffer(size_t storageWidth, std::vector<char> &data,
1241 APRangeT &&values) {
1242 size_t numValues = llvm::size(values);
1243 data.resize(llvm::divideCeil(storageWidth * numValues, CHAR_BIT8));
1244 size_t offset = 0;
1245 for (auto it = values.begin(), e = values.end(); it != e;
1246 ++it, offset += storageWidth) {
1247 assert((*it).getBitWidth() <= storageWidth)(static_cast <bool> ((*it).getBitWidth() <= storageWidth
) ? void (0) : __assert_fail ("(*it).getBitWidth() <= storageWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1247, __extension__ __PRETTY_FUNCTION__
))
;
1248 writeBits(data.data(), offset, *it);
1249 }
1250
1251 // Handle the special encoding of splat of a boolean.
1252 if (numValues == 1 && (*values.begin()).getBitWidth() == 1)
1253 data[0] = data[0] ? -1 : 0;
1254}
1255
1256/// Constructs a dense elements attribute from an array of raw APFloat values.
1257/// Each APFloat value is expected to have the same bitwidth as the element
1258/// type of 'type'. 'type' must be a vector or tensor with static shape.
1259DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1260 size_t storageWidth,
1261 ArrayRef<APFloat> values) {
1262 std::vector<char> data;
1263 auto unwrapFloat = [](const APFloat &val) { return val.bitcastToAPInt(); };
1264 writeAPIntsToBuffer(storageWidth, data, llvm::map_range(values, unwrapFloat));
1265 return DenseIntOrFPElementsAttr::getRaw(type, data);
1266}
1267
1268/// Constructs a dense elements attribute from an array of raw APInt values.
1269/// Each APInt value is expected to have the same bitwidth as the element type
1270/// of 'type'.
1271DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1272 size_t storageWidth,
1273 ArrayRef<APInt> values) {
1274 std::vector<char> data;
1275 writeAPIntsToBuffer(storageWidth, data, values);
1276 return DenseIntOrFPElementsAttr::getRaw(type, data);
1277}
1278
1279DenseElementsAttr DenseIntOrFPElementsAttr::getRaw(ShapedType type,
1280 ArrayRef<char> data) {
1281 assert(type.hasStaticShape() && "type must have static shape")(static_cast <bool> (type.hasStaticShape() && "type must have static shape"
) ? void (0) : __assert_fail ("type.hasStaticShape() && \"type must have static shape\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1281, __extension__ __PRETTY_FUNCTION__
))
;
1282 bool isSplat = false;
1283 bool isValid = isValidRawBuffer(type, data, isSplat);
1284 assert(isValid)(static_cast <bool> (isValid) ? void (0) : __assert_fail
("isValid", "mlir/lib/IR/BuiltinAttributes.cpp", 1284, __extension__
__PRETTY_FUNCTION__))
;
1285 (void)isValid;
1286 return Base::get(type.getContext(), type, data, isSplat);
1287}
1288
1289/// Overload of the raw 'get' method that asserts that the given type is of
1290/// complex type. This method is used to verify type invariants that the
1291/// templatized 'get' method cannot.
1292DenseElementsAttr DenseIntOrFPElementsAttr::getRawComplex(ShapedType type,
1293 ArrayRef<char> data,
1294 int64_t dataEltSize,
1295 bool isInt,
1296 bool isSigned) {
1297 assert(::isValidIntOrFloat((static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__
))
1298 type.getElementType().cast<ComplexType>().getElementType(),(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__
))
1299 dataEltSize / 2, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat( type.getElementType
().cast<ComplexType>().getElementType(), dataEltSize / 2
, isInt, isSigned)) ? void (0) : __assert_fail ("::isValidIntOrFloat( type.getElementType().cast<ComplexType>().getElementType(), dataEltSize / 2, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1299, __extension__ __PRETTY_FUNCTION__
))
;
1300
1301 int64_t numElements = data.size() / dataEltSize;
1302 (void)numElements;
1303 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1303, __extension__ __PRETTY_FUNCTION__
))
;
1304 return getRaw(type, data);
1305}
1306
1307/// Overload of the 'getRaw' method that asserts that the given type is of
1308/// integer type. This method is used to verify type invariants that the
1309/// templatized 'get' method cannot.
1310DenseElementsAttr
1311DenseIntOrFPElementsAttr::getRawIntOrFloat(ShapedType type, ArrayRef<char> data,
1312 int64_t dataEltSize, bool isInt,
1313 bool isSigned) {
1314 assert((static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1315, __extension__ __PRETTY_FUNCTION__
))
1315 ::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned))(static_cast <bool> (::isValidIntOrFloat(type.getElementType
(), dataEltSize, isInt, isSigned)) ? void (0) : __assert_fail
("::isValidIntOrFloat(type.getElementType(), dataEltSize, isInt, isSigned)"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1315, __extension__ __PRETTY_FUNCTION__
))
;
1316
1317 int64_t numElements = data.size() / dataEltSize;
1318 assert(numElements == 1 || numElements == type.getNumElements())(static_cast <bool> (numElements == 1 || numElements ==
type.getNumElements()) ? void (0) : __assert_fail ("numElements == 1 || numElements == type.getNumElements()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1318, __extension__ __PRETTY_FUNCTION__
))
;
1319 (void)numElements;
1320 return getRaw(type, data);
1321}
1322
1323void DenseIntOrFPElementsAttr::convertEndianOfCharForBEmachine(
1324 const char *inRawData, char *outRawData, size_t elementBitWidth,
1325 size_t numElements) {
1326 using llvm::support::ulittle16_t;
1327 using llvm::support::ulittle32_t;
1328 using llvm::support::ulittle64_t;
1329
1330 assert(llvm::support::endian::system_endianness() == // NOLINT(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1331, __extension__ __PRETTY_FUNCTION__
))
1331 llvm::support::endianness::big)(static_cast <bool> (llvm::support::endian::system_endianness
() == llvm::support::endianness::big) ? void (0) : __assert_fail
("llvm::support::endian::system_endianness() == llvm::support::endianness::big"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1331, __extension__ __PRETTY_FUNCTION__
))
; // NOLINT
1332 // NOLINT to avoid warning message about replacing by static_assert()
1333
1334 // Following std::copy_n always converts endianness on BE machine.
1335 switch (elementBitWidth) {
1336 case 16: {
1337 const ulittle16_t *inRawDataPos =
1338 reinterpret_cast<const ulittle16_t *>(inRawData);
1339 uint16_t *outDataPos = reinterpret_cast<uint16_t *>(outRawData);
1340 std::copy_n(inRawDataPos, numElements, outDataPos);
1341 break;
1342 }
1343 case 32: {
1344 const ulittle32_t *inRawDataPos =
1345 reinterpret_cast<const ulittle32_t *>(inRawData);
1346 uint32_t *outDataPos = reinterpret_cast<uint32_t *>(outRawData);
1347 std::copy_n(inRawDataPos, numElements, outDataPos);
1348 break;
1349 }
1350 case 64: {
1351 const ulittle64_t *inRawDataPos =
1352 reinterpret_cast<const ulittle64_t *>(inRawData);
1353 uint64_t *outDataPos = reinterpret_cast<uint64_t *>(outRawData);
1354 std::copy_n(inRawDataPos, numElements, outDataPos);
1355 break;
1356 }
1357 default: {
1358 size_t nBytes = elementBitWidth / CHAR_BIT8;
1359 for (size_t i = 0; i < nBytes; i++)
1360 std::copy_n(inRawData + (nBytes - 1 - i), 1, outRawData + i);
1361 break;
1362 }
1363 }
1364}
1365
1366void DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
1367 ArrayRef<char> inRawData, MutableArrayRef<char> outRawData,
1368 ShapedType type) {
1369 size_t numElements = type.getNumElements();
1370 Type elementType = type.getElementType();
1371 if (ComplexType complexTy = elementType.dyn_cast<ComplexType>()) {
1372 elementType = complexTy.getElementType();
1373 numElements = numElements * 2;
1374 }
1375 size_t elementBitWidth = getDenseElementStorageWidth(elementType);
1376 assert(numElements * elementBitWidth == inRawData.size() * CHAR_BIT &&(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1377, __extension__ __PRETTY_FUNCTION__
))
1377 inRawData.size() <= outRawData.size())(static_cast <bool> (numElements * elementBitWidth == inRawData
.size() * 8 && inRawData.size() <= outRawData.size
()) ? void (0) : __assert_fail ("numElements * elementBitWidth == inRawData.size() * CHAR_BIT && inRawData.size() <= outRawData.size()"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1377, __extension__ __PRETTY_FUNCTION__
))
;
1378 if (elementBitWidth <= CHAR_BIT8)
1379 std::memcpy(outRawData.begin(), inRawData.begin(), inRawData.size());
1380 else
1381 convertEndianOfCharForBEmachine(inRawData.begin(), outRawData.begin(),
1382 elementBitWidth, numElements);
1383}
1384
1385//===----------------------------------------------------------------------===//
1386// DenseFPElementsAttr
1387//===----------------------------------------------------------------------===//
1388
1389template <typename Fn, typename Attr>
1390static ShapedType mappingHelper(Fn mapping, Attr &attr, ShapedType inType,
1391 Type newElementType,
1392 llvm::SmallVectorImpl<char> &data) {
1393 size_t bitWidth = getDenseElementBitWidth(newElementType);
1394 size_t storageBitWidth = getDenseElementStorageWidth(bitWidth);
1395
1396 ShapedType newArrayType = inType.cloneWith(inType.getShape(), newElementType);
1397
1398 size_t numRawElements = attr.isSplat() ? 1 : newArrayType.getNumElements();
1399 data.resize(llvm::divideCeil(storageBitWidth * numRawElements, CHAR_BIT8));
1400
1401 // Functor used to process a single element value of the attribute.
1402 auto processElt = [&](decltype(*attr.begin()) value, size_t index) {
1403 auto newInt = mapping(value);
1404 assert(newInt.getBitWidth() == bitWidth)(static_cast <bool> (newInt.getBitWidth() == bitWidth) ?
void (0) : __assert_fail ("newInt.getBitWidth() == bitWidth"
, "mlir/lib/IR/BuiltinAttributes.cpp", 1404, __extension__ __PRETTY_FUNCTION__
))
;
1405 writeBits(data.data(), index * storageBitWidth, newInt);
1406 };
1407
1408 // Check for the splat case.
1409 if (attr.isSplat()) {
1410 if (bitWidth == 1) {
1411 // Handle the special encoding of splat of bool.
1412 data[0] = mapping(*attr.begin()).isZero() ? 0 : -1;
1413 } else {
1414 processElt(*attr.begin(), /*index=*/0);
1415 }
1416 return newArrayType;
1417 }
1418
1419 // Otherwise, process all of the element values.
1420 uint64_t elementIdx = 0;
1421 for (auto value : attr)
1422 processElt(value, elementIdx++);
1423 return newArrayType;
1424}
1425
1426DenseElementsAttr DenseFPElementsAttr::mapValues(
1427 Type newElementType, function_ref<APInt(const APFloat &)> mapping) const {
1428 llvm::SmallVector<char, 8> elementData;
1429 auto newArrayType =
1430 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1431
1432 return getRaw(newArrayType, elementData);
1433}
1434
1435/// Method for supporting type inquiry through isa, cast and dyn_cast.
1436bool DenseFPElementsAttr::classof(Attribute attr) {
1437 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1438 return denseAttr.getType().getElementType().isa<FloatType>();
1439 return false;
1440}
1441
1442//===----------------------------------------------------------------------===//
1443// DenseIntElementsAttr
1444//===----------------------------------------------------------------------===//
1445
1446DenseElementsAttr DenseIntElementsAttr::mapValues(
1447 Type newElementType, function_ref<APInt(const APInt &)> mapping) const {
1448 llvm::SmallVector<char, 8> elementData;
1449 auto newArrayType =
1450 mappingHelper(mapping, *this, getType(), newElementType, elementData);
1451 return getRaw(newArrayType, elementData);
1452}
1453
1454/// Method for supporting type inquiry through isa, cast and dyn_cast.
1455bool DenseIntElementsAttr::classof(Attribute attr) {
1456 if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>())
1457 return denseAttr.getType().getElementType().isIntOrIndex();
1458 return false;
1459}
1460
1461//===----------------------------------------------------------------------===//
1462// DenseResourceElementsAttr
1463//===----------------------------------------------------------------------===//
1464
1465DenseResourceElementsAttr
1466DenseResourceElementsAttr::get(ShapedType type,
1467 DenseResourceElementsHandle handle) {
1468 return Base::get(type.getContext(), type, handle);
1469}
1470
1471DenseResourceElementsAttr DenseResourceElementsAttr::get(ShapedType type,
1472 StringRef blobName,
1473 AsmResourceBlob blob) {
1474 // Extract the builtin dialect resource manager from context and construct a
1475 // handle by inserting a new resource using the provided blob.
1476 auto &manager =
1477 DenseResourceElementsHandle::getManagerInterface(type.getContext());
1478 return get(type, manager.insert(blobName, std::move(blob)));
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// DenseResourceElementsAttrBase
1483
1484namespace {
1485/// Instantiations of this class provide utilities for interacting with native
1486/// data types in the context of DenseResourceElementsAttr.
1487template <typename T>
1488struct DenseResourceAttrUtil;
1489template <size_t width, bool isSigned>
1490struct DenseResourceElementsAttrIntUtil {
1491 static bool checkElementType(Type eltType) {
1492 IntegerType type = eltType.dyn_cast<IntegerType>();
1493 if (!type || type.getWidth() != width)
1494 return false;
1495 return isSigned ? !type.isUnsigned() : !type.isSigned();
1496 }
1497};
1498template <>
1499struct DenseResourceAttrUtil<bool> {
1500 static bool checkElementType(Type eltType) {
1501 return eltType.isSignlessInteger(1);
1502 }
1503};
1504template <>
1505struct DenseResourceAttrUtil<int8_t>
1506 : public DenseResourceElementsAttrIntUtil<8, true> {};
1507template <>
1508struct DenseResourceAttrUtil<uint8_t>
1509 : public DenseResourceElementsAttrIntUtil<8, false> {};
1510template <>
1511struct DenseResourceAttrUtil<int16_t>
1512 : public DenseResourceElementsAttrIntUtil<16, true> {};
1513template <>
1514struct DenseResourceAttrUtil<uint16_t>
1515 : public DenseResourceElementsAttrIntUtil<16, false> {};
1516template <>
1517struct DenseResourceAttrUtil<int32_t>
1518 : public DenseResourceElementsAttrIntUtil<32, true> {};
1519template <>
1520struct DenseResourceAttrUtil<uint32_t>
1521 : public DenseResourceElementsAttrIntUtil<32, false> {};
1522template <>
1523struct DenseResourceAttrUtil<int64_t>
1524 : public DenseResourceElementsAttrIntUtil<64, true> {};
1525template <>
1526struct DenseResourceAttrUtil<uint64_t>
1527 : public DenseResourceElementsAttrIntUtil<64, false> {};
1528template <>
1529struct DenseResourceAttrUtil<float> {
1530 static bool checkElementType(Type eltType) { return eltType.isF32(); }
1531};
1532template <>
1533struct DenseResourceAttrUtil<double> {
1534 static bool checkElementType(Type eltType) { return eltType.isF64(); }
1535};
1536} // namespace
1537
1538template <typename T>
1539DenseResourceElementsAttrBase<T>
1540DenseResourceElementsAttrBase<T>::get(ShapedType type, StringRef blobName,
1541 AsmResourceBlob blob) {
1542 // Check that the blob is in the form we were expecting.
1543 assert(blob.getDataAlignment() == alignof(T) &&(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1544, __extension__ __PRETTY_FUNCTION__
))
1544 "alignment mismatch between expected alignment and blob alignment")(static_cast <bool> (blob.getDataAlignment() == alignof
(T) && "alignment mismatch between expected alignment and blob alignment"
) ? void (0) : __assert_fail ("blob.getDataAlignment() == alignof(T) && \"alignment mismatch between expected alignment and blob alignment\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1544, __extension__ __PRETTY_FUNCTION__
))
;
1545 assert(((blob.getData().size() % sizeof(T)) == 0) &&(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1546, __extension__ __PRETTY_FUNCTION__
))
1546 "size mismatch between expected element width and blob size")(static_cast <bool> (((blob.getData().size() % sizeof(T
)) == 0) && "size mismatch between expected element width and blob size"
) ? void (0) : __assert_fail ("((blob.getData().size() % sizeof(T)) == 0) && \"size mismatch between expected element width and blob size\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1546, __extension__ __PRETTY_FUNCTION__
))
;
1547 assert(DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) &&(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1548, __extension__ __PRETTY_FUNCTION__
))
1548 "invalid shape element type for provided type `T`")(static_cast <bool> (DenseResourceAttrUtil<T>::checkElementType
(type.getElementType()) && "invalid shape element type for provided type `T`"
) ? void (0) : __assert_fail ("DenseResourceAttrUtil<T>::checkElementType(type.getElementType()) && \"invalid shape element type for provided type `T`\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1548, __extension__ __PRETTY_FUNCTION__
))
;
1549 return DenseResourceElementsAttr::get(type, blobName, std::move(blob))
1550 .template cast<DenseResourceElementsAttrBase<T>>();
1551}
1552
1553template <typename T>
1554Optional<ArrayRef<T>>
1555DenseResourceElementsAttrBase<T>::tryGetAsArrayRef() const {
1556 if (AsmResourceBlob *blob = this->getRawHandle().getBlob())
1557 return blob->template getDataAs<T>();
1558 return std::nullopt;
1559}
1560
1561template <typename T>
1562bool DenseResourceElementsAttrBase<T>::classof(Attribute attr) {
1563 auto resourceAttr = attr.dyn_cast<DenseResourceElementsAttr>();
1564 return resourceAttr && DenseResourceAttrUtil<T>::checkElementType(
1565 resourceAttr.getElementType());
1566}
1567
1568namespace mlir {
1569namespace detail {
1570// Explicit instantiation for all the supported DenseResourceElementsAttr.
1571template class DenseResourceElementsAttrBase<bool>;
1572template class DenseResourceElementsAttrBase<int8_t>;
1573template class DenseResourceElementsAttrBase<int16_t>;
1574template class DenseResourceElementsAttrBase<int32_t>;
1575template class DenseResourceElementsAttrBase<int64_t>;
1576template class DenseResourceElementsAttrBase<uint8_t>;
1577template class DenseResourceElementsAttrBase<uint16_t>;
1578template class DenseResourceElementsAttrBase<uint32_t>;
1579template class DenseResourceElementsAttrBase<uint64_t>;
1580template class DenseResourceElementsAttrBase<float>;
1581template class DenseResourceElementsAttrBase<double>;
1582} // namespace detail
1583} // namespace mlir
1584
1585//===----------------------------------------------------------------------===//
1586// SparseElementsAttr
1587//===----------------------------------------------------------------------===//
1588
1589/// Get a zero APFloat for the given sparse attribute.
1590APFloat SparseElementsAttr::getZeroAPFloat() const {
1591 auto eltType = getElementType().cast<FloatType>();
1592 return APFloat(eltType.getFloatSemantics());
1593}
1594
1595/// Get a zero APInt for the given sparse attribute.
1596APInt SparseElementsAttr::getZeroAPInt() const {
1597 auto eltType = getElementType().cast<IntegerType>();
1598 return APInt::getZero(eltType.getWidth());
1599}
1600
1601/// Get a zero attribute for the given attribute type.
1602Attribute SparseElementsAttr::getZeroAttr() const {
1603 auto eltType = getElementType();
1604
1605 // Handle floating point elements.
1606 if (eltType.isa<FloatType>())
1607 return FloatAttr::get(eltType, 0);
1608
1609 // Handle complex elements.
1610 if (auto complexTy = eltType.dyn_cast<ComplexType>()) {
1611 auto eltType = complexTy.getElementType();
1612 Attribute zero;
1613 if (eltType.isa<FloatType>())
1614 zero = FloatAttr::get(eltType, 0);
1615 else // must be integer
1616 zero = IntegerAttr::get(eltType, 0);
1617 return ArrayAttr::get(complexTy.getContext(),
1618 ArrayRef<Attribute>{zero, zero});
1619 }
1620
1621 // Handle string type.
1622 if (getValues().isa<DenseStringElementsAttr>())
1623 return StringAttr::get("", eltType);
1624
1625 // Otherwise, this is an integer.
1626 return IntegerAttr::get(eltType, 0);
1627}
1628
1629/// Flatten, and return, all of the sparse indices in this attribute in
1630/// row-major order.
1631std::vector<ptrdiff_t> SparseElementsAttr::getFlattenedSparseIndices() const {
1632 std::vector<ptrdiff_t> flatSparseIndices;
1633
1634 // The sparse indices are 64-bit integers, so we can reinterpret the raw data
1635 // as a 1-D index array.
1636 auto sparseIndices = getIndices();
1637 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1638 if (sparseIndices.isSplat()) {
1639 SmallVector<uint64_t, 8> indices(getType().getRank(),
1640 *sparseIndexValues.begin());
1641 flatSparseIndices.push_back(getFlattenedIndex(indices));
1642 return flatSparseIndices;
1643 }
1644
1645 // Otherwise, reinterpret each index as an ArrayRef when flattening.
1646 auto numSparseIndices = sparseIndices.getType().getDimSize(0);
1647 size_t rank = getType().getRank();
1648 for (size_t i = 0, e = numSparseIndices; i != e; ++i)
1649 flatSparseIndices.push_back(getFlattenedIndex(
1650 {&*std::next(sparseIndexValues.begin(), i * rank), rank}));
1651 return flatSparseIndices;
1652}
1653
1654LogicalResult
1655SparseElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
1656 ShapedType type, DenseIntElementsAttr sparseIndices,
1657 DenseElementsAttr values) {
1658 ShapedType valuesType = values.getType();
1659 if (valuesType.getRank() != 1)
1660 return emitError() << "expected 1-d tensor for sparse element values";
1661
1662 // Verify the indices and values shape.
1663 ShapedType indicesType = sparseIndices.getType();
1664 auto emitShapeError = [&]() {
1665 return emitError() << "expected shape ([" << type.getShape()
1666 << "]); inferred shape of indices literal (["
1667 << indicesType.getShape()
1668 << "]); inferred shape of values literal (["
1669 << valuesType.getShape() << "])";
1670 };
1671 // Verify indices shape.
1672 size_t rank = type.getRank(), indicesRank = indicesType.getRank();
1673 if (indicesRank == 2) {
1674 if (indicesType.getDimSize(1) != static_cast<int64_t>(rank))
1675 return emitShapeError();
1676 } else if (indicesRank != 1 || rank != 1) {
1677 return emitShapeError();
1678 }
1679 // Verify the values shape.
1680 int64_t numSparseIndices = indicesType.getDimSize(0);
1681 if (numSparseIndices != valuesType.getDimSize(0))
1682 return emitShapeError();
1683
1684 // Verify that the sparse indices are within the value shape.
1685 auto emitIndexError = [&](unsigned indexNum, ArrayRef<uint64_t> index) {
1686 return emitError()
1687 << "sparse index #" << indexNum
1688 << " is not contained within the value shape, with index=[" << index
1689 << "], and type=" << type;
1690 };
1691
1692 // Handle the case where the index values are a splat.
1693 auto sparseIndexValues = sparseIndices.getValues<uint64_t>();
1694 if (sparseIndices.isSplat()) {
1695 SmallVector<uint64_t> indices(rank, *sparseIndexValues.begin());
1696 if (!ElementsAttr::isValidIndex(type, indices))
1697 return emitIndexError(0, indices);
1698 return success();
1699 }
1700
1701 // Otherwise, reinterpret each index as an ArrayRef.
1702 for (size_t i = 0, e = numSparseIndices; i != e; ++i) {
1703 ArrayRef<uint64_t> index(&*std::next(sparseIndexValues.begin(), i * rank),
1704 rank);
1705 if (!ElementsAttr::isValidIndex(type, index))
1706 return emitIndexError(i, index);
1707 }
1708
1709 return success();
1710}
1711
1712//===----------------------------------------------------------------------===//
1713// Attribute Utilities
1714//===----------------------------------------------------------------------===//
1715
1716AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
1717 int64_t offset,
1718 MLIRContext *context) {
1719 AffineExpr expr;
1720 unsigned nSymbols = 0;
1721
1722 // AffineExpr for offset.
1723 // Static case.
1724 if (!ShapedType::isDynamic(offset)) {
1725 auto cst = getAffineConstantExpr(offset, context);
1726 expr = cst;
1727 } else {
1728 // Dynamic case, new symbol for the offset.
1729 auto sym = getAffineSymbolExpr(nSymbols++, context);
1730 expr = sym;
1731 }
1732
1733 // AffineExpr for strides.
1734 for (const auto &en : llvm::enumerate(strides)) {
1735 auto dim = en.index();
1736 auto stride = en.value();
1737 assert(stride != 0 && "Invalid stride specification")(static_cast <bool> (stride != 0 && "Invalid stride specification"
) ? void (0) : __assert_fail ("stride != 0 && \"Invalid stride specification\""
, "mlir/lib/IR/BuiltinAttributes.cpp", 1737, __extension__ __PRETTY_FUNCTION__
))
;
1738 auto d = getAffineDimExpr(dim, context);
1739 AffineExpr mult;
1740 // Static case.
1741 if (!ShapedType::isDynamic(stride))
1742 mult = getAffineConstantExpr(stride, context);
1743 else
1744 // Dynamic case, new symbol for each new stride.
1745 mult = getAffineSymbolExpr(nSymbols++, context);
1746 expr = expr + d * mult;
1747 }
1748
1749 return AffineMap::get(strides.size(), nSymbols, expr);
1750}

/build/source/mlir/include/mlir/IR/OpImplementation.h

1//===- OpImplementation.h - Classes for implementing Op types ---*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This classes used by the implementation details of Op types.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef MLIR_IR_OPIMPLEMENTATION_H
14#define MLIR_IR_OPIMPLEMENTATION_H
15
16#include "mlir/IR/BuiltinTypes.h"
17#include "mlir/IR/DialectInterface.h"
18#include "mlir/IR/OpDefinition.h"
19#include "llvm/ADT/Twine.h"
20#include "llvm/Support/SMLoc.h"
21
22namespace mlir {
23class AsmParsedResourceEntry;
24class AsmResourceBuilder;
25class Builder;
26
27//===----------------------------------------------------------------------===//
28// AsmDialectResourceHandle
29//===----------------------------------------------------------------------===//
30
31/// This class represents an opaque handle to a dialect resource entry.
32class AsmDialectResourceHandle {
33public:
34 AsmDialectResourceHandle() = default;
35 AsmDialectResourceHandle(void *resource, TypeID resourceID, Dialect *dialect)
36 : resource(resource), opaqueID(resourceID), dialect(dialect) {}
37 bool operator==(const AsmDialectResourceHandle &other) const {
38 return resource == other.resource;
39 }
40
41 /// Return an opaque pointer to the referenced resource.
42 void *getResource() const { return resource; }
43
44 /// Return the type ID of the resource.
45 TypeID getTypeID() const { return opaqueID; }
46
47 /// Return the dialect that owns the resource.
48 Dialect *getDialect() const { return dialect; }
49
50private:
51 /// The opaque handle to the dialect resource.
52 void *resource = nullptr;
53 /// The type of the resource referenced.
54 TypeID opaqueID;
55 /// The dialect owning the given resource.
56 Dialect *dialect;
57};
58
59/// This class represents a CRTP base class for dialect resource handles. It
60/// abstracts away various utilities necessary for defined derived resource
61/// handles.
62template <typename DerivedT, typename ResourceT, typename DialectT>
63class AsmDialectResourceHandleBase : public AsmDialectResourceHandle {
64public:
65 using Dialect = DialectT;
66
67 /// Construct a handle from a pointer to the resource. The given pointer
68 /// should be guaranteed to live beyond the life of this handle.
69 AsmDialectResourceHandleBase(ResourceT *resource, DialectT *dialect)
70 : AsmDialectResourceHandle(resource, TypeID::get<DerivedT>(), dialect) {}
71 AsmDialectResourceHandleBase(AsmDialectResourceHandle handle)
72 : AsmDialectResourceHandle(handle) {
73 assert(handle.getTypeID() == TypeID::get<DerivedT>())(static_cast <bool> (handle.getTypeID() == TypeID::get<
DerivedT>()) ? void (0) : __assert_fail ("handle.getTypeID() == TypeID::get<DerivedT>()"
, "mlir/include/mlir/IR/OpImplementation.h", 73, __extension__
__PRETTY_FUNCTION__))
;
74 }
75
76 /// Return the resource referenced by this handle.
77 ResourceT *getResource() {
78 return static_cast<ResourceT *>(AsmDialectResourceHandle::getResource());
79 }
80 const ResourceT *getResource() const {
81 return const_cast<AsmDialectResourceHandleBase *>(this)->getResource();
82 }
83
84 /// Return the dialect that owns the resource.
85 DialectT *getDialect() const {
86 return static_cast<DialectT *>(AsmDialectResourceHandle::getDialect());
87 }
88
89 /// Support llvm style casting.
90 static bool classof(const AsmDialectResourceHandle *handle) {
91 return handle->getTypeID() == TypeID::get<DerivedT>();
92 }
93};
94
95inline llvm::hash_code hash_value(const AsmDialectResourceHandle &param) {
96 return llvm::hash_value(param.getResource());
97}
98
99//===----------------------------------------------------------------------===//
100// AsmPrinter
101//===----------------------------------------------------------------------===//
102
103/// This base class exposes generic asm printer hooks, usable across the various
104/// derived printers.
105class AsmPrinter {
106public:
107 /// This class contains the internal default implementation of the base
108 /// printer methods.
109 class Impl;
110
111 /// Initialize the printer with the given internal implementation.
112 AsmPrinter(Impl &impl) : impl(&impl) {}
113 virtual ~AsmPrinter();
114
115 /// Return the raw output stream used by this printer.
116 virtual raw_ostream &getStream() const;
117
118 /// Print the given floating point value in a stabilized form that can be
119 /// roundtripped through the IR. This is the companion to the 'parseFloat'
120 /// hook on the AsmParser.
121 virtual void printFloat(const APFloat &value);
122
123 virtual void printType(Type type);
124 virtual void printAttribute(Attribute attr);
125
126 /// Trait to check if `AttrType` provides a `print` method.
127 template <typename AttrOrType>
128 using has_print_method =
129 decltype(std::declval<AttrOrType>().print(std::declval<AsmPrinter &>()));
130 template <typename AttrOrType>
131 using detect_has_print_method =
132 llvm::is_detected<has_print_method, AttrOrType>;
133
134 /// Print the provided attribute in the context of an operation custom
135 /// printer/parser: this will invoke directly the print method on the
136 /// attribute class and skip the `#dialect.mnemonic` prefix in most cases.
137 template <typename AttrOrType,
138 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
139 *sfinae = nullptr>
140 void printStrippedAttrOrType(AttrOrType attrOrType) {
141 if (succeeded(printAlias(attrOrType)))
142 return;
143 attrOrType.print(*this);
144 }
145
146 /// Print the provided array of attributes or types in the context of an
147 /// operation custom printer/parser: this will invoke directly the print
148 /// method on the attribute class and skip the `#dialect.mnemonic` prefix in
149 /// most cases.
150 template <typename AttrOrType,
151 std::enable_if_t<detect_has_print_method<AttrOrType>::value>
152 *sfinae = nullptr>
153 void printStrippedAttrOrType(ArrayRef<AttrOrType> attrOrTypes) {
154 llvm::interleaveComma(
155 attrOrTypes, getStream(),
156 [this](AttrOrType attrOrType) { printStrippedAttrOrType(attrOrType); });
157 }
158
159 /// SFINAE for printing the provided attribute in the context of an operation
160 /// custom printer in the case where the attribute does not define a print
161 /// method.
162 template <typename AttrOrType,
163 std::enable_if_t<!detect_has_print_method<AttrOrType>::value>
164 *sfinae = nullptr>
165 void printStrippedAttrOrType(AttrOrType attrOrType) {
166 *this << attrOrType;
167 }
168
169 /// Print the given attribute without its type. The corresponding parser must
170 /// provide a valid type for the attribute.
171 virtual void printAttributeWithoutType(Attribute attr);
172
173 /// Print the given string as a keyword, or a quoted and escaped string if it
174 /// has any special or non-printable characters in it.
175 virtual void printKeywordOrString(StringRef keyword);
176
177 /// Print the given string as a symbol reference, i.e. a form representable by
178 /// a SymbolRefAttr. A symbol reference is represented as a string prefixed
179 /// with '@'. The reference is surrounded with ""'s and escaped if it has any
180 /// special or non-printable characters in it.
181 virtual void printSymbolName(StringRef symbolRef);
182
183 /// Print a handle to the given dialect resource.
184 virtual void printResourceHandle(const AsmDialectResourceHandle &resource);
185
186 /// Print an optional arrow followed by a type list.
187 template <typename TypeRange>
188 void printOptionalArrowTypeList(TypeRange &&types) {
189 if (types.begin() != types.end())
190 printArrowTypeList(types);
191 }
192 template <typename TypeRange>
193 void printArrowTypeList(TypeRange &&types) {
194 auto &os = getStream() << " -> ";
195
196 bool wrapped = !llvm::hasSingleElement(types) ||
197 (*types.begin()).template isa<FunctionType>();
198 if (wrapped)
199 os << '(';
200 llvm::interleaveComma(types, *this);
201 if (wrapped)
202 os << ')';
203 }
204
205 /// Print the two given type ranges in a functional form.
206 template <typename InputRangeT, typename ResultRangeT>
207 void printFunctionalType(InputRangeT &&inputs, ResultRangeT &&results) {
208 auto &os = getStream();
209 os << '(';
210 llvm::interleaveComma(inputs, *this);
211 os << ')';
212 printArrowTypeList(results);
213 }
214
215protected:
216 /// Initialize the printer with no internal implementation. In this case, all
217 /// virtual methods of this class must be overriden.
218 AsmPrinter() = default;
219
220private:
221 AsmPrinter(const AsmPrinter &) = delete;
222 void operator=(const AsmPrinter &) = delete;
223
224 /// Print the alias for the given attribute, return failure if no alias could
225 /// be printed.
226 virtual LogicalResult printAlias(Attribute attr);
227
228 /// Print the alias for the given type, return failure if no alias could
229 /// be printed.
230 virtual LogicalResult printAlias(Type type);
231
232 /// The internal implementation of the printer.
233 Impl *impl{nullptr};
234};
235
236template <typename AsmPrinterT>
237inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
238 AsmPrinterT &>
239operator<<(AsmPrinterT &p, Type type) {
240 p.printType(type);
241 return p;
242}
243
244template <typename AsmPrinterT>
245inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
246 AsmPrinterT &>
247operator<<(AsmPrinterT &p, Attribute attr) {
248 p.printAttribute(attr);
249 return p;
250}
251
252template <typename AsmPrinterT>
253inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
254 AsmPrinterT &>
255operator<<(AsmPrinterT &p, const APFloat &value) {
256 p.printFloat(value);
257 return p;
258}
259template <typename AsmPrinterT>
260inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
261 AsmPrinterT &>
262operator<<(AsmPrinterT &p, float value) {
263 return p << APFloat(value);
264}
265template <typename AsmPrinterT>
266inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
267 AsmPrinterT &>
268operator<<(AsmPrinterT &p, double value) {
269 return p << APFloat(value);
270}
271
272// Support printing anything that isn't convertible to one of the other
273// streamable types, even if it isn't exactly one of them. For example, we want
274// to print FunctionType with the Type version above, not have it match this.
275template <typename AsmPrinterT, typename T,
276 std::enable_if_t<!std::is_convertible<T &, Value &>::value &&
277 !std::is_convertible<T &, Type &>::value &&
278 !std::is_convertible<T &, Attribute &>::value &&
279 !std::is_convertible<T &, ValueRange>::value &&
280 !std::is_convertible<T &, APFloat &>::value &&
281 !llvm::is_one_of<T, bool, float, double>::value,
282 T> * = nullptr>
283inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
284 AsmPrinterT &>
285operator<<(AsmPrinterT &p, const T &other) {
286 p.getStream() << other;
287 return p;
288}
289
290template <typename AsmPrinterT>
291inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
292 AsmPrinterT &>
293operator<<(AsmPrinterT &p, bool value) {
294 return p << (value ? StringRef("true") : "false");
295}
296
297template <typename AsmPrinterT, typename ValueRangeT>
298inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
299 AsmPrinterT &>
300operator<<(AsmPrinterT &p, const ValueTypeRange<ValueRangeT> &types) {
301 llvm::interleaveComma(types, p);
302 return p;
303}
304template <typename AsmPrinterT>
305inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
306 AsmPrinterT &>
307operator<<(AsmPrinterT &p, const TypeRange &types) {
308 llvm::interleaveComma(types, p);
309 return p;
310}
311template <typename AsmPrinterT, typename ElementT>
312inline std::enable_if_t<std::is_base_of<AsmPrinter, AsmPrinterT>::value,
313 AsmPrinterT &>
314operator<<(AsmPrinterT &p, ArrayRef<ElementT> types) {
315 llvm::interleaveComma(types, p);
316 return p;
317}
318
319//===----------------------------------------------------------------------===//
320// OpAsmPrinter
321//===----------------------------------------------------------------------===//
322
323/// This is a pure-virtual base class that exposes the asmprinter hooks
324/// necessary to implement a custom print() method.
325class OpAsmPrinter : public AsmPrinter {
326public:
327 using AsmPrinter::AsmPrinter;
328 ~OpAsmPrinter() override;
329
330 /// Print a loc(...) specifier if printing debug info is enabled.
331 virtual void printOptionalLocationSpecifier(Location loc) = 0;
332
333 /// Print a newline and indent the printer to the start of the current
334 /// operation.
335 virtual void printNewline() = 0;
336
337 /// Increase indentation.
338 virtual void increaseIndent() = 0;
339
340 /// Decrease indentation.
341 virtual void decreaseIndent() = 0;
342
343 /// Print a block argument in the usual format of:
344 /// %ssaName : type {attr1=42} loc("here")
345 /// where location printing is controlled by the standard internal option.
346 /// You may pass omitType=true to not print a type, and pass an empty
347 /// attribute list if you don't care for attributes.
348 virtual void printRegionArgument(BlockArgument arg,
349 ArrayRef<NamedAttribute> argAttrs = {},
350 bool omitType = false) = 0;
351
352 /// Print implementations for various things an operation contains.
353 virtual void printOperand(Value value) = 0;
354 virtual void printOperand(Value value, raw_ostream &os) = 0;
355
356 /// Print a comma separated list of operands.
357 template <typename ContainerType>
358 void printOperands(const ContainerType &container) {
359 printOperands(container.begin(), container.end());
360 }
361
362 /// Print a comma separated list of operands.
363 template <typename IteratorType>
364 void printOperands(IteratorType it, IteratorType end) {
365 llvm::interleaveComma(llvm::make_range(it, end), getStream(),
366 [this](Value value) { printOperand(value); });
367 }
368
369 /// Print the given successor.
370 virtual void printSuccessor(Block *successor) = 0;
371
372 /// Print the successor and its operands.
373 virtual void printSuccessorAndUseList(Block *successor,
374 ValueRange succOperands) = 0;
375
376 /// If the specified operation has attributes, print out an attribute
377 /// dictionary with their values. elidedAttrs allows the client to ignore
378 /// specific well known attributes, commonly used if the attribute value is
379 /// printed some other way (like as a fixed operand).
380 virtual void printOptionalAttrDict(ArrayRef<NamedAttribute> attrs,
381 ArrayRef<StringRef> elidedAttrs = {}) = 0;
382
383 /// If the specified operation has attributes, print out an attribute
384 /// dictionary prefixed with 'attributes'.
385 virtual void
386 printOptionalAttrDictWithKeyword(ArrayRef<NamedAttribute> attrs,
387 ArrayRef<StringRef> elidedAttrs = {}) = 0;
388
389 /// Prints the entire operation with the custom assembly form, if available,
390 /// or the generic assembly form, otherwise.
391 virtual void printCustomOrGenericOp(Operation *op) = 0;
392
393 /// Print the entire operation with the default generic assembly form.
394 /// If `printOpName` is true, then the operation name is printed (the default)
395 /// otherwise it is omitted and the print will start with the operand list.
396 virtual void printGenericOp(Operation *op, bool printOpName = true) = 0;
397
398 /// Prints a region.
399 /// If 'printEntryBlockArgs' is false, the arguments of the
400 /// block are not printed. If 'printBlockTerminator' is false, the terminator
401 /// operation of the block is not printed. If printEmptyBlock is true, then
402 /// the block header is printed even if the block is empty.
403 virtual void printRegion(Region &blocks, bool printEntryBlockArgs = true,
404 bool printBlockTerminators = true,
405 bool printEmptyBlock = false) = 0;
406
407 /// Renumber the arguments for the specified region to the same names as the
408 /// SSA values in namesToUse. This may only be used for IsolatedFromAbove
409 /// operations. If any entry in namesToUse is null, the corresponding
410 /// argument name is left alone.
411 virtual void shadowRegionArgs(Region &region, ValueRange namesToUse) = 0;
412
413 /// Prints an affine map of SSA ids, where SSA id names are used in place
414 /// of dims/symbols.
415 /// Operand values must come from single-result sources, and be valid
416 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
417 virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr,
418 ValueRange operands) = 0;
419
420 /// Prints an affine expression of SSA ids with SSA id names used instead of
421 /// dims and symbols.
422 /// Operand values must come from single-result sources, and be valid
423 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
424 virtual void printAffineExprOfSSAIds(AffineExpr expr, ValueRange dimOperands,
425 ValueRange symOperands) = 0;
426
427 /// Print the complete type of an operation in functional form.
428 void printFunctionalType(Operation *op);
429 using AsmPrinter::printFunctionalType;
430};
431
432// Make the implementations convenient to use.
433inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Value value) {
434 p.printOperand(value);
435 return p;
436}
437
438template <typename T,
439 std::enable_if_t<std::is_convertible<T &, ValueRange>::value &&
440 !std::is_convertible<T &, Value &>::value,
441 T> * = nullptr>
442inline OpAsmPrinter &operator<<(OpAsmPrinter &p, const T &values) {
443 p.printOperands(values);
444 return p;
445}
446
447inline OpAsmPrinter &operator<<(OpAsmPrinter &p, Block *value) {
448 p.printSuccessor(value);
449 return p;
450}
451
452//===----------------------------------------------------------------------===//
453// AsmParser
454//===----------------------------------------------------------------------===//
455
456/// This base class exposes generic asm parser hooks, usable across the various
457/// derived parsers.
458class AsmParser {
459public:
460 AsmParser() = default;
461 virtual ~AsmParser();
462
463 MLIRContext *getContext() const;
464
465 /// Return the location of the original name token.
466 virtual SMLoc getNameLoc() const = 0;
467
468 //===--------------------------------------------------------------------===//
469 // Utilities
470 //===--------------------------------------------------------------------===//
471
472 /// Emit a diagnostic at the specified location and return failure.
473 virtual InFlightDiagnostic emitError(SMLoc loc,
474 const Twine &message = {}) = 0;
475
476 /// Return a builder which provides useful access to MLIRContext, global
477 /// objects like types and attributes.
478 virtual Builder &getBuilder() const = 0;
479
480 /// Get the location of the next token and store it into the argument. This
481 /// always succeeds.
482 virtual SMLoc getCurrentLocation() = 0;
483 ParseResult getCurrentLocation(SMLoc *loc) {
484 *loc = getCurrentLocation();
485 return success();
486 }
487
488 /// Re-encode the given source location as an MLIR location and return it.
489 /// Note: This method should only be used when a `Location` is necessary, as
490 /// the encoding process is not efficient.
491 virtual Location getEncodedSourceLoc(SMLoc loc) = 0;
492
493 //===--------------------------------------------------------------------===//
494 // Token Parsing
495 //===--------------------------------------------------------------------===//
496
497 /// Parse a '->' token.
498 virtual ParseResult parseArrow() = 0;
499
500 /// Parse a '->' token if present
501 virtual ParseResult parseOptionalArrow() = 0;
502
503 /// Parse a `{` token.
504 virtual ParseResult parseLBrace() = 0;
505
506 /// Parse a `{` token if present.
507 virtual ParseResult parseOptionalLBrace() = 0;
508
509 /// Parse a `}` token.
510 virtual ParseResult parseRBrace() = 0;
511
512 /// Parse a `}` token if present.
513 virtual ParseResult parseOptionalRBrace() = 0;
514
515 /// Parse a `:` token.
516 virtual ParseResult parseColon() = 0;
517
518 /// Parse a `:` token if present.
519 virtual ParseResult parseOptionalColon() = 0;
520
521 /// Parse a `,` token.
522 virtual ParseResult parseComma() = 0;
523
524 /// Parse a `,` token if present.
525 virtual ParseResult parseOptionalComma() = 0;
526
527 /// Parse a `=` token.
528 virtual ParseResult parseEqual() = 0;
529
530 /// Parse a `=` token if present.
531 virtual ParseResult parseOptionalEqual() = 0;
532
533 /// Parse a '<' token.
534 virtual ParseResult parseLess() = 0;
535
536 /// Parse a '<' token if present.
537 virtual ParseResult parseOptionalLess() = 0;
538
539 /// Parse a '>' token.
540 virtual ParseResult parseGreater() = 0;
541
542 /// Parse a '>' token if present.
543 virtual ParseResult parseOptionalGreater() = 0;
544
545 /// Parse a '?' token.
546 virtual ParseResult parseQuestion() = 0;
547
548 /// Parse a '?' token if present.
549 virtual ParseResult parseOptionalQuestion() = 0;
550
551 /// Parse a '+' token.
552 virtual ParseResult parsePlus() = 0;
553
554 /// Parse a '+' token if present.
555 virtual ParseResult parseOptionalPlus() = 0;
556
557 /// Parse a '*' token.
558 virtual ParseResult parseStar() = 0;
559
560 /// Parse a '*' token if present.
561 virtual ParseResult parseOptionalStar() = 0;
562
563 /// Parse a '|' token.
564 virtual ParseResult parseVerticalBar() = 0;
565
566 /// Parse a '|' token if present.
567 virtual ParseResult parseOptionalVerticalBar() = 0;
568
569 /// Parse a quoted string token.
570 ParseResult parseString(std::string *string) {
571 auto loc = getCurrentLocation();
572 if (parseOptionalString(string))
573 return emitError(loc, "expected string");
574 return success();
575 }
576
577 /// Parse a quoted string token if present.
578 virtual ParseResult parseOptionalString(std::string *string) = 0;
579
580 /// Parses a Base64 encoded string of bytes.
581 virtual ParseResult parseBase64Bytes(std::vector<char> *bytes) = 0;
582
583 /// Parse a `(` token.
584 virtual ParseResult parseLParen() = 0;
585
586 /// Parse a `(` token if present.
587 virtual ParseResult parseOptionalLParen() = 0;
588
589 /// Parse a `)` token.
590 virtual ParseResult parseRParen() = 0;
591
592 /// Parse a `)` token if present.
593 virtual ParseResult parseOptionalRParen() = 0;
594
595 /// Parse a `[` token.
596 virtual ParseResult parseLSquare() = 0;
597
598 /// Parse a `[` token if present.
599 virtual ParseResult parseOptionalLSquare() = 0;
600
601 /// Parse a `]` token.
602 virtual ParseResult parseRSquare() = 0;
603
604 /// Parse a `]` token if present.
605 virtual ParseResult parseOptionalRSquare() = 0;
606
607 /// Parse a `...` token.
608 virtual ParseResult parseEllipsis() = 0;
609
610 /// Parse a `...` token if present;
611 virtual ParseResult parseOptionalEllipsis() = 0;
612
613 /// Parse a floating point value from the stream.
614 virtual ParseResult parseFloat(double &result) = 0;
615
616 /// Parse an integer value from the stream.
617 template <typename IntT>
618 ParseResult parseInteger(IntT &result) {
619 auto loc = getCurrentLocation();
620 OptionalParseResult parseResult = parseOptionalInteger(result);
4
Calling 'AsmParser::parseOptionalInteger'
8
Returning from 'AsmParser::parseOptionalInteger'
621 if (!parseResult.has_value())
9
Taking true branch
622 return emitError(loc, "expected integer value");
10
Returning without writing to 'result'
623 return *parseResult;
624 }
625
626 /// Parse an optional integer value from the stream.
627 virtual OptionalParseResult parseOptionalInteger(APInt &result) = 0;
628
629 template <typename IntT>
630 OptionalParseResult parseOptionalInteger(IntT &result) {
631 auto loc = getCurrentLocation();
632
633 // Parse the unsigned variant.
634 APInt uintResult;
635 OptionalParseResult parseResult = parseOptionalInteger(uintResult);
636 if (!parseResult.has_value() || failed(*parseResult))
5
Assuming the condition is true
6
Taking true branch
637 return parseResult;
7
Returning without writing to 'result'
638
639 // Try to convert to the provided integer type. sextOrTrunc is correct even
640 // for unsigned types because parseOptionalInteger ensures the sign bit is
641 // zero for non-negated integers.
642 result =
643 (IntT)uintResult.sextOrTrunc(sizeof(IntT) * CHAR_BIT8).getLimitedValue();
644 if (APInt(uintResult.getBitWidth(), result) != uintResult)
645 return emitError(loc, "integer value too large");
646 return success();
647 }
648
649 /// These are the supported delimiters around operand lists and region
650 /// argument lists, used by parseOperandList.
651 enum class Delimiter {
652 /// Zero or more operands with no delimiters.
653 None,
654 /// Parens surrounding zero or more operands.
655 Paren,
656 /// Square brackets surrounding zero or more operands.
657 Square,
658 /// <> brackets surrounding zero or more operands.
659 LessGreater,
660 /// {} brackets surrounding zero or more operands.
661 Braces,
662 /// Parens supporting zero or more operands, or nothing.
663 OptionalParen,
664 /// Square brackets supporting zero or more ops, or nothing.
665 OptionalSquare,
666 /// <> brackets supporting zero or more ops, or nothing.
667 OptionalLessGreater,
668 /// {} brackets surrounding zero or more operands, or nothing.
669 OptionalBraces,
670 };
671
672 /// Parse a list of comma-separated items with an optional delimiter. If a
673 /// delimiter is provided, then an empty list is allowed. If not, then at
674 /// least one element will be parsed.
675 ///
676 /// contextMessage is an optional message appended to "expected '('" sorts of
677 /// diagnostics when parsing the delimeters.
678 virtual ParseResult
679 parseCommaSeparatedList(Delimiter delimiter,
680 function_ref<ParseResult()> parseElementFn,
681 StringRef contextMessage = StringRef()) = 0;
682
683 /// Parse a comma separated list of elements that must have at least one entry
684 /// in it.
685 ParseResult
686 parseCommaSeparatedList(function_ref<ParseResult()> parseElementFn) {
687 return parseCommaSeparatedList(Delimiter::None, parseElementFn);
688 }
689
690 //===--------------------------------------------------------------------===//
691 // Keyword Parsing
692 //===--------------------------------------------------------------------===//
693
694 /// This class represents a StringSwitch like class that is useful for parsing
695 /// expected keywords. On construction, it invokes `parseKeyword` and
696 /// processes each of the provided cases statements until a match is hit. The
697 /// provided `ResultT` must be assignable from `failure()`.
698 template <typename ResultT = ParseResult>
699 class KeywordSwitch {
700 public:
701 KeywordSwitch(AsmParser &parser)
702 : parser(parser), loc(parser.getCurrentLocation()) {
703 if (failed(parser.parseKeywordOrCompletion(&keyword)))
704 result = failure();
705 }
706
707 /// Case that uses the provided value when true.
708 KeywordSwitch &Case(StringLiteral str, ResultT value) {
709 return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
710 }
711 KeywordSwitch &Default(ResultT value) {
712 return Default([&](StringRef, SMLoc) { return std::move(value); });
713 }
714 /// Case that invokes the provided functor when true. The parameters passed
715 /// to the functor are the keyword, and the location of the keyword (in case
716 /// any errors need to be emitted).
717 template <typename FnT>
718 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
719 Case(StringLiteral str, FnT &&fn) {
720 if (result)
721 return *this;
722
723 // If the word was empty, record this as a completion.
724 if (keyword.empty())
725 parser.codeCompleteExpectedTokens(str);
726 else if (keyword == str)
727 result.emplace(std::move(fn(keyword, loc)));
728 return *this;
729 }
730 template <typename FnT>
731 std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
732 Default(FnT &&fn) {
733 if (!result)
734 result.emplace(fn(keyword, loc));
735 return *this;
736 }
737
738 /// Returns true if this switch has a value yet.
739 bool hasValue() const { return result.has_value(); }
740
741 /// Return the result of the switch.
742 [[nodiscard]] operator ResultT() {
743 if (!result)
744 return parser.emitError(loc, "unexpected keyword: ") << keyword;
745 return std::move(*result);
746 }
747
748 private:
749 /// The parser used to construct this switch.
750 AsmParser &parser;
751
752 /// The location of the keyword, used to emit errors as necessary.
753 SMLoc loc;
754
755 /// The parsed keyword itself.
756 StringRef keyword;
757
758 /// The result of the switch statement or none if currently unknown.
759 Optional<ResultT> result;
760 };
761
762 /// Parse a given keyword.
763 ParseResult parseKeyword(StringRef keyword) {
764 return parseKeyword(keyword, "");
765 }
766 virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
767
768 /// Parse a keyword into 'keyword'.
769 ParseResult parseKeyword(StringRef *keyword) {
770 auto loc = getCurrentLocation();
771 if (parseOptionalKeyword(keyword))
772 return emitError(loc, "expected valid keyword");
773 return success();
774 }
775
776 /// Parse the given keyword if present.
777 virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
778
779 /// Parse a keyword, if present, into 'keyword'.
780 virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
781
782 /// Parse a keyword, if present, and if one of the 'allowedValues',
783 /// into 'keyword'
784 virtual ParseResult
785 parseOptionalKeyword(StringRef *keyword,
786 ArrayRef<StringRef> allowedValues) = 0;
787
788 /// Parse a keyword or a quoted string.
789 ParseResult parseKeywordOrString(std::string *result) {
790 if (failed(parseOptionalKeywordOrString(result)))
791 return emitError(getCurrentLocation())
792 << "expected valid keyword or string";
793 return success();
794 }
795
796 /// Parse an optional keyword or string.
797 virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
798
799 //===--------------------------------------------------------------------===//
800 // Attribute/Type Parsing
801 //===--------------------------------------------------------------------===//
802
803 /// Invoke the `getChecked` method of the given Attribute or Type class, using
804 /// the provided location to emit errors in the case of failure. Note that
805 /// unlike `OpBuilder::getType`, this method does not implicitly insert a
806 /// context parameter.
807 template <typename T, typename... ParamsT>
808 auto getChecked(SMLoc loc, ParamsT &&...params) {
809 return T::getChecked([&] { return emitError(loc); },
810 std::forward<ParamsT>(params)...);
811 }
812 /// A variant of `getChecked` that uses the result of `getNameLoc` to emit
813 /// errors.
814 template <typename T, typename... ParamsT>
815 auto getChecked(ParamsT &&...params) {
816 return T::getChecked([&] { return emitError(getNameLoc()); },
817 std::forward<ParamsT>(params)...);
818 }
819
820 //===--------------------------------------------------------------------===//
821 // Attribute Parsing
822 //===--------------------------------------------------------------------===//
823
824 /// Parse an arbitrary attribute of a given type and return it in result.
825 virtual ParseResult parseAttribute(Attribute &result, Type type = {}) = 0;
826
827 /// Parse a custom attribute with the provided callback, unless the next
828 /// token is `#`, in which case the generic parser is invoked.
829 virtual ParseResult parseCustomAttributeWithFallback(
830 Attribute &result, Type type,
831 function_ref<ParseResult(Attribute &result, Type type)>
832 parseAttribute) = 0;
833
834 /// Parse an attribute of a specific kind and type.
835 template <typename AttrType>
836 ParseResult parseAttribute(AttrType &result, Type type = {}) {
837 SMLoc loc = getCurrentLocation();
838
839 // Parse any kind of attribute.
840 Attribute attr;
841 if (parseAttribute(attr, type))
842 return failure();
843
844 // Check for the right kind of attribute.
845 if (!(result = attr.dyn_cast<AttrType>()))
846 return emitError(loc, "invalid kind of attribute specified");
847
848 return success();
849 }
850
851 /// Parse an arbitrary attribute and return it in result. This also adds the
852 /// attribute to the specified attribute list with the specified name.
853 ParseResult parseAttribute(Attribute &result, StringRef attrName,
854 NamedAttrList &attrs) {
855 return parseAttribute(result, Type(), attrName, attrs);
856 }
857
858 /// Parse an attribute of a specific kind and type.
859 template <typename AttrType>
860 ParseResult parseAttribute(AttrType &result, StringRef attrName,
861 NamedAttrList &attrs) {
862 return parseAttribute(result, Type(), attrName, attrs);
863 }
864
865 /// Parse an arbitrary attribute of a given type and populate it in `result`.
866 /// This also adds the attribute to the specified attribute list with the
867 /// specified name.
868 template <typename AttrType>
869 ParseResult parseAttribute(AttrType &result, Type type, StringRef attrName,
870 NamedAttrList &attrs) {
871 SMLoc loc = getCurrentLocation();
872
873 // Parse any kind of attribute.
874 Attribute attr;
875 if (parseAttribute(attr, type))
876 return failure();
877
878 // Check for the right kind of attribute.
879 result = attr.dyn_cast<AttrType>();
880 if (!result)
881 return emitError(loc, "invalid kind of attribute specified");
882
883 attrs.append(attrName, result);
884 return success();
885 }
886
887 /// Trait to check if `AttrType` provides a `parse` method.
888 template <typename AttrType>
889 using has_parse_method = decltype(AttrType::parse(std::declval<AsmParser &>(),
890 std::declval<Type>()));
891 template <typename AttrType>
892 using detect_has_parse_method = llvm::is_detected<has_parse_method, AttrType>;
893
894 /// Parse a custom attribute of a given type unless the next token is `#`, in
895 /// which case the generic parser is invoked. The parsed attribute is
896 /// populated in `result` and also added to the specified attribute list with
897 /// the specified name.
898 template <typename AttrType>
899 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
900 parseCustomAttributeWithFallback(AttrType &result, Type type,
901 StringRef attrName, NamedAttrList &attrs) {
902 SMLoc loc = getCurrentLocation();
903
904 // Parse any kind of attribute.
905 Attribute attr;
906 if (parseCustomAttributeWithFallback(
907 attr, type, [&](Attribute &result, Type type) -> ParseResult {
908 result = AttrType::parse(*this, type);
909 if (!result)
910 return failure();
911 return success();
912 }))
913 return failure();
914
915 // Check for the right kind of attribute.
916 result = attr.dyn_cast<AttrType>();
917 if (!result)
918 return emitError(loc, "invalid kind of attribute specified");
919
920 attrs.append(attrName, result);
921 return success();
922 }
923
924 /// SFINAE parsing method for Attribute that don't implement a parse method.
925 template <typename AttrType>
926 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
927 parseCustomAttributeWithFallback(AttrType &result, Type type,
928 StringRef attrName, NamedAttrList &attrs) {
929 return parseAttribute(result, type, attrName, attrs);
930 }
931
932 /// Parse a custom attribute of a given type unless the next token is `#`, in
933 /// which case the generic parser is invoked. The parsed attribute is
934 /// populated in `result`.
935 template <typename AttrType>
936 std::enable_if_t<detect_has_parse_method<AttrType>::value, ParseResult>
937 parseCustomAttributeWithFallback(AttrType &result) {
938 SMLoc loc = getCurrentLocation();
939
940 // Parse any kind of attribute.
941 Attribute attr;
942 if (parseCustomAttributeWithFallback(
943 attr, {}, [&](Attribute &result, Type type) -> ParseResult {
944 result = AttrType::parse(*this, type);
945 return success(!!result);
946 }))
947 return failure();
948
949 // Check for the right kind of attribute.
950 result = attr.dyn_cast<AttrType>();
951 if (!result)
952 return emitError(loc, "invalid kind of attribute specified");
953 return success();
954 }
955
956 /// SFINAE parsing method for Attribute that don't implement a parse method.
957 template <typename AttrType>
958 std::enable_if_t<!detect_has_parse_method<AttrType>::value, ParseResult>
959 parseCustomAttributeWithFallback(AttrType &result) {
960 return parseAttribute(result);
961 }
962
963 /// Parse an arbitrary optional attribute of a given type and return it in
964 /// result.
965 virtual OptionalParseResult parseOptionalAttribute(Attribute &result,
966 Type type = {}) = 0;
967
968 /// Parse an optional array attribute and return it in result.
969 virtual OptionalParseResult parseOptionalAttribute(ArrayAttr &result,
970 Type type = {}) = 0;
971
972 /// Parse an optional string attribute and return it in result.
973 virtual OptionalParseResult parseOptionalAttribute(StringAttr &result,
974 Type type = {}) = 0;
975
976 /// Parse an optional symbol ref attribute and return it in result.
977 virtual OptionalParseResult parseOptionalAttribute(SymbolRefAttr &result,
978 Type type = {}) = 0;
979
980 /// Parse an optional attribute of a specific type and add it to the list with
981 /// the specified name.
982 template <typename AttrType>
983 OptionalParseResult parseOptionalAttribute(AttrType &result,
984 StringRef attrName,
985 NamedAttrList &attrs) {
986 return parseOptionalAttribute(result, Type(), attrName, attrs);
987 }
988
989 /// Parse an optional attribute of a specific type and add it to the list with
990 /// the specified name.
991 template <typename AttrType>
992 OptionalParseResult parseOptionalAttribute(AttrType &result, Type type,
993 StringRef attrName,
994 NamedAttrList &attrs) {
995 OptionalParseResult parseResult = parseOptionalAttribute(result, type);
996 if (parseResult.has_value() && succeeded(*parseResult))
997 attrs.append(attrName, result);
998 return parseResult;
999 }
1000
1001 /// Parse a named dictionary into 'result' if it is present.
1002 virtual ParseResult parseOptionalAttrDict(NamedAttrList &result) = 0;
1003
1004 /// Parse a named dictionary into 'result' if the `attributes` keyword is
1005 /// present.
1006 virtual ParseResult
1007 parseOptionalAttrDictWithKeyword(NamedAttrList &result) = 0;
1008
1009 /// Parse an affine map instance into 'map'.
1010 virtual ParseResult parseAffineMap(AffineMap &map) = 0;
1011
1012 /// Parse an integer set instance into 'set'.
1013 virtual ParseResult printIntegerSet(IntegerSet &set) = 0;
1014
1015 //===--------------------------------------------------------------------===//
1016 // Identifier Parsing
1017 //===--------------------------------------------------------------------===//
1018
1019 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1020 /// attribute.
1021 ParseResult parseSymbolName(StringAttr &result) {
1022 if (failed(parseOptionalSymbolName(result)))
1023 return emitError(getCurrentLocation())
1024 << "expected valid '@'-identifier for symbol name";
1025 return success();
1026 }
1027
1028 /// Parse an @-identifier and store it (without the '@' symbol) in a string
1029 /// attribute named 'attrName'.
1030 ParseResult parseSymbolName(StringAttr &result, StringRef attrName,
1031 NamedAttrList &attrs) {
1032 if (parseSymbolName(result))
1033 return failure();
1034 attrs.append(attrName, result);
1035 return success();
1036 }
1037
1038 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1039 /// string attribute.
1040 virtual ParseResult parseOptionalSymbolName(StringAttr &result) = 0;
1041
1042 /// Parse an optional @-identifier and store it (without the '@' symbol) in a
1043 /// string attribute named 'attrName'.
1044 ParseResult parseOptionalSymbolName(StringAttr &result, StringRef attrName,
1045 NamedAttrList &attrs) {
1046 if (succeeded(parseOptionalSymbolName(result))) {
1047 attrs.append(attrName, result);
1048 return success();
1049 }
1050 return failure();
1051 }
1052
1053 //===--------------------------------------------------------------------===//
1054 // Resource Parsing
1055 //===--------------------------------------------------------------------===//
1056
1057 /// Parse a handle to a resource within the assembly format.
1058 template <typename ResourceT>
1059 FailureOr<ResourceT> parseResourceHandle() {
1060 SMLoc handleLoc = getCurrentLocation();
1061
1062 // Try to load the dialect that owns the handle.
1063 auto *dialect =
1064 getContext()->getOrLoadDialect<typename ResourceT::Dialect>();
1065 if (!dialect) {
1066 return emitError(handleLoc)
1067 << "dialect '" << ResourceT::Dialect::getDialectNamespace()
1068 << "' is unknown";
1069 }
1070
1071 FailureOr<AsmDialectResourceHandle> handle = parseResourceHandle(dialect);
1072 if (failed(handle))
1073 return failure();
1074 if (auto *result = dyn_cast<ResourceT>(&*handle))
1075 return std::move(*result);
1076 return emitError(handleLoc) << "provided resource handle differs from the "
1077 "expected resource type";
1078 }
1079
1080 //===--------------------------------------------------------------------===//
1081 // Type Parsing
1082 //===--------------------------------------------------------------------===//
1083
1084 /// Parse a type.
1085 virtual ParseResult parseType(Type &result) = 0;
1086
1087 /// Parse a custom type with the provided callback, unless the next
1088 /// token is `#`, in which case the generic parser is invoked.
1089 virtual ParseResult parseCustomTypeWithFallback(
1090 Type &result, function_ref<ParseResult(Type &result)> parseType) = 0;
1091
1092 /// Parse an optional type.
1093 virtual OptionalParseResult parseOptionalType(Type &result) = 0;
1094
1095 /// Parse a type of a specific type.
1096 template <typename TypeT>
1097 ParseResult parseType(TypeT &result) {
1098 SMLoc loc = getCurrentLocation();
1099
1100 // Parse any kind of type.
1101 Type type;
1102 if (parseType(type))
1103 return failure();
1104
1105 // Check for the right kind of type.
1106 result = type.dyn_cast<TypeT>();
1107 if (!result)
1108 return emitError(loc, "invalid kind of type specified");
1109
1110 return success();
1111 }
1112
1113 /// Trait to check if `TypeT` provides a `parse` method.
1114 template <typename TypeT>
1115 using type_has_parse_method =
1116 decltype(TypeT::parse(std::declval<AsmParser &>()));
1117 template <typename TypeT>
1118 using detect_type_has_parse_method =
1119 llvm::is_detected<type_has_parse_method, TypeT>;
1120
1121 /// Parse a custom Type of a given type unless the next token is `#`, in
1122 /// which case the generic parser is invoked. The parsed Type is
1123 /// populated in `result`.
1124 template <typename TypeT>
1125 std::enable_if_t<detect_type_has_parse_method<TypeT>::value, ParseResult>
1126 parseCustomTypeWithFallback(TypeT &result) {
1127 SMLoc loc = getCurrentLocation();
1128
1129 // Parse any kind of Type.
1130 Type type;
1131 if (parseCustomTypeWithFallback(type, [&](Type &result) -> ParseResult {
1132 result = TypeT::parse(*this);
1133 return success(!!result);
1134 }))
1135 return failure();
1136
1137 // Check for the right kind of Type.
1138 result = type.dyn_cast<TypeT>();
1139 if (!result)
1140 return emitError(loc, "invalid kind of Type specified");
1141 return success();
1142 }
1143
1144 /// SFINAE parsing method for Type that don't implement a parse method.
1145 template <typename TypeT>
1146 std::enable_if_t<!detect_type_has_parse_method<TypeT>::value, ParseResult>
1147 parseCustomTypeWithFallback(TypeT &result) {
1148 return parseType(result);
1149 }
1150
1151 /// Parse a type list.
1152 ParseResult parseTypeList(SmallVectorImpl<Type> &result) {
1153 return parseCommaSeparatedList(
1154 [&]() { return parseType(result.emplace_back()); });
1155 }
1156
1157 /// Parse an arrow followed by a type list.
1158 virtual ParseResult parseArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1159
1160 /// Parse an optional arrow followed by a type list.
1161 virtual ParseResult
1162 parseOptionalArrowTypeList(SmallVectorImpl<Type> &result) = 0;
1163
1164 /// Parse a colon followed by a type.
1165 virtual ParseResult parseColonType(Type &result) = 0;
1166
1167 /// Parse a colon followed by a type of a specific kind, e.g. a FunctionType.
1168 template <typename TypeType>
1169 ParseResult parseColonType(TypeType &result) {
1170 SMLoc loc = getCurrentLocation();
1171
1172 // Parse any kind of type.
1173 Type type;
1174 if (parseColonType(type))
1175 return failure();
1176
1177 // Check for the right kind of type.
1178 result = type.dyn_cast<TypeType>();
1179 if (!result)
1180 return emitError(loc, "invalid kind of type specified");
1181
1182 return success();
1183 }
1184
1185 /// Parse a colon followed by a type list, which must have at least one type.
1186 virtual ParseResult parseColonTypeList(SmallVectorImpl<Type> &result) = 0;
1187
1188 /// Parse an optional colon followed by a type list, which if present must
1189 /// have at least one type.
1190 virtual ParseResult
1191 parseOptionalColonTypeList(SmallVectorImpl<Type> &result) = 0;
1192
1193 /// Parse a keyword followed by a type.
1194 ParseResult parseKeywordType(const char *keyword, Type &result) {
1195 return failure(parseKeyword(keyword) || parseType(result));
1196 }
1197
1198 /// Add the specified type to the end of the specified type list and return
1199 /// success. This is a helper designed to allow parse methods to be simple
1200 /// and chain through || operators.
1201 ParseResult addTypeToList(Type type, SmallVectorImpl<Type> &result) {
1202 result.push_back(type);
1203 return success();
1204 }
1205
1206 /// Add the specified types to the end of the specified type list and return
1207 /// success. This is a helper designed to allow parse methods to be simple
1208 /// and chain through || operators.
1209 ParseResult addTypesToList(ArrayRef<Type> types,
1210 SmallVectorImpl<Type> &result) {
1211 result.append(types.begin(), types.end());
1212 return success();
1213 }
1214
1215 /// Parse a dimension list of a tensor or memref type. This populates the
1216 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set
1217 /// and errors out on `?` otherwise. Parsing the trailing `x` is configurable.
1218 ///
1219 /// dimension-list ::= eps | dimension (`x` dimension)*
1220 /// dimension-list-with-trailing-x ::= (dimension `x`)*
1221 /// dimension ::= `?` | decimal-literal
1222 ///
1223 /// When `allowDynamic` is not set, this is used to parse:
1224 ///
1225 /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)*
1226 /// static-dimension-list-with-trailing-x ::= (dimension `x`)*
1227 virtual ParseResult parseDimensionList(SmallVectorImpl<int64_t> &dimensions,
1228 bool allowDynamic = true,
1229 bool withTrailingX = true) = 0;
1230
1231 /// Parse an 'x' token in a dimension list, handling the case where the x is
1232 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the
1233 /// next token.
1234 virtual ParseResult parseXInDimensionList() = 0;
1235
1236protected:
1237 /// Parse a handle to a resource within the assembly format for the given
1238 /// dialect.
1239 virtual FailureOr<AsmDialectResourceHandle>
1240 parseResourceHandle(Dialect *dialect) = 0;
1241
1242 //===--------------------------------------------------------------------===//
1243 // Code Completion
1244 //===--------------------------------------------------------------------===//
1245
1246 /// Parse a keyword, or an empty string if the current location signals a code
1247 /// completion.
1248 virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
1249
1250 /// Signal the code completion of a set of expected tokens.
1251 virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
1252
1253private:
1254 AsmParser(const AsmParser &) = delete;
1255 void operator=(const AsmParser &) = delete;
1256};
1257
1258//===----------------------------------------------------------------------===//
1259// OpAsmParser
1260//===----------------------------------------------------------------------===//
1261
1262/// The OpAsmParser has methods for interacting with the asm parser: parsing
1263/// things from it, emitting errors etc. It has an intentionally high-level API
1264/// that is designed to reduce/constrain syntax innovation in individual
1265/// operations.
1266///
1267/// For example, consider an op like this:
1268///
1269/// %x = load %p[%1, %2] : memref<...>
1270///
1271/// The "%x = load" tokens are already parsed and therefore invisible to the
1272/// custom op parser. This can be supported by calling `parseOperandList` to
1273/// parse the %p, then calling `parseOperandList` with a `SquareDelimiter` to
1274/// parse the indices, then calling `parseColonTypeList` to parse the result
1275/// type.
1276///
1277class OpAsmParser : public AsmParser {
1278public:
1279 using AsmParser::AsmParser;
1280 ~OpAsmParser() override;
1281
1282 /// Parse a loc(...) specifier if present, filling in result if so.
1283 /// Location for BlockArgument and Operation may be deferred with an alias, in
1284 /// which case an OpaqueLoc is set and will be resolved when parsing
1285 /// completes.
1286 virtual ParseResult
1287 parseOptionalLocationSpecifier(Optional<Location> &result) = 0;
1288
1289 /// Return the name of the specified result in the specified syntax, as well
1290 /// as the sub-element in the name. It returns an empty string and ~0U for
1291 /// invalid result numbers. For example, in this operation:
1292 ///
1293 /// %x, %y:2, %z = foo.op
1294 ///
1295 /// getResultName(0) == {"x", 0 }
1296 /// getResultName(1) == {"y", 0 }
1297 /// getResultName(2) == {"y", 1 }
1298 /// getResultName(3) == {"z", 0 }
1299 /// getResultName(4) == {"", ~0U }
1300 virtual std::pair<StringRef, unsigned>
1301 getResultName(unsigned resultNo) const = 0;
1302
1303 /// Return the number of declared SSA results. This returns 4 for the foo.op
1304 /// example in the comment for `getResultName`.
1305 virtual size_t getNumResults() const = 0;
1306
1307 // These methods emit an error and return failure or success. This allows
1308 // these to be chained together into a linear sequence of || expressions in
1309 // many cases.
1310
1311 /// Parse an operation in its generic form.
1312 /// The parsed operation is parsed in the current context and inserted in the
1313 /// provided block and insertion point. The results produced by this operation
1314 /// aren't mapped to any named value in the parser. Returns nullptr on
1315 /// failure.
1316 virtual Operation *parseGenericOperation(Block *insertBlock,
1317 Block::iterator insertPt) = 0;
1318
1319 /// Parse the name of an operation, in the custom form. On success, return a
1320 /// an object of type 'OperationName'. Otherwise, failure is returned.
1321 virtual FailureOr<OperationName> parseCustomOperationName() = 0;
1322
1323 //===--------------------------------------------------------------------===//
1324 // Operand Parsing
1325 //===--------------------------------------------------------------------===//
1326
1327 /// This is the representation of an operand reference.
1328 struct UnresolvedOperand {
1329 SMLoc location; // Location of the token.
1330 StringRef name; // Value name, e.g. %42 or %abc
1331 unsigned number; // Number, e.g. 12 for an operand like %xyz#12
1332 };
1333
1334 /// Parse different components, viz., use-info of operand(s), successor(s),
1335 /// region(s), attribute(s) and function-type, of the generic form of an
1336 /// operation instance and populate the input operation-state 'result' with
1337 /// those components. If any of the components is explicitly provided, then
1338 /// skip parsing that component.
1339 virtual ParseResult parseGenericOperationAfterOpName(
1340 OperationState &result,
1341 Optional<ArrayRef<UnresolvedOperand>> parsedOperandType = std::nullopt,
1342 Optional<ArrayRef<Block *>> parsedSuccessors = std::nullopt,
1343 Optional<MutableArrayRef<std::unique_ptr<Region>>> parsedRegions =
1344 std::nullopt,
1345 Optional<ArrayRef<NamedAttribute>> parsedAttributes = std::nullopt,
1346 Optional<FunctionType> parsedFnType = std::nullopt) = 0;
1347
1348 /// Parse a single SSA value operand name along with a result number if
1349 /// `allowResultNumber` is true.
1350 virtual ParseResult parseOperand(UnresolvedOperand &result,
1351 bool allowResultNumber = true) = 0;
1352
1353 /// Parse a single operand if present.
1354 virtual OptionalParseResult
1355 parseOptionalOperand(UnresolvedOperand &result,
1356 bool allowResultNumber = true) = 0;
1357
1358 /// Parse zero or more SSA comma-separated operand references with a specified
1359 /// surrounding delimiter, and an optional required operand count.
1360 virtual ParseResult
1361 parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1362 Delimiter delimiter = Delimiter::None,
1363 bool allowResultNumber = true,
1364 int requiredOperandCount = -1) = 0;
1365
1366 /// Parse a specified number of comma separated operands.
1367 ParseResult parseOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1368 int requiredOperandCount,
1369 Delimiter delimiter = Delimiter::None) {
1370 return parseOperandList(result, delimiter,
1371 /*allowResultNumber=*/true, requiredOperandCount);
1372 }
1373
1374 /// Parse zero or more trailing SSA comma-separated trailing operand
1375 /// references with a specified surrounding delimiter, and an optional
1376 /// required operand count. A leading comma is expected before the
1377 /// operands.
1378 ParseResult
1379 parseTrailingOperandList(SmallVectorImpl<UnresolvedOperand> &result,
1380 Delimiter delimiter = Delimiter::None) {
1381 if (failed(parseOptionalComma()))
1382 return success(); // The comma is optional.
1383 return parseOperandList(result, delimiter);
1384 }
1385
1386 /// Resolve an operand to an SSA value, emitting an error on failure.
1387 virtual ParseResult resolveOperand(const UnresolvedOperand &operand,
1388 Type type,
1389 SmallVectorImpl<Value> &result) = 0;
1390
1391 /// Resolve a list of operands to SSA values, emitting an error on failure, or
1392 /// appending the results to the list on success. This method should be used
1393 /// when all operands have the same type.
1394 template <typename Operands = ArrayRef<UnresolvedOperand>>
1395 ParseResult resolveOperands(Operands &&operands, Type type,
1396 SmallVectorImpl<Value> &result) {
1397 for (const UnresolvedOperand &operand : operands)
1398 if (resolveOperand(operand, type, result))
1399 return failure();
1400 return success();
1401 }
1402 template <typename Operands = ArrayRef<UnresolvedOperand>>
1403 ParseResult resolveOperands(Operands &&operands, Type type, SMLoc loc,
1404 SmallVectorImpl<Value> &result) {
1405 return resolveOperands(std::forward<Operands>(operands), type, result);
1406 }
1407
1408 /// Resolve a list of operands and a list of operand types to SSA values,
1409 /// emitting an error and returning failure, or appending the results
1410 /// to the list on success.
1411 template <typename Operands = ArrayRef<UnresolvedOperand>,
1412 typename Types = ArrayRef<Type>>
1413 std::enable_if_t<!std::is_convertible<Types, Type>::value, ParseResult>
1414 resolveOperands(Operands &&operands, Types &&types, SMLoc loc,
1415 SmallVectorImpl<Value> &result) {
1416 size_t operandSize = std::distance(operands.begin(), operands.end());
1417 size_t typeSize = std::distance(types.begin(), types.end());
1418 if (operandSize != typeSize)
1419 return emitError(loc)
1420 << operandSize << " operands present, but expected " << typeSize;
1421
1422 for (auto [operand, type] : llvm::zip(operands, types))
1423 if (resolveOperand(operand, type, result))
1424 return failure();
1425 return success();
1426 }
1427
1428 /// Parses an affine map attribute where dims and symbols are SSA operands.
1429 /// Operand values must come from single-result sources, and be valid
1430 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1431 virtual ParseResult
1432 parseAffineMapOfSSAIds(SmallVectorImpl<UnresolvedOperand> &operands,
1433 Attribute &map, StringRef attrName,
1434 NamedAttrList &attrs,
1435 Delimiter delimiter = Delimiter::Square) = 0;
1436
1437 /// Parses an affine expression where dims and symbols are SSA operands.
1438 /// Operand values must come from single-result sources, and be valid
1439 /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol.
1440 virtual ParseResult
1441 parseAffineExprOfSSAIds(SmallVectorImpl<UnresolvedOperand> &dimOperands,
1442 SmallVectorImpl<UnresolvedOperand> &symbOperands,
1443 AffineExpr &expr) = 0;
1444
1445 //===--------------------------------------------------------------------===//
1446 // Argument Parsing
1447 //===--------------------------------------------------------------------===//
1448
1449 struct Argument {
1450 UnresolvedOperand ssaName; // SourceLoc, SSA name, result #.
1451 Type type; // Type.
1452 DictionaryAttr attrs; // Attributes if present.
1453 Optional<Location> sourceLoc; // Source location specifier if present.
1454 };
1455
1456 /// Parse a single argument with the following syntax:
1457 ///
1458 /// `%ssaName : !type { optionalAttrDict} loc(optionalSourceLoc)`
1459 ///
1460 /// If `allowType` is false or `allowAttrs` are false then the respective
1461 /// parts of the grammar are not parsed.
1462 virtual ParseResult parseArgument(Argument &result, bool allowType = false,
1463 bool allowAttrs = false) = 0;
1464
1465 /// Parse a single argument if present.
1466 virtual OptionalParseResult
1467 parseOptionalArgument(Argument &result, bool allowType = false,
1468 bool allowAttrs = false) = 0;
1469
1470 /// Parse zero or more arguments with a specified surrounding delimiter.
1471 virtual ParseResult parseArgumentList(SmallVectorImpl<Argument> &result,
1472 Delimiter delimiter = Delimiter::None,
1473 bool allowType = false,
1474 bool allowAttrs = false) = 0;
1475
1476 //===--------------------------------------------------------------------===//
1477 // Region Parsing
1478 //===--------------------------------------------------------------------===//
1479
1480 /// Parses a region. Any parsed blocks are appended to 'region' and must be
1481 /// moved to the op regions after the op is created. The first block of the
1482 /// region takes 'arguments'.
1483 ///
1484 /// If 'enableNameShadowing' is set to true, the argument names are allowed to
1485 /// shadow the names of other existing SSA values defined above the region
1486 /// scope. 'enableNameShadowing' can only be set to true for regions attached
1487 /// to operations that are 'IsolatedFromAbove'.
1488 virtual ParseResult parseRegion(Region &region,
1489 ArrayRef<Argument> arguments = {},
1490 bool enableNameShadowing = false) = 0;
1491
1492 /// Parses a region if present.
1493 virtual OptionalParseResult
1494 parseOptionalRegion(Region &region, ArrayRef<Argument> arguments = {},
1495 bool enableNameShadowing = false) = 0;
1496
1497 /// Parses a region if present. If the region is present, a new region is
1498 /// allocated and placed in `region`. If no region is present or on failure,
1499 /// `region` remains untouched.
1500 virtual OptionalParseResult
1501 parseOptionalRegion(std::unique_ptr<Region> &region,
1502 ArrayRef<Argument> arguments = {},
1503 bool enableNameShadowing = false) = 0;
1504
1505 //===--------------------------------------------------------------------===//
1506 // Successor Parsing
1507 //===--------------------------------------------------------------------===//
1508
1509 /// Parse a single operation successor.
1510 virtual ParseResult parseSuccessor(Block *&dest) = 0;
1511
1512 /// Parse an optional operation successor.
1513 virtual OptionalParseResult parseOptionalSuccessor(Block *&dest) = 0;
1514
1515 /// Parse a single operation successor and its operand list.
1516 virtual ParseResult
1517 parseSuccessorAndUseList(Block *&dest, SmallVectorImpl<Value> &operands) = 0;
1518
1519 //===--------------------------------------------------------------------===//
1520 // Type Parsing
1521 //===--------------------------------------------------------------------===//
1522
1523 /// Parse a list of assignments of the form
1524 /// (%x1 = %y1, %x2 = %y2, ...)
1525 ParseResult parseAssignmentList(SmallVectorImpl<Argument> &lhs,
1526 SmallVectorImpl<UnresolvedOperand> &rhs) {
1527 OptionalParseResult result = parseOptionalAssignmentList(lhs, rhs);
1528 if (!result.has_value())
1529 return emitError(getCurrentLocation(), "expected '('");
1530 return result.value();
1531 }
1532
1533 virtual OptionalParseResult
1534 parseOptionalAssignmentList(SmallVectorImpl<Argument> &lhs,
1535 SmallVectorImpl<UnresolvedOperand> &rhs) = 0;
1536};
1537
1538//===--------------------------------------------------------------------===//
1539// Dialect OpAsm interface.
1540//===--------------------------------------------------------------------===//
1541
1542/// A functor used to set the name of the start of a result group of an
1543/// operation. See 'getAsmResultNames' below for more details.
1544using OpAsmSetValueNameFn = function_ref<void(Value, StringRef)>;
1545
1546/// A functor used to set the name of blocks in regions directly nested under
1547/// an operation.
1548using OpAsmSetBlockNameFn = function_ref<void(Block *, StringRef)>;
1549
1550class OpAsmDialectInterface
1551 : public DialectInterface::Base<OpAsmDialectInterface> {
1552public:
1553 OpAsmDialectInterface(Dialect *dialect) : Base(dialect) {}
1554
1555 //===------------------------------------------------------------------===//
1556 // Aliases
1557 //===------------------------------------------------------------------===//
1558
1559 /// Holds the result of `getAlias` hook call.
1560 enum class AliasResult {
1561 /// The object (type or attribute) is not supported by the hook
1562 /// and an alias was not provided.
1563 NoAlias,
1564 /// An alias was provided, but it might be overriden by other hook.
1565 OverridableAlias,
1566 /// An alias was provided and it should be used
1567 /// (no other hooks will be checked).
1568 FinalAlias
1569 };
1570
1571 /// Hooks for getting an alias identifier alias for a given symbol, that is
1572 /// not necessarily a part of this dialect. The identifier is used in place of
1573 /// the symbol when printing textual IR. These aliases must not contain `.` or
1574 /// end with a numeric digit([0-9]+).
1575 virtual AliasResult getAlias(Attribute attr, raw_ostream &os) const {
1576 return AliasResult::NoAlias;
1577 }
1578 virtual AliasResult getAlias(Type type, raw_ostream &os) const {
1579 return AliasResult::NoAlias;
1580 }
1581
1582 //===--------------------------------------------------------------------===//
1583 // Resources
1584 //===--------------------------------------------------------------------===//
1585
1586 /// Declare a resource with the given key, returning a handle to use for any
1587 /// references of this resource key within the IR during parsing. The result
1588 /// of `getResourceKey` on the returned handle is permitted to be different
1589 /// than `key`.
1590 virtual FailureOr<AsmDialectResourceHandle>
1591 declareResource(StringRef key) const {
1592 return failure();
1593 }
1594
1595 /// Return a key to use for the given resource. This key should uniquely
1596 /// identify this resource within the dialect.
1597 virtual std::string
1598 getResourceKey(const AsmDialectResourceHandle &handle) const {
1599 llvm_unreachable(::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
1600 "Dialect must implement `getResourceKey` when defining resources")::llvm::llvm_unreachable_internal("Dialect must implement `getResourceKey` when defining resources"
, "mlir/include/mlir/IR/OpImplementation.h", 1600)
;
1601 }
1602
1603 /// Hook for parsing resource entries. Returns failure if the entry was not
1604 /// valid, or could otherwise not be processed correctly. Any necessary errors
1605 /// can be emitted via the provided entry.
1606 virtual LogicalResult parseResource(AsmParsedResourceEntry &entry) const;
1607
1608 /// Hook for building resources to use during printing. The given `op` may be
1609 /// inspected to help determine what information to include.
1610 /// `referencedResources` contains all of the resources detected when printing
1611 /// 'op'.
1612 virtual void
1613 buildResources(Operation *op,
1614 const SetVector<AsmDialectResourceHandle> &referencedResources,
1615 AsmResourceBuilder &builder) const {}
1616};
1617} // namespace mlir
1618
1619//===--------------------------------------------------------------------===//
1620// Operation OpAsm interface.
1621//===--------------------------------------------------------------------===//
1622
1623/// The OpAsmOpInterface, see OpAsmInterface.td for more details.
1624#include "mlir/IR/OpAsmInterface.h.inc"
1625
1626namespace llvm {
1627template <>
1628struct DenseMapInfo<mlir::AsmDialectResourceHandle> {
1629 static inline mlir::AsmDialectResourceHandle getEmptyKey() {
1630 return {DenseMapInfo<void *>::getEmptyKey(),
1631 DenseMapInfo<mlir::TypeID>::getEmptyKey(), nullptr};
1632 }
1633 static inline mlir::AsmDialectResourceHandle getTombstoneKey() {
1634 return {DenseMapInfo<void *>::getTombstoneKey(),
1635 DenseMapInfo<mlir::TypeID>::getTombstoneKey(), nullptr};
1636 }
1637 static unsigned getHashValue(const mlir::AsmDialectResourceHandle &handle) {
1638 return DenseMapInfo<void *>::getHashValue(handle.getResource());
1639 }
1640 static bool isEqual(const mlir::AsmDialectResourceHandle &lhs,
1641 const mlir::AsmDialectResourceHandle &rhs) {
1642 return lhs.getResource() == rhs.getResource();
1643 }
1644};
1645} // namespace llvm
1646
1647#endif