File: | build/source/llvm/include/llvm/ADT/Optional.h |
Warning: | line 179, column 13 Assigned value is garbage or undefined |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | // TODO: Support for big-endian architectures. | |||
10 | // TODO: Properly preserve use lists of values. | |||
11 | ||||
12 | #include "mlir/Bytecode/BytecodeReader.h" | |||
13 | #include "../Encoding.h" | |||
14 | #include "mlir/AsmParser/AsmParser.h" | |||
15 | #include "mlir/Bytecode/BytecodeImplementation.h" | |||
16 | #include "mlir/IR/BuiltinDialect.h" | |||
17 | #include "mlir/IR/BuiltinOps.h" | |||
18 | #include "mlir/IR/OpImplementation.h" | |||
19 | #include "mlir/IR/Verifier.h" | |||
20 | #include "llvm/ADT/MapVector.h" | |||
21 | #include "llvm/ADT/ScopeExit.h" | |||
22 | #include "llvm/ADT/SmallString.h" | |||
23 | #include "llvm/ADT/StringExtras.h" | |||
24 | #include "llvm/Support/MemoryBufferRef.h" | |||
25 | #include "llvm/Support/SaveAndRestore.h" | |||
26 | #include "llvm/Support/SourceMgr.h" | |||
27 | #include <optional> | |||
28 | ||||
29 | #define DEBUG_TYPE"mlir-bytecode-reader" "mlir-bytecode-reader" | |||
30 | ||||
31 | using namespace mlir; | |||
32 | ||||
33 | /// Stringify the given section ID. | |||
34 | static std::string toString(bytecode::Section::ID sectionID) { | |||
35 | switch (sectionID) { | |||
36 | case bytecode::Section::kString: | |||
37 | return "String (0)"; | |||
38 | case bytecode::Section::kDialect: | |||
39 | return "Dialect (1)"; | |||
40 | case bytecode::Section::kAttrType: | |||
41 | return "AttrType (2)"; | |||
42 | case bytecode::Section::kAttrTypeOffset: | |||
43 | return "AttrTypeOffset (3)"; | |||
44 | case bytecode::Section::kIR: | |||
45 | return "IR (4)"; | |||
46 | case bytecode::Section::kResource: | |||
47 | return "Resource (5)"; | |||
48 | case bytecode::Section::kResourceOffset: | |||
49 | return "ResourceOffset (6)"; | |||
50 | default: | |||
51 | return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); | |||
52 | } | |||
53 | } | |||
54 | ||||
55 | /// Returns true if the given top-level section ID is optional. | |||
56 | static bool isSectionOptional(bytecode::Section::ID sectionID) { | |||
57 | switch (sectionID) { | |||
58 | case bytecode::Section::kString: | |||
59 | case bytecode::Section::kDialect: | |||
60 | case bytecode::Section::kAttrType: | |||
61 | case bytecode::Section::kAttrTypeOffset: | |||
62 | case bytecode::Section::kIR: | |||
63 | return false; | |||
64 | case bytecode::Section::kResource: | |||
65 | case bytecode::Section::kResourceOffset: | |||
66 | return true; | |||
67 | default: | |||
68 | llvm_unreachable("unknown section ID")::llvm::llvm_unreachable_internal("unknown section ID", "mlir/lib/Bytecode/Reader/BytecodeReader.cpp" , 68); | |||
69 | } | |||
70 | } | |||
71 | ||||
72 | //===----------------------------------------------------------------------===// | |||
73 | // EncodingReader | |||
74 | //===----------------------------------------------------------------------===// | |||
75 | ||||
76 | namespace { | |||
77 | class EncodingReader { | |||
78 | public: | |||
79 | explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) | |||
80 | : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} | |||
81 | explicit EncodingReader(StringRef contents, Location fileLoc) | |||
82 | : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), | |||
83 | contents.size()}, | |||
84 | fileLoc) {} | |||
85 | ||||
86 | /// Returns true if the entire section has been read. | |||
87 | bool empty() const { return dataIt == dataEnd; } | |||
88 | ||||
89 | /// Returns the remaining size of the bytecode. | |||
90 | size_t size() const { return dataEnd - dataIt; } | |||
91 | ||||
92 | /// Align the current reader position to the specified alignment. | |||
93 | LogicalResult alignTo(unsigned alignment) { | |||
94 | if (!llvm::isPowerOf2_32(alignment)) | |||
95 | return emitError("expected alignment to be a power-of-two"); | |||
96 | ||||
97 | // Shift the reader position to the next alignment boundary. | |||
98 | while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { | |||
99 | uint8_t padding; | |||
100 | if (failed(parseByte(padding))) | |||
101 | return failure(); | |||
102 | if (padding != bytecode::kAlignmentByte) { | |||
103 | return emitError("expected alignment byte (0xCB), but got: '0x" + | |||
104 | llvm::utohexstr(padding) + "'"); | |||
105 | } | |||
106 | } | |||
107 | ||||
108 | // Ensure the data iterator is now aligned. This case is unlikely because we | |||
109 | // *just* went through the effort to align the data iterator. | |||
110 | if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))__builtin_expect((bool)(!llvm::isAddrAligned(llvm::Align(alignment ), dataIt)), false)) { | |||
111 | return emitError("expected data iterator aligned to ", alignment, | |||
112 | ", but got pointer: '0x" + | |||
113 | llvm::utohexstr((uintptr_t)dataIt) + "'"); | |||
114 | } | |||
115 | ||||
116 | return success(); | |||
117 | } | |||
118 | ||||
119 | /// Emit an error using the given arguments. | |||
120 | template <typename... Args> | |||
121 | InFlightDiagnostic emitError(Args &&...args) const { | |||
122 | return ::emitError(fileLoc).append(std::forward<Args>(args)...); | |||
123 | } | |||
124 | InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } | |||
125 | ||||
126 | /// Parse a single byte from the stream. | |||
127 | template <typename T> | |||
128 | LogicalResult parseByte(T &value) { | |||
129 | if (empty()) | |||
130 | return emitError("attempting to parse a byte at the end of the bytecode"); | |||
131 | value = static_cast<T>(*dataIt++); | |||
132 | return success(); | |||
133 | } | |||
134 | /// Parse a range of bytes of 'length' into the given result. | |||
135 | LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { | |||
136 | if (length > size()) { | |||
137 | return emitError("attempting to parse ", length, " bytes when only ", | |||
138 | size(), " remain"); | |||
139 | } | |||
140 | result = {dataIt, length}; | |||
141 | dataIt += length; | |||
142 | return success(); | |||
143 | } | |||
144 | /// Parse a range of bytes of 'length' into the given result, which can be | |||
145 | /// assumed to be large enough to hold `length`. | |||
146 | LogicalResult parseBytes(size_t length, uint8_t *result) { | |||
147 | if (length > size()) { | |||
148 | return emitError("attempting to parse ", length, " bytes when only ", | |||
149 | size(), " remain"); | |||
150 | } | |||
151 | memcpy(result, dataIt, length); | |||
152 | dataIt += length; | |||
153 | return success(); | |||
154 | } | |||
155 | ||||
156 | /// Parse an aligned blob of data, where the alignment was encoded alongside | |||
157 | /// the data. | |||
158 | LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, | |||
159 | uint64_t &alignment) { | |||
160 | uint64_t dataSize; | |||
161 | if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || | |||
162 | failed(alignTo(alignment))) | |||
163 | return failure(); | |||
164 | return parseBytes(dataSize, data); | |||
165 | } | |||
166 | ||||
167 | /// Parse a variable length encoded integer from the byte stream. The first | |||
168 | /// encoded byte contains a prefix in the low bits indicating the encoded | |||
169 | /// length of the value. This length prefix is a bit sequence of '0's followed | |||
170 | /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes | |||
171 | /// (not including the prefix byte). All remaining bits in the first byte, | |||
172 | /// along with all of the bits in additional bytes, provide the value of the | |||
173 | /// integer encoded in little-endian order. | |||
174 | LogicalResult parseVarInt(uint64_t &result) { | |||
175 | // Parse the first byte of the encoding, which contains the length prefix. | |||
176 | if (failed(parseByte(result))) | |||
177 | return failure(); | |||
178 | ||||
179 | // Handle the overwhelmingly common case where the value is stored in a | |||
180 | // single byte. In this case, the first bit is the `1` marker bit. | |||
181 | if (LLVM_LIKELY(result & 1)__builtin_expect((bool)(result & 1), true)) { | |||
182 | result >>= 1; | |||
183 | return success(); | |||
184 | } | |||
185 | ||||
186 | // Handle the overwhelming uncommon case where the value required all 8 | |||
187 | // bytes (i.e. a really really big number). In this case, the marker byte is | |||
188 | // all zeros: `00000000`. | |||
189 | if (LLVM_UNLIKELY(result == 0)__builtin_expect((bool)(result == 0), false)) | |||
190 | return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); | |||
191 | return parseMultiByteVarInt(result); | |||
192 | } | |||
193 | ||||
194 | /// Parse a signed variable length encoded integer from the byte stream. A | |||
195 | /// signed varint is encoded as a normal varint with zigzag encoding applied, | |||
196 | /// i.e. the low bit of the value is used to indicate the sign. | |||
197 | LogicalResult parseSignedVarInt(uint64_t &result) { | |||
198 | if (failed(parseVarInt(result))) | |||
199 | return failure(); | |||
200 | // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) | |||
201 | result = (result >> 1) ^ (~(result & 1) + 1); | |||
202 | return success(); | |||
203 | } | |||
204 | ||||
205 | /// Parse a variable length encoded integer whose low bit is used to encode an | |||
206 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. | |||
207 | LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { | |||
208 | if (failed(parseVarInt(result))) | |||
209 | return failure(); | |||
210 | flag = result & 1; | |||
211 | result >>= 1; | |||
212 | return success(); | |||
213 | } | |||
214 | ||||
215 | /// Skip the first `length` bytes within the reader. | |||
216 | LogicalResult skipBytes(size_t length) { | |||
217 | if (length > size()) { | |||
218 | return emitError("attempting to skip ", length, " bytes when only ", | |||
219 | size(), " remain"); | |||
220 | } | |||
221 | dataIt += length; | |||
222 | return success(); | |||
223 | } | |||
224 | ||||
225 | /// Parse a null-terminated string into `result` (without including the NUL | |||
226 | /// terminator). | |||
227 | LogicalResult parseNullTerminatedString(StringRef &result) { | |||
228 | const char *startIt = (const char *)dataIt; | |||
229 | const char *nulIt = (const char *)memchr(startIt, 0, size()); | |||
230 | if (!nulIt) | |||
231 | return emitError( | |||
232 | "malformed null-terminated string, no null character found"); | |||
233 | ||||
234 | result = StringRef(startIt, nulIt - startIt); | |||
235 | dataIt = (const uint8_t *)nulIt + 1; | |||
236 | return success(); | |||
237 | } | |||
238 | ||||
239 | /// Parse a section header, placing the kind of section in `sectionID` and the | |||
240 | /// contents of the section in `sectionData`. | |||
241 | LogicalResult parseSection(bytecode::Section::ID §ionID, | |||
242 | ArrayRef<uint8_t> §ionData) { | |||
243 | uint8_t sectionIDAndHasAlignment; | |||
244 | uint64_t length; | |||
245 | if (failed(parseByte(sectionIDAndHasAlignment)) || | |||
246 | failed(parseVarInt(length))) | |||
247 | return failure(); | |||
248 | ||||
249 | // Extract the section ID and whether the section is aligned. The high bit | |||
250 | // of the ID is the alignment flag. | |||
251 | sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & | |||
252 | 0b01111111); | |||
253 | bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; | |||
254 | ||||
255 | // Check that the section is actually valid before trying to process its | |||
256 | // data. | |||
257 | if (sectionID >= bytecode::Section::kNumSections) | |||
258 | return emitError("invalid section ID: ", unsigned(sectionID)); | |||
259 | ||||
260 | // Process the section alignment if present. | |||
261 | if (hasAlignment) { | |||
262 | uint64_t alignment; | |||
263 | if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) | |||
264 | return failure(); | |||
265 | } | |||
266 | ||||
267 | // Parse the actual section data. | |||
268 | return parseBytes(static_cast<size_t>(length), sectionData); | |||
269 | } | |||
270 | ||||
271 | private: | |||
272 | /// Parse a variable length encoded integer from the byte stream. This method | |||
273 | /// is a fallback when the number of bytes used to encode the value is greater | |||
274 | /// than 1, but less than the max (9). The provided `result` value can be | |||
275 | /// assumed to already contain the first byte of the value. | |||
276 | /// NOTE: This method is marked noinline to avoid pessimizing the common case | |||
277 | /// of single byte encoding. | |||
278 | LLVM_ATTRIBUTE_NOINLINE__attribute__((noinline)) LogicalResult parseMultiByteVarInt(uint64_t &result) { | |||
279 | // Count the number of trailing zeros in the marker byte, this indicates the | |||
280 | // number of trailing bytes that are part of the value. We use `uint32_t` | |||
281 | // here because we only care about the first byte, and so that be actually | |||
282 | // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop | |||
283 | // implementation). | |||
284 | uint32_t numBytes = | |||
285 | llvm::countTrailingZeros<uint32_t>(result, llvm::ZB_Undefined); | |||
286 | assert(numBytes > 0 && numBytes <= 7 &&(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 287, __extension__ __PRETTY_FUNCTION__)) | |||
287 | "unexpected number of trailing zeros in varint encoding")(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 287, __extension__ __PRETTY_FUNCTION__)); | |||
288 | ||||
289 | // Parse in the remaining bytes of the value. | |||
290 | if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) | |||
291 | return failure(); | |||
292 | ||||
293 | // Shift out the low-order bits that were used to mark how the value was | |||
294 | // encoded. | |||
295 | result >>= (numBytes + 1); | |||
296 | return success(); | |||
297 | } | |||
298 | ||||
299 | /// The current data iterator, and an iterator to the end of the buffer. | |||
300 | const uint8_t *dataIt, *dataEnd; | |||
301 | ||||
302 | /// A location for the bytecode used to report errors. | |||
303 | Location fileLoc; | |||
304 | }; | |||
305 | } // namespace | |||
306 | ||||
307 | /// Resolve an index into the given entry list. `entry` may either be a | |||
308 | /// reference, in which case it is assigned to the corresponding value in | |||
309 | /// `entries`, or a pointer, in which case it is assigned to the address of the | |||
310 | /// element in `entries`. | |||
311 | template <typename RangeT, typename T> | |||
312 | static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, | |||
313 | uint64_t index, T &entry, | |||
314 | StringRef entryStr) { | |||
315 | if (index >= entries.size()) | |||
316 | return reader.emitError("invalid ", entryStr, " index: ", index); | |||
317 | ||||
318 | // If the provided entry is a pointer, resolve to the address of the entry. | |||
319 | if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) | |||
320 | entry = entries[index]; | |||
321 | else | |||
322 | entry = &entries[index]; | |||
323 | return success(); | |||
324 | } | |||
325 | ||||
326 | /// Parse and resolve an index into the given entry list. | |||
327 | template <typename RangeT, typename T> | |||
328 | static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, | |||
329 | T &entry, StringRef entryStr) { | |||
330 | uint64_t entryIdx; | |||
331 | if (failed(reader.parseVarInt(entryIdx))) | |||
332 | return failure(); | |||
333 | return resolveEntry(reader, entries, entryIdx, entry, entryStr); | |||
334 | } | |||
335 | ||||
336 | //===----------------------------------------------------------------------===// | |||
337 | // StringSectionReader | |||
338 | //===----------------------------------------------------------------------===// | |||
339 | ||||
340 | namespace { | |||
341 | /// This class is used to read references to the string section from the | |||
342 | /// bytecode. | |||
343 | class StringSectionReader { | |||
344 | public: | |||
345 | /// Initialize the string section reader with the given section data. | |||
346 | LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); | |||
347 | ||||
348 | /// Parse a shared string from the string section. The shared string is | |||
349 | /// encoded using an index to a corresponding string in the string section. | |||
350 | LogicalResult parseString(EncodingReader &reader, StringRef &result) { | |||
351 | return parseEntry(reader, strings, result, "string"); | |||
352 | } | |||
353 | ||||
354 | private: | |||
355 | /// The table of strings referenced within the bytecode file. | |||
356 | SmallVector<StringRef> strings; | |||
357 | }; | |||
358 | } // namespace | |||
359 | ||||
360 | LogicalResult StringSectionReader::initialize(Location fileLoc, | |||
361 | ArrayRef<uint8_t> sectionData) { | |||
362 | EncodingReader stringReader(sectionData, fileLoc); | |||
363 | ||||
364 | // Parse the number of strings in the section. | |||
365 | uint64_t numStrings; | |||
366 | if (failed(stringReader.parseVarInt(numStrings))) | |||
367 | return failure(); | |||
368 | strings.resize(numStrings); | |||
369 | ||||
370 | // Parse each of the strings. The sizes of the strings are encoded in reverse | |||
371 | // order, so that's the order we populate the table. | |||
372 | size_t stringDataEndOffset = sectionData.size(); | |||
373 | for (StringRef &string : llvm::reverse(strings)) { | |||
374 | uint64_t stringSize; | |||
375 | if (failed(stringReader.parseVarInt(stringSize))) | |||
376 | return failure(); | |||
377 | if (stringDataEndOffset < stringSize) { | |||
378 | return stringReader.emitError( | |||
379 | "string size exceeds the available data size"); | |||
380 | } | |||
381 | ||||
382 | // Extract the string from the data, dropping the null character. | |||
383 | size_t stringOffset = stringDataEndOffset - stringSize; | |||
384 | string = StringRef( | |||
385 | reinterpret_cast<const char *>(sectionData.data() + stringOffset), | |||
386 | stringSize - 1); | |||
387 | stringDataEndOffset = stringOffset; | |||
388 | } | |||
389 | ||||
390 | // Check that the only remaining data was for the strings, i.e. the reader | |||
391 | // should be at the same offset as the first string. | |||
392 | if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { | |||
393 | return stringReader.emitError("unexpected trailing data between the " | |||
394 | "offsets for strings and their data"); | |||
395 | } | |||
396 | return success(); | |||
397 | } | |||
398 | ||||
399 | //===----------------------------------------------------------------------===// | |||
400 | // BytecodeDialect | |||
401 | //===----------------------------------------------------------------------===// | |||
402 | ||||
403 | namespace { | |||
404 | /// This struct represents a dialect entry within the bytecode. | |||
405 | struct BytecodeDialect { | |||
406 | /// Load the dialect into the provided context if it hasn't been loaded yet. | |||
407 | /// Returns failure if the dialect couldn't be loaded *and* the provided | |||
408 | /// context does not allow unregistered dialects. The provided reader is used | |||
409 | /// for error emission if necessary. | |||
410 | LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { | |||
411 | if (dialect) | |||
412 | return success(); | |||
413 | Dialect *loadedDialect = ctx->getOrLoadDialect(name); | |||
414 | if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { | |||
415 | return reader.emitError( | |||
416 | "dialect '", name, | |||
417 | "' is unknown. If this is intended, please call " | |||
418 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
419 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
420 | } | |||
421 | dialect = loadedDialect; | |||
422 | ||||
423 | // If the dialect was actually loaded, check to see if it has a bytecode | |||
424 | // interface. | |||
425 | if (loadedDialect) | |||
426 | interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); | |||
427 | return success(); | |||
428 | } | |||
429 | ||||
430 | /// Return the loaded dialect, or nullptr if the dialect is unknown. This can | |||
431 | /// only be called after `load`. | |||
432 | Dialect *getLoadedDialect() const { | |||
433 | assert(dialect &&(static_cast <bool> (dialect && "expected `load` to be invoked before `getLoadedDialect`" ) ? void (0) : __assert_fail ("dialect && \"expected `load` to be invoked before `getLoadedDialect`\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 434, __extension__ __PRETTY_FUNCTION__)) | |||
434 | "expected `load` to be invoked before `getLoadedDialect`")(static_cast <bool> (dialect && "expected `load` to be invoked before `getLoadedDialect`" ) ? void (0) : __assert_fail ("dialect && \"expected `load` to be invoked before `getLoadedDialect`\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 434, __extension__ __PRETTY_FUNCTION__)); | |||
435 | return *dialect; | |||
436 | } | |||
437 | ||||
438 | /// The loaded dialect entry. This field is std::nullopt if we haven't | |||
439 | /// attempted to load, nullptr if we failed to load, otherwise the loaded | |||
440 | /// dialect. | |||
441 | std::optional<Dialect *> dialect; | |||
442 | ||||
443 | /// The bytecode interface of the dialect, or nullptr if the dialect does not | |||
444 | /// implement the bytecode interface. This field should only be checked if the | |||
445 | /// `dialect` field is not std::nullopt. | |||
446 | const BytecodeDialectInterface *interface = nullptr; | |||
447 | ||||
448 | /// The name of the dialect. | |||
449 | StringRef name; | |||
450 | }; | |||
451 | ||||
452 | /// This struct represents an operation name entry within the bytecode. | |||
453 | struct BytecodeOperationName { | |||
454 | BytecodeOperationName(BytecodeDialect *dialect, StringRef name) | |||
455 | : dialect(dialect), name(name) {} | |||
456 | ||||
457 | /// The loaded operation name, or std::nullopt if it hasn't been processed | |||
458 | /// yet. | |||
459 | std::optional<OperationName> opName; | |||
460 | ||||
461 | /// The dialect that owns this operation name. | |||
462 | BytecodeDialect *dialect; | |||
463 | ||||
464 | /// The name of the operation, without the dialect prefix. | |||
465 | StringRef name; | |||
466 | }; | |||
467 | } // namespace | |||
468 | ||||
469 | /// Parse a single dialect group encoded in the byte stream. | |||
470 | static LogicalResult parseDialectGrouping( | |||
471 | EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, | |||
472 | function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { | |||
473 | // Parse the dialect and the number of entries in the group. | |||
474 | BytecodeDialect *dialect; | |||
475 | if (failed(parseEntry(reader, dialects, dialect, "dialect"))) | |||
476 | return failure(); | |||
477 | uint64_t numEntries; | |||
478 | if (failed(reader.parseVarInt(numEntries))) | |||
479 | return failure(); | |||
480 | ||||
481 | for (uint64_t i = 0; i < numEntries; ++i) | |||
482 | if (failed(entryCallback(dialect))) | |||
483 | return failure(); | |||
484 | return success(); | |||
485 | } | |||
486 | ||||
487 | //===----------------------------------------------------------------------===// | |||
488 | // ResourceSectionReader | |||
489 | //===----------------------------------------------------------------------===// | |||
490 | ||||
491 | namespace { | |||
492 | /// This class is used to read the resource section from the bytecode. | |||
493 | class ResourceSectionReader { | |||
494 | public: | |||
495 | /// Initialize the resource section reader with the given section data. | |||
496 | LogicalResult | |||
497 | initialize(Location fileLoc, const ParserConfig &config, | |||
498 | MutableArrayRef<BytecodeDialect> dialects, | |||
499 | StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, | |||
500 | ArrayRef<uint8_t> offsetSectionData, | |||
501 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); | |||
502 | ||||
503 | /// Parse a dialect resource handle from the resource section. | |||
504 | LogicalResult parseResourceHandle(EncodingReader &reader, | |||
505 | AsmDialectResourceHandle &result) { | |||
506 | return parseEntry(reader, dialectResources, result, "resource handle"); | |||
507 | } | |||
508 | ||||
509 | private: | |||
510 | /// The table of dialect resources within the bytecode file. | |||
511 | SmallVector<AsmDialectResourceHandle> dialectResources; | |||
512 | }; | |||
513 | ||||
514 | class ParsedResourceEntry : public AsmParsedResourceEntry { | |||
515 | public: | |||
516 | ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, | |||
517 | EncodingReader &reader, StringSectionReader &stringReader, | |||
518 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) | |||
519 | : key(key), kind(kind), reader(reader), stringReader(stringReader), | |||
520 | bufferOwnerRef(bufferOwnerRef) {} | |||
521 | ~ParsedResourceEntry() override = default; | |||
522 | ||||
523 | StringRef getKey() const final { return key; } | |||
524 | ||||
525 | InFlightDiagnostic emitError() const final { return reader.emitError(); } | |||
526 | ||||
527 | AsmResourceEntryKind getKind() const final { return kind; } | |||
528 | ||||
529 | FailureOr<bool> parseAsBool() const final { | |||
530 | if (kind != AsmResourceEntryKind::Bool) | |||
| ||||
531 | return emitError() << "expected a bool resource entry, but found a " | |||
532 | << toString(kind) << " entry instead"; | |||
533 | ||||
534 | bool value; | |||
535 | if (failed(reader.parseByte(value))) | |||
536 | return failure(); | |||
537 | return value; | |||
538 | } | |||
539 | FailureOr<std::string> parseAsString() const final { | |||
540 | if (kind != AsmResourceEntryKind::String) | |||
541 | return emitError() << "expected a string resource entry, but found a " | |||
542 | << toString(kind) << " entry instead"; | |||
543 | ||||
544 | StringRef string; | |||
545 | if (failed(stringReader.parseString(reader, string))) | |||
546 | return failure(); | |||
547 | return string.str(); | |||
548 | } | |||
549 | ||||
550 | FailureOr<AsmResourceBlob> | |||
551 | parseAsBlob(BlobAllocatorFn allocator) const final { | |||
552 | if (kind != AsmResourceEntryKind::Blob) | |||
553 | return emitError() << "expected a blob resource entry, but found a " | |||
554 | << toString(kind) << " entry instead"; | |||
555 | ||||
556 | ArrayRef<uint8_t> data; | |||
557 | uint64_t alignment; | |||
558 | if (failed(reader.parseBlobAndAlignment(data, alignment))) | |||
559 | return failure(); | |||
560 | ||||
561 | // If we have an extendable reference to the buffer owner, we don't need to | |||
562 | // allocate a new buffer for the data, and can use the data directly. | |||
563 | if (bufferOwnerRef) { | |||
564 | ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), | |||
565 | data.size()); | |||
566 | ||||
567 | // Allocate an unmanager buffer which captures a reference to the owner. | |||
568 | // For now we just mark this as immutable, but in the future we should | |||
569 | // explore marking this as mutable when desired. | |||
570 | return UnmanagedAsmResourceBlob::allocateWithAlign( | |||
571 | charData, alignment, | |||
572 | [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); | |||
573 | } | |||
574 | ||||
575 | // Allocate memory for the blob using the provided allocator and copy the | |||
576 | // data into it. | |||
577 | AsmResourceBlob blob = allocator(data.size(), alignment); | |||
578 | assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 580, __extension__ __PRETTY_FUNCTION__)) | |||
579 | blob.isMutable() &&(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 580, __extension__ __PRETTY_FUNCTION__)) | |||
580 | "blob allocator did not return a properly aligned address")(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 580, __extension__ __PRETTY_FUNCTION__)); | |||
581 | memcpy(blob.getMutableData().data(), data.data(), data.size()); | |||
582 | return blob; | |||
583 | } | |||
584 | ||||
585 | private: | |||
586 | StringRef key; | |||
587 | AsmResourceEntryKind kind; | |||
588 | EncodingReader &reader; | |||
589 | StringSectionReader &stringReader; | |||
590 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; | |||
591 | }; | |||
592 | } // namespace | |||
593 | ||||
594 | template <typename T> | |||
595 | static LogicalResult | |||
596 | parseResourceGroup(Location fileLoc, bool allowEmpty, | |||
597 | EncodingReader &offsetReader, EncodingReader &resourceReader, | |||
598 | StringSectionReader &stringReader, T *handler, | |||
599 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, | |||
600 | function_ref<LogicalResult(StringRef)> processKeyFn = {}) { | |||
601 | uint64_t numResources; | |||
602 | if (failed(offsetReader.parseVarInt(numResources))) | |||
603 | return failure(); | |||
604 | ||||
605 | for (uint64_t i = 0; i < numResources; ++i) { | |||
606 | StringRef key; | |||
607 | AsmResourceEntryKind kind; | |||
608 | uint64_t resourceOffset; | |||
609 | ArrayRef<uint8_t> data; | |||
610 | if (failed(stringReader.parseString(offsetReader, key)) || | |||
611 | failed(offsetReader.parseVarInt(resourceOffset)) || | |||
612 | failed(offsetReader.parseByte(kind)) || | |||
613 | failed(resourceReader.parseBytes(resourceOffset, data))) | |||
614 | return failure(); | |||
615 | ||||
616 | // Process the resource key. | |||
617 | if ((processKeyFn && failed(processKeyFn(key)))) | |||
618 | return failure(); | |||
619 | ||||
620 | // If the resource data is empty and we allow it, don't error out when | |||
621 | // parsing below, just skip it. | |||
622 | if (allowEmpty && data.empty()) | |||
623 | continue; | |||
624 | ||||
625 | // Ignore the entry if we don't have a valid handler. | |||
626 | if (!handler) | |||
627 | continue; | |||
628 | ||||
629 | // Otherwise, parse the resource value. | |||
630 | EncodingReader entryReader(data, fileLoc); | |||
631 | ParsedResourceEntry entry(key, kind, entryReader, stringReader, | |||
632 | bufferOwnerRef); | |||
633 | if (failed(handler->parseResource(entry))) | |||
634 | return failure(); | |||
635 | if (!entryReader.empty()) { | |||
636 | return entryReader.emitError( | |||
637 | "unexpected trailing bytes in resource entry '", key, "'"); | |||
638 | } | |||
639 | } | |||
640 | return success(); | |||
641 | } | |||
642 | ||||
643 | LogicalResult ResourceSectionReader::initialize( | |||
644 | Location fileLoc, const ParserConfig &config, | |||
645 | MutableArrayRef<BytecodeDialect> dialects, | |||
646 | StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, | |||
647 | ArrayRef<uint8_t> offsetSectionData, | |||
648 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { | |||
649 | EncodingReader resourceReader(sectionData, fileLoc); | |||
650 | EncodingReader offsetReader(offsetSectionData, fileLoc); | |||
651 | ||||
652 | // Read the number of external resource providers. | |||
653 | uint64_t numExternalResourceGroups; | |||
654 | if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) | |||
655 | return failure(); | |||
656 | ||||
657 | // Utility functor that dispatches to `parseResourceGroup`, but implicitly | |||
658 | // provides most of the arguments. | |||
659 | auto parseGroup = [&](auto *handler, bool allowEmpty = false, | |||
660 | function_ref<LogicalResult(StringRef)> keyFn = {}) { | |||
661 | return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, | |||
662 | stringReader, handler, bufferOwnerRef, keyFn); | |||
663 | }; | |||
664 | ||||
665 | // Read the external resources from the bytecode. | |||
666 | for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { | |||
667 | StringRef key; | |||
668 | if (failed(stringReader.parseString(offsetReader, key))) | |||
669 | return failure(); | |||
670 | ||||
671 | // Get the handler for these resources. | |||
672 | // TODO: Should we require handling external resources in some scenarios? | |||
673 | AsmResourceParser *handler = config.getResourceParser(key); | |||
674 | if (!handler) { | |||
675 | emitWarning(fileLoc) << "ignoring unknown external resources for '" << key | |||
676 | << "'"; | |||
677 | } | |||
678 | ||||
679 | if (failed(parseGroup(handler))) | |||
680 | return failure(); | |||
681 | } | |||
682 | ||||
683 | // Read the dialect resources from the bytecode. | |||
684 | MLIRContext *ctx = fileLoc->getContext(); | |||
685 | while (!offsetReader.empty()) { | |||
686 | BytecodeDialect *dialect; | |||
687 | if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || | |||
688 | failed(dialect->load(resourceReader, ctx))) | |||
689 | return failure(); | |||
690 | Dialect *loadedDialect = dialect->getLoadedDialect(); | |||
691 | if (!loadedDialect) { | |||
692 | return resourceReader.emitError() | |||
693 | << "dialect '" << dialect->name << "' is unknown"; | |||
694 | } | |||
695 | const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); | |||
696 | if (!handler) { | |||
697 | return resourceReader.emitError() | |||
698 | << "unexpected resources for dialect '" << dialect->name << "'"; | |||
699 | } | |||
700 | ||||
701 | // Ensure that each resource is declared before being processed. | |||
702 | auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { | |||
703 | FailureOr<AsmDialectResourceHandle> handle = | |||
704 | handler->declareResource(key); | |||
705 | if (failed(handle)) { | |||
706 | return resourceReader.emitError() | |||
707 | << "unknown 'resource' key '" << key << "' for dialect '" | |||
708 | << dialect->name << "'"; | |||
709 | } | |||
710 | dialectResources.push_back(*handle); | |||
711 | return success(); | |||
712 | }; | |||
713 | ||||
714 | // Parse the resources for this dialect. We allow empty resources because we | |||
715 | // just treat these as declarations. | |||
716 | if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) | |||
717 | return failure(); | |||
718 | } | |||
719 | ||||
720 | return success(); | |||
721 | } | |||
722 | ||||
723 | //===----------------------------------------------------------------------===// | |||
724 | // Attribute/Type Reader | |||
725 | //===----------------------------------------------------------------------===// | |||
726 | ||||
727 | namespace { | |||
728 | /// This class provides support for reading attribute and type entries from the | |||
729 | /// bytecode. Attribute and Type entries are read lazily on demand, so we use | |||
730 | /// this reader to manage when to actually parse them from the bytecode. | |||
731 | class AttrTypeReader { | |||
732 | /// This class represents a single attribute or type entry. | |||
733 | template <typename T> | |||
734 | struct Entry { | |||
735 | /// The entry, or null if it hasn't been resolved yet. | |||
736 | T entry = {}; | |||
737 | /// The parent dialect of this entry. | |||
738 | BytecodeDialect *dialect = nullptr; | |||
739 | /// A flag indicating if the entry was encoded using a custom encoding, | |||
740 | /// instead of using the textual assembly format. | |||
741 | bool hasCustomEncoding = false; | |||
742 | /// The raw data of this entry in the bytecode. | |||
743 | ArrayRef<uint8_t> data; | |||
744 | }; | |||
745 | using AttrEntry = Entry<Attribute>; | |||
746 | using TypeEntry = Entry<Type>; | |||
747 | ||||
748 | public: | |||
749 | AttrTypeReader(StringSectionReader &stringReader, | |||
750 | ResourceSectionReader &resourceReader, Location fileLoc) | |||
751 | : stringReader(stringReader), resourceReader(resourceReader), | |||
752 | fileLoc(fileLoc) {} | |||
753 | ||||
754 | /// Initialize the attribute and type information within the reader. | |||
755 | LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
756 | ArrayRef<uint8_t> sectionData, | |||
757 | ArrayRef<uint8_t> offsetSectionData); | |||
758 | ||||
759 | /// Resolve the attribute or type at the given index. Returns nullptr on | |||
760 | /// failure. | |||
761 | Attribute resolveAttribute(size_t index) { | |||
762 | return resolveEntry(attributes, index, "Attribute"); | |||
763 | } | |||
764 | Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } | |||
765 | ||||
766 | /// Parse a reference to an attribute or type using the given reader. | |||
767 | LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { | |||
768 | uint64_t attrIdx; | |||
769 | if (failed(reader.parseVarInt(attrIdx))) | |||
770 | return failure(); | |||
771 | result = resolveAttribute(attrIdx); | |||
772 | return success(!!result); | |||
773 | } | |||
774 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
775 | uint64_t typeIdx; | |||
776 | if (failed(reader.parseVarInt(typeIdx))) | |||
777 | return failure(); | |||
778 | result = resolveType(typeIdx); | |||
779 | return success(!!result); | |||
780 | } | |||
781 | ||||
782 | template <typename T> | |||
783 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
784 | Attribute baseResult; | |||
785 | if (failed(parseAttribute(reader, baseResult))) | |||
786 | return failure(); | |||
787 | if ((result = baseResult.dyn_cast<T>())) | |||
788 | return success(); | |||
789 | return reader.emitError("expected attribute of type: ", | |||
790 | llvm::getTypeName<T>(), ", but got: ", baseResult); | |||
791 | } | |||
792 | ||||
793 | private: | |||
794 | /// Resolve the given entry at `index`. | |||
795 | template <typename T> | |||
796 | T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
797 | StringRef entryType); | |||
798 | ||||
799 | /// Parse an entry using the given reader that was encoded using the textual | |||
800 | /// assembly format. | |||
801 | template <typename T> | |||
802 | LogicalResult parseAsmEntry(T &result, EncodingReader &reader, | |||
803 | StringRef entryType); | |||
804 | ||||
805 | /// Parse an entry using the given reader that was encoded using a custom | |||
806 | /// bytecode format. | |||
807 | template <typename T> | |||
808 | LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, | |||
809 | StringRef entryType); | |||
810 | ||||
811 | /// The string section reader used to resolve string references when parsing | |||
812 | /// custom encoded attribute/type entries. | |||
813 | StringSectionReader &stringReader; | |||
814 | ||||
815 | /// The resource section reader used to resolve resource references when | |||
816 | /// parsing custom encoded attribute/type entries. | |||
817 | ResourceSectionReader &resourceReader; | |||
818 | ||||
819 | /// The set of attribute and type entries. | |||
820 | SmallVector<AttrEntry> attributes; | |||
821 | SmallVector<TypeEntry> types; | |||
822 | ||||
823 | /// A location used for error emission. | |||
824 | Location fileLoc; | |||
825 | }; | |||
826 | ||||
827 | class DialectReader : public DialectBytecodeReader { | |||
828 | public: | |||
829 | DialectReader(AttrTypeReader &attrTypeReader, | |||
830 | StringSectionReader &stringReader, | |||
831 | ResourceSectionReader &resourceReader, EncodingReader &reader) | |||
832 | : attrTypeReader(attrTypeReader), stringReader(stringReader), | |||
833 | resourceReader(resourceReader), reader(reader) {} | |||
834 | ||||
835 | InFlightDiagnostic emitError(const Twine &msg) override { | |||
836 | return reader.emitError(msg); | |||
837 | } | |||
838 | ||||
839 | //===--------------------------------------------------------------------===// | |||
840 | // IR | |||
841 | //===--------------------------------------------------------------------===// | |||
842 | ||||
843 | LogicalResult readAttribute(Attribute &result) override { | |||
844 | return attrTypeReader.parseAttribute(reader, result); | |||
845 | } | |||
846 | ||||
847 | LogicalResult readType(Type &result) override { | |||
848 | return attrTypeReader.parseType(reader, result); | |||
849 | } | |||
850 | ||||
851 | FailureOr<AsmDialectResourceHandle> readResourceHandle() override { | |||
852 | AsmDialectResourceHandle handle; | |||
853 | if (failed(resourceReader.parseResourceHandle(reader, handle))) | |||
854 | return failure(); | |||
855 | return handle; | |||
856 | } | |||
857 | ||||
858 | //===--------------------------------------------------------------------===// | |||
859 | // Primitives | |||
860 | //===--------------------------------------------------------------------===// | |||
861 | ||||
862 | LogicalResult readVarInt(uint64_t &result) override { | |||
863 | return reader.parseVarInt(result); | |||
864 | } | |||
865 | ||||
866 | LogicalResult readSignedVarInt(int64_t &result) override { | |||
867 | uint64_t unsignedResult; | |||
868 | if (failed(reader.parseSignedVarInt(unsignedResult))) | |||
869 | return failure(); | |||
870 | result = static_cast<int64_t>(unsignedResult); | |||
871 | return success(); | |||
872 | } | |||
873 | ||||
874 | FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { | |||
875 | // Small values are encoded using a single byte. | |||
876 | if (bitWidth <= 8) { | |||
877 | uint8_t value; | |||
878 | if (failed(reader.parseByte(value))) | |||
879 | return failure(); | |||
880 | return APInt(bitWidth, value); | |||
881 | } | |||
882 | ||||
883 | // Large values up to 64 bits are encoded using a single varint. | |||
884 | if (bitWidth <= 64) { | |||
885 | uint64_t value; | |||
886 | if (failed(reader.parseSignedVarInt(value))) | |||
887 | return failure(); | |||
888 | return APInt(bitWidth, value); | |||
889 | } | |||
890 | ||||
891 | // Otherwise, for really big values we encode the array of active words in | |||
892 | // the value. | |||
893 | uint64_t numActiveWords; | |||
894 | if (failed(reader.parseVarInt(numActiveWords))) | |||
895 | return failure(); | |||
896 | SmallVector<uint64_t, 4> words(numActiveWords); | |||
897 | for (uint64_t i = 0; i < numActiveWords; ++i) | |||
898 | if (failed(reader.parseSignedVarInt(words[i]))) | |||
899 | return failure(); | |||
900 | return APInt(bitWidth, words); | |||
901 | } | |||
902 | ||||
903 | FailureOr<APFloat> | |||
904 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { | |||
905 | FailureOr<APInt> intVal = | |||
906 | readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); | |||
907 | if (failed(intVal)) | |||
908 | return failure(); | |||
909 | return APFloat(semantics, *intVal); | |||
910 | } | |||
911 | ||||
912 | LogicalResult readString(StringRef &result) override { | |||
913 | return stringReader.parseString(reader, result); | |||
914 | } | |||
915 | ||||
916 | LogicalResult readBlob(ArrayRef<char> &result) override { | |||
917 | uint64_t dataSize; | |||
918 | ArrayRef<uint8_t> data; | |||
919 | if (failed(reader.parseVarInt(dataSize)) || | |||
920 | failed(reader.parseBytes(dataSize, data))) | |||
921 | return failure(); | |||
922 | result = llvm::makeArrayRef(reinterpret_cast<const char *>(data.data()), | |||
923 | data.size()); | |||
924 | return success(); | |||
925 | } | |||
926 | ||||
927 | private: | |||
928 | AttrTypeReader &attrTypeReader; | |||
929 | StringSectionReader &stringReader; | |||
930 | ResourceSectionReader &resourceReader; | |||
931 | EncodingReader &reader; | |||
932 | }; | |||
933 | } // namespace | |||
934 | ||||
935 | LogicalResult | |||
936 | AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
937 | ArrayRef<uint8_t> sectionData, | |||
938 | ArrayRef<uint8_t> offsetSectionData) { | |||
939 | EncodingReader offsetReader(offsetSectionData, fileLoc); | |||
940 | ||||
941 | // Parse the number of attribute and type entries. | |||
942 | uint64_t numAttributes, numTypes; | |||
943 | if (failed(offsetReader.parseVarInt(numAttributes)) || | |||
944 | failed(offsetReader.parseVarInt(numTypes))) | |||
945 | return failure(); | |||
946 | attributes.resize(numAttributes); | |||
947 | types.resize(numTypes); | |||
948 | ||||
949 | // A functor used to accumulate the offsets for the entries in the given | |||
950 | // range. | |||
951 | uint64_t currentOffset = 0; | |||
952 | auto parseEntries = [&](auto &&range) { | |||
953 | size_t currentIndex = 0, endIndex = range.size(); | |||
954 | ||||
955 | // Parse an individual entry. | |||
956 | auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { | |||
957 | auto &entry = range[currentIndex++]; | |||
958 | ||||
959 | uint64_t entrySize; | |||
960 | if (failed(offsetReader.parseVarIntWithFlag(entrySize, | |||
961 | entry.hasCustomEncoding))) | |||
962 | return failure(); | |||
963 | ||||
964 | // Verify that the offset is actually valid. | |||
965 | if (currentOffset + entrySize > sectionData.size()) { | |||
966 | return offsetReader.emitError( | |||
967 | "Attribute or Type entry offset points past the end of section"); | |||
968 | } | |||
969 | ||||
970 | entry.data = sectionData.slice(currentOffset, entrySize); | |||
971 | entry.dialect = dialect; | |||
972 | currentOffset += entrySize; | |||
973 | return success(); | |||
974 | }; | |||
975 | while (currentIndex != endIndex) | |||
976 | if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) | |||
977 | return failure(); | |||
978 | return success(); | |||
979 | }; | |||
980 | ||||
981 | // Process each of the attributes, and then the types. | |||
982 | if (failed(parseEntries(attributes)) || failed(parseEntries(types))) | |||
983 | return failure(); | |||
984 | ||||
985 | // Ensure that we read everything from the section. | |||
986 | if (!offsetReader.empty()) { | |||
987 | return offsetReader.emitError( | |||
988 | "unexpected trailing data in the Attribute/Type offset section"); | |||
989 | } | |||
990 | return success(); | |||
991 | } | |||
992 | ||||
993 | template <typename T> | |||
994 | T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
995 | StringRef entryType) { | |||
996 | if (index >= entries.size()) { | |||
997 | emitError(fileLoc) << "invalid " << entryType << " index: " << index; | |||
998 | return {}; | |||
999 | } | |||
1000 | ||||
1001 | // If the entry has already been resolved, there is nothing left to do. | |||
1002 | Entry<T> &entry = entries[index]; | |||
1003 | if (entry.entry) | |||
1004 | return entry.entry; | |||
1005 | ||||
1006 | // Parse the entry. | |||
1007 | EncodingReader reader(entry.data, fileLoc); | |||
1008 | ||||
1009 | // Parse based on how the entry was encoded. | |||
1010 | if (entry.hasCustomEncoding) { | |||
1011 | if (failed(parseCustomEntry(entry, reader, entryType))) | |||
1012 | return T(); | |||
1013 | } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { | |||
1014 | return T(); | |||
1015 | } | |||
1016 | ||||
1017 | if (!reader.empty()) { | |||
1018 | reader.emitError("unexpected trailing bytes after " + entryType + " entry"); | |||
1019 | return T(); | |||
1020 | } | |||
1021 | return entry.entry; | |||
1022 | } | |||
1023 | ||||
1024 | template <typename T> | |||
1025 | LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, | |||
1026 | StringRef entryType) { | |||
1027 | StringRef asmStr; | |||
1028 | if (failed(reader.parseNullTerminatedString(asmStr))) | |||
1029 | return failure(); | |||
1030 | ||||
1031 | // Invoke the MLIR assembly parser to parse the entry text. | |||
1032 | size_t numRead = 0; | |||
1033 | MLIRContext *context = fileLoc->getContext(); | |||
1034 | if constexpr (std::is_same_v<T, Type>) | |||
1035 | result = ::parseType(asmStr, context, numRead); | |||
1036 | else | |||
1037 | result = ::parseAttribute(asmStr, context, numRead); | |||
1038 | if (!result) | |||
1039 | return failure(); | |||
1040 | ||||
1041 | // Ensure there weren't dangling characters after the entry. | |||
1042 | if (numRead != asmStr.size()) { | |||
1043 | return reader.emitError("trailing characters found after ", entryType, | |||
1044 | " assembly format: ", asmStr.drop_front(numRead)); | |||
1045 | } | |||
1046 | return success(); | |||
1047 | } | |||
1048 | ||||
1049 | template <typename T> | |||
1050 | LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, | |||
1051 | EncodingReader &reader, | |||
1052 | StringRef entryType) { | |||
1053 | if (failed(entry.dialect->load(reader, fileLoc.getContext()))) | |||
1054 | return failure(); | |||
1055 | ||||
1056 | // Ensure that the dialect implements the bytecode interface. | |||
1057 | if (!entry.dialect->interface) { | |||
1058 | return reader.emitError("dialect '", entry.dialect->name, | |||
1059 | "' does not implement the bytecode interface"); | |||
1060 | } | |||
1061 | ||||
1062 | // Ask the dialect to parse the entry. | |||
1063 | DialectReader dialectReader(*this, stringReader, resourceReader, reader); | |||
1064 | if constexpr (std::is_same_v<T, Type>) | |||
1065 | entry.entry = entry.dialect->interface->readType(dialectReader); | |||
1066 | else | |||
1067 | entry.entry = entry.dialect->interface->readAttribute(dialectReader); | |||
1068 | return success(!!entry.entry); | |||
1069 | } | |||
1070 | ||||
1071 | //===----------------------------------------------------------------------===// | |||
1072 | // Bytecode Reader | |||
1073 | //===----------------------------------------------------------------------===// | |||
1074 | ||||
1075 | namespace { | |||
1076 | /// This class is used to read a bytecode buffer and translate it into MLIR. | |||
1077 | class BytecodeReader { | |||
1078 | public: | |||
1079 | BytecodeReader(Location fileLoc, const ParserConfig &config, | |||
1080 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) | |||
1081 | : config(config), fileLoc(fileLoc), | |||
1082 | attrTypeReader(stringReader, resourceReader, fileLoc), | |||
1083 | // Use the builtin unrealized conversion cast operation to represent | |||
1084 | // forward references to values that aren't yet defined. | |||
1085 | forwardRefOpState(UnknownLoc::get(config.getContext()), | |||
1086 | "builtin.unrealized_conversion_cast", ValueRange(), | |||
1087 | NoneType::get(config.getContext())), | |||
1088 | bufferOwnerRef(bufferOwnerRef) {} | |||
1089 | ||||
1090 | /// Read the bytecode defined within `buffer` into the given block. | |||
1091 | LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); | |||
1092 | ||||
1093 | private: | |||
1094 | /// Return the context for this config. | |||
1095 | MLIRContext *getContext() const { return config.getContext(); } | |||
1096 | ||||
1097 | /// Parse the bytecode version. | |||
1098 | LogicalResult parseVersion(EncodingReader &reader); | |||
1099 | ||||
1100 | //===--------------------------------------------------------------------===// | |||
1101 | // Dialect Section | |||
1102 | ||||
1103 | LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); | |||
1104 | ||||
1105 | /// Parse an operation name reference using the given reader. | |||
1106 | FailureOr<OperationName> parseOpName(EncodingReader &reader); | |||
1107 | ||||
1108 | //===--------------------------------------------------------------------===// | |||
1109 | // Attribute/Type Section | |||
1110 | ||||
1111 | /// Parse an attribute or type using the given reader. | |||
1112 | template <typename T> | |||
1113 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
1114 | return attrTypeReader.parseAttribute(reader, result); | |||
1115 | } | |||
1116 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
1117 | return attrTypeReader.parseType(reader, result); | |||
1118 | } | |||
1119 | ||||
1120 | //===--------------------------------------------------------------------===// | |||
1121 | // Resource Section | |||
1122 | ||||
1123 | LogicalResult | |||
1124 | parseResourceSection(Optional<ArrayRef<uint8_t>> resourceData, | |||
1125 | Optional<ArrayRef<uint8_t>> resourceOffsetData); | |||
1126 | ||||
1127 | //===--------------------------------------------------------------------===// | |||
1128 | // IR Section | |||
1129 | ||||
1130 | /// This struct represents the current read state of a range of regions. This | |||
1131 | /// struct is used to enable iterative parsing of regions. | |||
1132 | struct RegionReadState { | |||
1133 | RegionReadState(Operation *op, bool isIsolatedFromAbove) | |||
1134 | : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} | |||
1135 | RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove) | |||
1136 | : curRegion(regions.begin()), endRegion(regions.end()), | |||
1137 | isIsolatedFromAbove(isIsolatedFromAbove) {} | |||
1138 | ||||
1139 | /// The current regions being read. | |||
1140 | MutableArrayRef<Region>::iterator curRegion, endRegion; | |||
1141 | ||||
1142 | /// The number of values defined immediately within this region. | |||
1143 | unsigned numValues = 0; | |||
1144 | ||||
1145 | /// The current blocks of the region being read. | |||
1146 | SmallVector<Block *> curBlocks; | |||
1147 | Region::iterator curBlock = {}; | |||
1148 | ||||
1149 | /// The number of operations remaining to be read from the current block | |||
1150 | /// being read. | |||
1151 | uint64_t numOpsRemaining = 0; | |||
1152 | ||||
1153 | /// A flag indicating if the regions being read are isolated from above. | |||
1154 | bool isIsolatedFromAbove = false; | |||
1155 | }; | |||
1156 | ||||
1157 | LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); | |||
1158 | LogicalResult parseRegions(EncodingReader &reader, | |||
1159 | std::vector<RegionReadState> ®ionStack, | |||
1160 | RegionReadState &readState); | |||
1161 | FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, | |||
1162 | RegionReadState &readState, | |||
1163 | bool &isIsolatedFromAbove); | |||
1164 | ||||
1165 | LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); | |||
1166 | LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); | |||
1167 | LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); | |||
1168 | ||||
1169 | //===--------------------------------------------------------------------===// | |||
1170 | // Value Processing | |||
1171 | ||||
1172 | /// Parse an operand reference using the given reader. Returns nullptr in the | |||
1173 | /// case of failure. | |||
1174 | Value parseOperand(EncodingReader &reader); | |||
1175 | ||||
1176 | /// Sequentially define the given value range. | |||
1177 | LogicalResult defineValues(EncodingReader &reader, ValueRange values); | |||
1178 | ||||
1179 | /// Create a value to use for a forward reference. | |||
1180 | Value createForwardRef(); | |||
1181 | ||||
1182 | //===--------------------------------------------------------------------===// | |||
1183 | // Fields | |||
1184 | ||||
1185 | /// This class represents a single value scope, in which a value scope is | |||
1186 | /// delimited by isolated from above regions. | |||
1187 | struct ValueScope { | |||
1188 | /// Push a new region state onto this scope, reserving enough values for | |||
1189 | /// those defined within the current region of the provided state. | |||
1190 | void push(RegionReadState &readState) { | |||
1191 | nextValueIDs.push_back(values.size()); | |||
1192 | values.resize(values.size() + readState.numValues); | |||
1193 | } | |||
1194 | ||||
1195 | /// Pop the values defined for the current region within the provided region | |||
1196 | /// state. | |||
1197 | void pop(RegionReadState &readState) { | |||
1198 | values.resize(values.size() - readState.numValues); | |||
1199 | nextValueIDs.pop_back(); | |||
1200 | } | |||
1201 | ||||
1202 | /// The set of values defined in this scope. | |||
1203 | std::vector<Value> values; | |||
1204 | ||||
1205 | /// The ID for the next defined value for each region current being | |||
1206 | /// processed in this scope. | |||
1207 | SmallVector<unsigned, 4> nextValueIDs; | |||
1208 | }; | |||
1209 | ||||
1210 | /// The configuration of the parser. | |||
1211 | const ParserConfig &config; | |||
1212 | ||||
1213 | /// A location to use when emitting errors. | |||
1214 | Location fileLoc; | |||
1215 | ||||
1216 | /// The reader used to process attribute and types within the bytecode. | |||
1217 | AttrTypeReader attrTypeReader; | |||
1218 | ||||
1219 | /// The version of the bytecode being read. | |||
1220 | uint64_t version = 0; | |||
1221 | ||||
1222 | /// The producer of the bytecode being read. | |||
1223 | StringRef producer; | |||
1224 | ||||
1225 | /// The table of IR units referenced within the bytecode file. | |||
1226 | SmallVector<BytecodeDialect> dialects; | |||
1227 | SmallVector<BytecodeOperationName> opNames; | |||
1228 | ||||
1229 | /// The reader used to process resources within the bytecode. | |||
1230 | ResourceSectionReader resourceReader; | |||
1231 | ||||
1232 | /// The table of strings referenced within the bytecode file. | |||
1233 | StringSectionReader stringReader; | |||
1234 | ||||
1235 | /// The current set of available IR value scopes. | |||
1236 | std::vector<ValueScope> valueScopes; | |||
1237 | /// A block containing the set of operations defined to create forward | |||
1238 | /// references. | |||
1239 | Block forwardRefOps; | |||
1240 | /// A block containing previously created, and no longer used, forward | |||
1241 | /// reference operations. | |||
1242 | Block openForwardRefOps; | |||
1243 | /// An operation state used when instantiating forward references. | |||
1244 | OperationState forwardRefOpState; | |||
1245 | ||||
1246 | /// The optional owning source manager, which when present may be used to | |||
1247 | /// extend the lifetime of the input buffer. | |||
1248 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; | |||
1249 | }; | |||
1250 | } // namespace | |||
1251 | ||||
1252 | LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { | |||
1253 | EncodingReader reader(buffer.getBuffer(), fileLoc); | |||
1254 | ||||
1255 | // Skip over the bytecode header, this should have already been checked. | |||
1256 | if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) | |||
1257 | return failure(); | |||
1258 | // Parse the bytecode version and producer. | |||
1259 | if (failed(parseVersion(reader)) || | |||
1260 | failed(reader.parseNullTerminatedString(producer))) | |||
1261 | return failure(); | |||
1262 | ||||
1263 | // Add a diagnostic handler that attaches a note that includes the original | |||
1264 | // producer of the bytecode. | |||
1265 | ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { | |||
1266 | diag.attachNote() << "in bytecode version " << version | |||
1267 | << " produced by: " << producer; | |||
1268 | return failure(); | |||
1269 | }); | |||
1270 | ||||
1271 | // Parse the raw data for each of the top-level sections of the bytecode. | |||
1272 | Optional<ArrayRef<uint8_t>> sectionDatas[bytecode::Section::kNumSections]; | |||
1273 | while (!reader.empty()) { | |||
1274 | // Read the next section from the bytecode. | |||
1275 | bytecode::Section::ID sectionID; | |||
1276 | ArrayRef<uint8_t> sectionData; | |||
1277 | if (failed(reader.parseSection(sectionID, sectionData))) | |||
1278 | return failure(); | |||
1279 | ||||
1280 | // Check for duplicate sections, we only expect one instance of each. | |||
1281 | if (sectionDatas[sectionID]) { | |||
1282 | return reader.emitError("duplicate top-level section: ", | |||
1283 | toString(sectionID)); | |||
1284 | } | |||
1285 | sectionDatas[sectionID] = sectionData; | |||
1286 | } | |||
1287 | // Check that all of the required sections were found. | |||
1288 | for (int i = 0; i < bytecode::Section::kNumSections; ++i) { | |||
1289 | bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); | |||
1290 | if (!sectionDatas[i] && !isSectionOptional(sectionID)) { | |||
1291 | return reader.emitError("missing data for top-level section: ", | |||
1292 | toString(sectionID)); | |||
1293 | } | |||
1294 | } | |||
1295 | ||||
1296 | // Process the string section first. | |||
1297 | if (failed(stringReader.initialize( | |||
1298 | fileLoc, *sectionDatas[bytecode::Section::kString]))) | |||
1299 | return failure(); | |||
1300 | ||||
1301 | // Process the dialect section. | |||
1302 | if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) | |||
1303 | return failure(); | |||
1304 | ||||
1305 | // Process the resource section if present. | |||
1306 | if (failed(parseResourceSection( | |||
1307 | sectionDatas[bytecode::Section::kResource], | |||
1308 | sectionDatas[bytecode::Section::kResourceOffset]))) | |||
1309 | return failure(); | |||
1310 | ||||
1311 | // Process the attribute and type section. | |||
1312 | if (failed(attrTypeReader.initialize( | |||
1313 | dialects, *sectionDatas[bytecode::Section::kAttrType], | |||
1314 | *sectionDatas[bytecode::Section::kAttrTypeOffset]))) | |||
1315 | return failure(); | |||
1316 | ||||
1317 | // Finally, process the IR section. | |||
1318 | return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); | |||
1319 | } | |||
1320 | ||||
1321 | LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { | |||
1322 | if (failed(reader.parseVarInt(version))) | |||
1323 | return failure(); | |||
1324 | ||||
1325 | // Validate the bytecode version. | |||
1326 | uint64_t currentVersion = bytecode::kVersion; | |||
1327 | if (version < currentVersion) { | |||
1328 | return reader.emitError("bytecode version ", version, | |||
1329 | " is older than the current version of ", | |||
1330 | currentVersion, ", and upgrade is not supported"); | |||
1331 | } | |||
1332 | if (version > currentVersion) { | |||
1333 | return reader.emitError("bytecode version ", version, | |||
1334 | " is newer than the current version ", | |||
1335 | currentVersion); | |||
1336 | } | |||
1337 | return success(); | |||
1338 | } | |||
1339 | ||||
1340 | //===----------------------------------------------------------------------===// | |||
1341 | // Dialect Section | |||
1342 | ||||
1343 | LogicalResult | |||
1344 | BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { | |||
1345 | EncodingReader sectionReader(sectionData, fileLoc); | |||
1346 | ||||
1347 | // Parse the number of dialects in the section. | |||
1348 | uint64_t numDialects; | |||
1349 | if (failed(sectionReader.parseVarInt(numDialects))) | |||
1350 | return failure(); | |||
1351 | dialects.resize(numDialects); | |||
1352 | ||||
1353 | // Parse each of the dialects. | |||
1354 | for (uint64_t i = 0; i < numDialects; ++i) | |||
1355 | if (failed(stringReader.parseString(sectionReader, dialects[i].name))) | |||
1356 | return failure(); | |||
1357 | ||||
1358 | // Parse the operation names, which are grouped by dialect. | |||
1359 | auto parseOpName = [&](BytecodeDialect *dialect) { | |||
1360 | StringRef opName; | |||
1361 | if (failed(stringReader.parseString(sectionReader, opName))) | |||
1362 | return failure(); | |||
1363 | opNames.emplace_back(dialect, opName); | |||
1364 | return success(); | |||
1365 | }; | |||
1366 | while (!sectionReader.empty()) | |||
1367 | if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) | |||
1368 | return failure(); | |||
1369 | return success(); | |||
1370 | } | |||
1371 | ||||
1372 | FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { | |||
1373 | BytecodeOperationName *opName = nullptr; | |||
1374 | if (failed(parseEntry(reader, opNames, opName, "operation name"))) | |||
1375 | return failure(); | |||
1376 | ||||
1377 | // Check to see if this operation name has already been resolved. If we | |||
1378 | // haven't, load the dialect and build the operation name. | |||
1379 | if (!opName->opName) { | |||
1380 | if (failed(opName->dialect->load(reader, getContext()))) | |||
1381 | return failure(); | |||
1382 | opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), | |||
1383 | getContext()); | |||
1384 | } | |||
1385 | return *opName->opName; | |||
1386 | } | |||
1387 | ||||
1388 | //===----------------------------------------------------------------------===// | |||
1389 | // Resource Section | |||
1390 | ||||
1391 | LogicalResult BytecodeReader::parseResourceSection( | |||
1392 | Optional<ArrayRef<uint8_t>> resourceData, | |||
1393 | Optional<ArrayRef<uint8_t>> resourceOffsetData) { | |||
1394 | // Ensure both sections are either present or not. | |||
1395 | if (resourceData.has_value() != resourceOffsetData.has_value()) { | |||
1396 | if (resourceOffsetData) | |||
1397 | return emitError(fileLoc, "unexpected resource offset section when " | |||
1398 | "resource section is not present"); | |||
1399 | return emitError( | |||
1400 | fileLoc, | |||
1401 | "expected resource offset section when resource section is present"); | |||
1402 | } | |||
1403 | ||||
1404 | // If the resource sections are absent, there is nothing to do. | |||
1405 | if (!resourceData) | |||
1406 | return success(); | |||
1407 | ||||
1408 | // Initialize the resource reader with the resource sections. | |||
1409 | return resourceReader.initialize(fileLoc, config, dialects, stringReader, | |||
1410 | *resourceData, *resourceOffsetData, | |||
1411 | bufferOwnerRef); | |||
1412 | } | |||
1413 | ||||
1414 | //===----------------------------------------------------------------------===// | |||
1415 | // IR Section | |||
1416 | ||||
1417 | LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, | |||
1418 | Block *block) { | |||
1419 | EncodingReader reader(sectionData, fileLoc); | |||
1420 | ||||
1421 | // A stack of operation regions currently being read from the bytecode. | |||
1422 | std::vector<RegionReadState> regionStack; | |||
1423 | ||||
1424 | // Parse the top-level block using a temporary module operation. | |||
1425 | OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); | |||
1426 | regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); | |||
1427 | regionStack.back().curBlocks.push_back(moduleOp->getBody()); | |||
1428 | regionStack.back().curBlock = regionStack.back().curRegion->begin(); | |||
1429 | if (failed(parseBlock(reader, regionStack.back()))) | |||
1430 | return failure(); | |||
1431 | valueScopes.emplace_back(); | |||
1432 | valueScopes.back().push(regionStack.back()); | |||
1433 | ||||
1434 | // Iteratively parse regions until everything has been resolved. | |||
1435 | while (!regionStack.empty()) | |||
1436 | if (failed(parseRegions(reader, regionStack, regionStack.back()))) | |||
1437 | return failure(); | |||
1438 | if (!forwardRefOps.empty()) { | |||
1439 | return reader.emitError( | |||
1440 | "not all forward unresolved forward operand references"); | |||
1441 | } | |||
1442 | ||||
1443 | // Verify that the parsed operations are valid. | |||
1444 | if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) | |||
1445 | return failure(); | |||
1446 | ||||
1447 | // Splice the parsed operations over to the provided top-level block. | |||
1448 | auto &parsedOps = moduleOp->getBody()->getOperations(); | |||
1449 | auto &destOps = block->getOperations(); | |||
1450 | destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); | |||
1451 | return success(); | |||
1452 | } | |||
1453 | ||||
1454 | LogicalResult | |||
1455 | BytecodeReader::parseRegions(EncodingReader &reader, | |||
1456 | std::vector<RegionReadState> ®ionStack, | |||
1457 | RegionReadState &readState) { | |||
1458 | // Read the regions of this operation. | |||
1459 | for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { | |||
1460 | // If the current block hasn't been setup yet, parse the header for this | |||
1461 | // region. | |||
1462 | if (readState.curBlock == Region::iterator()) { | |||
1463 | if (failed(parseRegion(reader, readState))) | |||
1464 | return failure(); | |||
1465 | ||||
1466 | // If the region is empty, there is nothing to more to do. | |||
1467 | if (readState.curRegion->empty()) | |||
1468 | continue; | |||
1469 | } | |||
1470 | ||||
1471 | // Parse the blocks within the region. | |||
1472 | do { | |||
1473 | while (readState.numOpsRemaining--) { | |||
1474 | // Read in the next operation. We don't read its regions directly, we | |||
1475 | // handle those afterwards as necessary. | |||
1476 | bool isIsolatedFromAbove = false; | |||
1477 | FailureOr<Operation *> op = | |||
1478 | parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); | |||
1479 | if (failed(op)) | |||
1480 | return failure(); | |||
1481 | ||||
1482 | // If the op has regions, add it to the stack for processing. | |||
1483 | if ((*op)->getNumRegions()) { | |||
1484 | regionStack.emplace_back(*op, isIsolatedFromAbove); | |||
1485 | ||||
1486 | // If the op is isolated from above, push a new value scope. | |||
1487 | if (isIsolatedFromAbove) | |||
1488 | valueScopes.emplace_back(); | |||
1489 | return success(); | |||
1490 | } | |||
1491 | } | |||
1492 | ||||
1493 | // Move to the next block of the region. | |||
1494 | if (++readState.curBlock == readState.curRegion->end()) | |||
1495 | break; | |||
1496 | if (failed(parseBlock(reader, readState))) | |||
1497 | return failure(); | |||
1498 | } while (true); | |||
1499 | ||||
1500 | // Reset the current block and any values reserved for this region. | |||
1501 | readState.curBlock = {}; | |||
1502 | valueScopes.back().pop(readState); | |||
1503 | } | |||
1504 | ||||
1505 | // When the regions have been fully parsed, pop them off of the read stack. If | |||
1506 | // the regions were isolated from above, we also pop the last value scope. | |||
1507 | if (readState.isIsolatedFromAbove) | |||
1508 | valueScopes.pop_back(); | |||
1509 | regionStack.pop_back(); | |||
1510 | return success(); | |||
1511 | } | |||
1512 | ||||
1513 | FailureOr<Operation *> | |||
1514 | BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, | |||
1515 | RegionReadState &readState, | |||
1516 | bool &isIsolatedFromAbove) { | |||
1517 | // Parse the name of the operation. | |||
1518 | FailureOr<OperationName> opName = parseOpName(reader); | |||
1519 | if (failed(opName)) | |||
1520 | return failure(); | |||
1521 | ||||
1522 | // Parse the operation mask, which indicates which components of the operation | |||
1523 | // are present. | |||
1524 | uint8_t opMask; | |||
1525 | if (failed(reader.parseByte(opMask))) | |||
1526 | return failure(); | |||
1527 | ||||
1528 | /// Parse the location. | |||
1529 | LocationAttr opLoc; | |||
1530 | if (failed(parseAttribute(reader, opLoc))) | |||
1531 | return failure(); | |||
1532 | ||||
1533 | // With the location and name resolved, we can start building the operation | |||
1534 | // state. | |||
1535 | OperationState opState(opLoc, *opName); | |||
1536 | ||||
1537 | // Parse the attributes of the operation. | |||
1538 | if (opMask & bytecode::OpEncodingMask::kHasAttrs) { | |||
1539 | DictionaryAttr dictAttr; | |||
1540 | if (failed(parseAttribute(reader, dictAttr))) | |||
1541 | return failure(); | |||
1542 | opState.attributes = dictAttr; | |||
1543 | } | |||
1544 | ||||
1545 | /// Parse the results of the operation. | |||
1546 | if (opMask & bytecode::OpEncodingMask::kHasResults) { | |||
1547 | uint64_t numResults; | |||
1548 | if (failed(reader.parseVarInt(numResults))) | |||
1549 | return failure(); | |||
1550 | opState.types.resize(numResults); | |||
1551 | for (int i = 0, e = numResults; i < e; ++i) | |||
1552 | if (failed(parseType(reader, opState.types[i]))) | |||
1553 | return failure(); | |||
1554 | } | |||
1555 | ||||
1556 | /// Parse the operands of the operation. | |||
1557 | if (opMask & bytecode::OpEncodingMask::kHasOperands) { | |||
1558 | uint64_t numOperands; | |||
1559 | if (failed(reader.parseVarInt(numOperands))) | |||
1560 | return failure(); | |||
1561 | opState.operands.resize(numOperands); | |||
1562 | for (int i = 0, e = numOperands; i < e; ++i) | |||
1563 | if (!(opState.operands[i] = parseOperand(reader))) | |||
1564 | return failure(); | |||
1565 | } | |||
1566 | ||||
1567 | /// Parse the successors of the operation. | |||
1568 | if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { | |||
1569 | uint64_t numSuccs; | |||
1570 | if (failed(reader.parseVarInt(numSuccs))) | |||
1571 | return failure(); | |||
1572 | opState.successors.resize(numSuccs); | |||
1573 | for (int i = 0, e = numSuccs; i < e; ++i) { | |||
1574 | if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], | |||
1575 | "successor"))) | |||
1576 | return failure(); | |||
1577 | } | |||
1578 | } | |||
1579 | ||||
1580 | /// Parse the regions of the operation. | |||
1581 | if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { | |||
1582 | uint64_t numRegions; | |||
1583 | if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) | |||
1584 | return failure(); | |||
1585 | ||||
1586 | opState.regions.reserve(numRegions); | |||
1587 | for (int i = 0, e = numRegions; i < e; ++i) | |||
1588 | opState.regions.push_back(std::make_unique<Region>()); | |||
1589 | } | |||
1590 | ||||
1591 | // Create the operation at the back of the current block. | |||
1592 | Operation *op = Operation::create(opState); | |||
1593 | readState.curBlock->push_back(op); | |||
1594 | ||||
1595 | // If the operation had results, update the value references. | |||
1596 | if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) | |||
1597 | return failure(); | |||
1598 | ||||
1599 | return op; | |||
1600 | } | |||
1601 | ||||
1602 | LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, | |||
1603 | RegionReadState &readState) { | |||
1604 | // Parse the number of blocks in the region. | |||
1605 | uint64_t numBlocks; | |||
1606 | if (failed(reader.parseVarInt(numBlocks))) | |||
1607 | return failure(); | |||
1608 | ||||
1609 | // If the region is empty, there is nothing else to do. | |||
1610 | if (numBlocks == 0) | |||
1611 | return success(); | |||
1612 | ||||
1613 | // Parse the number of values defined in this region. | |||
1614 | uint64_t numValues; | |||
1615 | if (failed(reader.parseVarInt(numValues))) | |||
1616 | return failure(); | |||
1617 | readState.numValues = numValues; | |||
1618 | ||||
1619 | // Create the blocks within this region. We do this before processing so that | |||
1620 | // we can rely on the blocks existing when creating operations. | |||
1621 | readState.curBlocks.clear(); | |||
1622 | readState.curBlocks.reserve(numBlocks); | |||
1623 | for (uint64_t i = 0; i < numBlocks; ++i) { | |||
1624 | readState.curBlocks.push_back(new Block()); | |||
1625 | readState.curRegion->push_back(readState.curBlocks.back()); | |||
1626 | } | |||
1627 | ||||
1628 | // Prepare the current value scope for this region. | |||
1629 | valueScopes.back().push(readState); | |||
1630 | ||||
1631 | // Parse the entry block of the region. | |||
1632 | readState.curBlock = readState.curRegion->begin(); | |||
1633 | return parseBlock(reader, readState); | |||
1634 | } | |||
1635 | ||||
1636 | LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, | |||
1637 | RegionReadState &readState) { | |||
1638 | bool hasArgs; | |||
1639 | if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) | |||
1640 | return failure(); | |||
1641 | ||||
1642 | // Parse the arguments of the block. | |||
1643 | if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) | |||
1644 | return failure(); | |||
1645 | ||||
1646 | // We don't parse the operations of the block here, that's done elsewhere. | |||
1647 | return success(); | |||
1648 | } | |||
1649 | ||||
1650 | LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, | |||
1651 | Block *block) { | |||
1652 | // Parse the value ID for the first argument, and the number of arguments. | |||
1653 | uint64_t numArgs; | |||
1654 | if (failed(reader.parseVarInt(numArgs))) | |||
1655 | return failure(); | |||
1656 | ||||
1657 | SmallVector<Type> argTypes; | |||
1658 | SmallVector<Location> argLocs; | |||
1659 | argTypes.reserve(numArgs); | |||
1660 | argLocs.reserve(numArgs); | |||
1661 | ||||
1662 | while (numArgs--) { | |||
1663 | Type argType; | |||
1664 | LocationAttr argLoc; | |||
1665 | if (failed(parseType(reader, argType)) || | |||
1666 | failed(parseAttribute(reader, argLoc))) | |||
1667 | return failure(); | |||
1668 | ||||
1669 | argTypes.push_back(argType); | |||
1670 | argLocs.push_back(argLoc); | |||
1671 | } | |||
1672 | block->addArguments(argTypes, argLocs); | |||
1673 | return defineValues(reader, block->getArguments()); | |||
1674 | } | |||
1675 | ||||
1676 | //===----------------------------------------------------------------------===// | |||
1677 | // Value Processing | |||
1678 | ||||
1679 | Value BytecodeReader::parseOperand(EncodingReader &reader) { | |||
1680 | std::vector<Value> &values = valueScopes.back().values; | |||
1681 | Value *value = nullptr; | |||
1682 | if (failed(parseEntry(reader, values, value, "value"))) | |||
1683 | return Value(); | |||
1684 | ||||
1685 | // Create a new forward reference if necessary. | |||
1686 | if (!*value) | |||
1687 | *value = createForwardRef(); | |||
1688 | return *value; | |||
1689 | } | |||
1690 | ||||
1691 | LogicalResult BytecodeReader::defineValues(EncodingReader &reader, | |||
1692 | ValueRange newValues) { | |||
1693 | ValueScope &valueScope = valueScopes.back(); | |||
1694 | std::vector<Value> &values = valueScope.values; | |||
1695 | ||||
1696 | unsigned &valueID = valueScope.nextValueIDs.back(); | |||
1697 | unsigned valueIDEnd = valueID + newValues.size(); | |||
1698 | if (valueIDEnd > values.size()) { | |||
1699 | return reader.emitError( | |||
1700 | "value index range was outside of the expected range for " | |||
1701 | "the parent region, got [", | |||
1702 | valueID, ", ", valueIDEnd, "), but the maximum index was ", | |||
1703 | values.size() - 1); | |||
1704 | } | |||
1705 | ||||
1706 | // Assign the values and update any forward references. | |||
1707 | for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { | |||
1708 | Value newValue = newValues[i]; | |||
1709 | ||||
1710 | // Check to see if a definition for this value already exists. | |||
1711 | if (Value oldValue = std::exchange(values[valueID], newValue)) { | |||
1712 | Operation *forwardRefOp = oldValue.getDefiningOp(); | |||
1713 | ||||
1714 | // Assert that this is a forward reference operation. Given how we compute | |||
1715 | // definition ids (incrementally as we parse), it shouldn't be possible | |||
1716 | // for the value to be defined any other way. | |||
1717 | assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1718, __extension__ __PRETTY_FUNCTION__)) | |||
1718 | "value index was already defined?")(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1718, __extension__ __PRETTY_FUNCTION__)); | |||
1719 | ||||
1720 | oldValue.replaceAllUsesWith(newValue); | |||
1721 | forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); | |||
1722 | } | |||
1723 | } | |||
1724 | return success(); | |||
1725 | } | |||
1726 | ||||
1727 | Value BytecodeReader::createForwardRef() { | |||
1728 | // Check for an avaliable existing operation to use. Otherwise, create a new | |||
1729 | // fake operation to use for the reference. | |||
1730 | if (!openForwardRefOps.empty()) { | |||
1731 | Operation *op = &openForwardRefOps.back(); | |||
1732 | op->moveBefore(&forwardRefOps, forwardRefOps.end()); | |||
1733 | } else { | |||
1734 | forwardRefOps.push_back(Operation::create(forwardRefOpState)); | |||
1735 | } | |||
1736 | return forwardRefOps.back().getResult(0); | |||
1737 | } | |||
1738 | ||||
1739 | //===----------------------------------------------------------------------===// | |||
1740 | // Entry Points | |||
1741 | //===----------------------------------------------------------------------===// | |||
1742 | ||||
1743 | bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { | |||
1744 | return buffer.getBuffer().startswith("ML\xefR"); | |||
1745 | } | |||
1746 | ||||
1747 | /// Read the bytecode from the provided memory buffer reference. | |||
1748 | /// `bufferOwnerRef` if provided is the owning source manager for the buffer, | |||
1749 | /// and may be used to extend the lifetime of the buffer. | |||
1750 | static LogicalResult | |||
1751 | readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, | |||
1752 | const ParserConfig &config, | |||
1753 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { | |||
1754 | Location sourceFileLoc = | |||
1755 | FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), | |||
1756 | /*line=*/0, /*column=*/0); | |||
1757 | if (!isBytecode(buffer)) { | |||
1758 | return emitError(sourceFileLoc, | |||
1759 | "input buffer is not an MLIR bytecode file"); | |||
1760 | } | |||
1761 | ||||
1762 | BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); | |||
1763 | return reader.read(buffer, block); | |||
1764 | } | |||
1765 | ||||
1766 | LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, | |||
1767 | const ParserConfig &config) { | |||
1768 | return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); | |||
1769 | } | |||
1770 | LogicalResult | |||
1771 | mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, | |||
1772 | Block *block, const ParserConfig &config) { | |||
1773 | return readBytecodeFileImpl( | |||
1774 | *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, | |||
1775 | sourceMgr); | |||
1776 | } |
1 | //===- LogicalResult.h - Utilities for handling success/failure -*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #ifndef MLIR_SUPPORT_LOGICALRESULT_H |
10 | #define MLIR_SUPPORT_LOGICALRESULT_H |
11 | |
12 | #include "mlir/Support/LLVM.h" |
13 | #include "llvm/ADT/Optional.h" |
14 | |
15 | namespace mlir { |
16 | |
17 | /// This class represents an efficient way to signal success or failure. It |
18 | /// should be preferred over the use of `bool` when appropriate, as it avoids |
19 | /// all of the ambiguity that arises in interpreting a boolean result. This |
20 | /// class is marked as NODISCARD to ensure that the result is processed. Users |
21 | /// may explicitly discard a result by using `(void)`, e.g. |
22 | /// `(void)functionThatReturnsALogicalResult();`. Given the intended nature of |
23 | /// this class, it generally shouldn't be used as the result of functions that |
24 | /// very frequently have the result ignored. This class is intended to be used |
25 | /// in conjunction with the utility functions below. |
26 | struct [[nodiscard]] LogicalResult { |
27 | public: |
28 | /// If isSuccess is true a `success` result is generated, otherwise a |
29 | /// 'failure' result is generated. |
30 | static LogicalResult success(bool isSuccess = true) { |
31 | return LogicalResult(isSuccess); |
32 | } |
33 | |
34 | /// If isFailure is true a `failure` result is generated, otherwise a |
35 | /// 'success' result is generated. |
36 | static LogicalResult failure(bool isFailure = true) { |
37 | return LogicalResult(!isFailure); |
38 | } |
39 | |
40 | /// Returns true if the provided LogicalResult corresponds to a success value. |
41 | bool succeeded() const { return isSuccess; } |
42 | |
43 | /// Returns true if the provided LogicalResult corresponds to a failure value. |
44 | bool failed() const { return !isSuccess; } |
45 | |
46 | private: |
47 | LogicalResult(bool isSuccess) : isSuccess(isSuccess) {} |
48 | |
49 | /// Boolean indicating if this is a success result, if false this is a |
50 | /// failure result. |
51 | bool isSuccess; |
52 | }; |
53 | |
54 | /// Utility function to generate a LogicalResult. If isSuccess is true a |
55 | /// `success` result is generated, otherwise a 'failure' result is generated. |
56 | inline LogicalResult success(bool isSuccess = true) { |
57 | return LogicalResult::success(isSuccess); |
58 | } |
59 | |
60 | /// Utility function to generate a LogicalResult. If isFailure is true a |
61 | /// `failure` result is generated, otherwise a 'success' result is generated. |
62 | inline LogicalResult failure(bool isFailure = true) { |
63 | return LogicalResult::failure(isFailure); |
64 | } |
65 | |
66 | /// Utility function that returns true if the provided LogicalResult corresponds |
67 | /// to a success value. |
68 | inline bool succeeded(LogicalResult result) { return result.succeeded(); } |
69 | |
70 | /// Utility function that returns true if the provided LogicalResult corresponds |
71 | /// to a failure value. |
72 | inline bool failed(LogicalResult result) { return result.failed(); } |
73 | |
74 | /// This class provides support for representing a failure result, or a valid |
75 | /// value of type `T`. This allows for integrating with LogicalResult, while |
76 | /// also providing a value on the success path. |
77 | template <typename T> |
78 | class [[nodiscard]] FailureOr : public Optional<T> { |
79 | public: |
80 | /// Allow constructing from a LogicalResult. The result *must* be a failure. |
81 | /// Success results should use a proper instance of type `T`. |
82 | FailureOr(LogicalResult result) { |
83 | assert(failed(result) &&(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'" ) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\"" , "mlir/include/mlir/Support/LogicalResult.h", 84, __extension__ __PRETTY_FUNCTION__)) |
84 | "success should be constructed with an instance of 'T'")(static_cast <bool> (failed(result) && "success should be constructed with an instance of 'T'" ) ? void (0) : __assert_fail ("failed(result) && \"success should be constructed with an instance of 'T'\"" , "mlir/include/mlir/Support/LogicalResult.h", 84, __extension__ __PRETTY_FUNCTION__)); |
85 | } |
86 | FailureOr() : FailureOr(failure()) {} |
87 | FailureOr(T &&y) : Optional<T>(std::forward<T>(y)) {} |
88 | FailureOr(const T &y) : Optional<T>(y) {} |
89 | template <typename U, |
90 | std::enable_if_t<std::is_constructible<T, U>::value> * = nullptr> |
91 | FailureOr(const FailureOr<U> &other) |
92 | : Optional<T>(failed(other) ? Optional<T>() : Optional<T>(*other)) {} |
93 | |
94 | operator LogicalResult() const { return success(this->has_value()); } |
95 | |
96 | private: |
97 | /// Hide the bool conversion as it easily creates confusion. |
98 | using Optional<T>::operator bool; |
99 | using Optional<T>::has_value; |
100 | }; |
101 | |
102 | /// Wrap a value on the success path in a FailureOr of the same value type. |
103 | template <typename T, |
104 | typename = std::enable_if_t<!std::is_convertible_v<T, bool>>> |
105 | inline auto success(T &&t) { |
106 | return FailureOr<std::decay_t<T>>(std::forward<T>(t)); |
107 | } |
108 | |
109 | /// This class represents success/failure for parsing-like operations that find |
110 | /// it important to chain together failable operations with `||`. This is an |
111 | /// extended version of `LogicalResult` that allows for explicit conversion to |
112 | /// bool. |
113 | /// |
114 | /// This class should not be used for general error handling cases - we prefer |
115 | /// to keep the logic explicit with the `succeeded`/`failed` predicates. |
116 | /// However, traditional monadic-style parsing logic can sometimes get |
117 | /// swallowed up in boilerplate without this, so we provide this for narrow |
118 | /// cases where it is important. |
119 | /// |
120 | class [[nodiscard]] ParseResult : public LogicalResult { |
121 | public: |
122 | ParseResult(LogicalResult result = success()) : LogicalResult(result) {} |
123 | |
124 | /// Failure is true in a boolean context. |
125 | explicit operator bool() const { return failed(); } |
126 | }; |
127 | |
128 | } // namespace mlir |
129 | |
130 | #endif // MLIR_SUPPORT_LOGICALRESULT_H |
1 | //===- Optional.h - Simple variant for passing optional values --*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | /// | |||
9 | /// \file | |||
10 | /// This file provides Optional, a template class modeled in the spirit of | |||
11 | /// OCaml's 'opt' variant. The idea is to strongly type whether or not | |||
12 | /// a value can be optional. | |||
13 | /// | |||
14 | //===----------------------------------------------------------------------===// | |||
15 | ||||
16 | #ifndef LLVM_ADT_OPTIONAL_H | |||
17 | #define LLVM_ADT_OPTIONAL_H | |||
18 | ||||
19 | #include "llvm/ADT/Hashing.h" | |||
20 | #include "llvm/Support/Compiler.h" | |||
21 | #include "llvm/Support/type_traits.h" | |||
22 | #include <cassert> | |||
23 | #include <new> | |||
24 | #include <utility> | |||
25 | ||||
26 | namespace llvm { | |||
27 | ||||
28 | namespace optional_detail { | |||
29 | ||||
30 | /// Storage for any type. | |||
31 | // | |||
32 | // The specialization condition intentionally uses | |||
33 | // llvm::is_trivially_{copy/move}_constructible instead of | |||
34 | // std::is_trivially_{copy/move}_constructible. GCC versions prior to 7.4 may | |||
35 | // instantiate the copy/move constructor of `T` when | |||
36 | // std::is_trivially_{copy/move}_constructible is instantiated. This causes | |||
37 | // compilation to fail if we query the trivially copy/move constructible | |||
38 | // property of a class which is not copy/move constructible. | |||
39 | // | |||
40 | // The current implementation of OptionalStorage insists that in order to use | |||
41 | // the trivial specialization, the value_type must be trivially copy | |||
42 | // constructible and trivially copy assignable due to =default implementations | |||
43 | // of the copy/move constructor/assignment. It does not follow that this is | |||
44 | // necessarily the case std::is_trivially_copyable is true (hence the expanded | |||
45 | // specialization condition). | |||
46 | // | |||
47 | // The move constructible / assignable conditions emulate the remaining behavior | |||
48 | // of std::is_trivially_copyable. | |||
49 | template <typename T, | |||
50 | bool = (llvm::is_trivially_copy_constructible<T>::value && | |||
51 | std::is_trivially_copy_assignable<T>::value && | |||
52 | (llvm::is_trivially_move_constructible<T>::value || | |||
53 | !std::is_move_constructible<T>::value) && | |||
54 | (std::is_trivially_move_assignable<T>::value || | |||
55 | !std::is_move_assignable<T>::value))> | |||
56 | class OptionalStorage { | |||
57 | union { | |||
58 | char empty; | |||
59 | T val; | |||
60 | }; | |||
61 | bool hasVal = false; | |||
62 | ||||
63 | public: | |||
64 | ~OptionalStorage() { reset(); } | |||
65 | ||||
66 | constexpr OptionalStorage() noexcept : empty() {} | |||
67 | ||||
68 | constexpr OptionalStorage(OptionalStorage const &other) : OptionalStorage() { | |||
69 | if (other.has_value()) { | |||
70 | emplace(other.val); | |||
71 | } | |||
72 | } | |||
73 | constexpr OptionalStorage(OptionalStorage &&other) : OptionalStorage() { | |||
74 | if (other.has_value()) { | |||
75 | emplace(std::move(other.val)); | |||
76 | } | |||
77 | } | |||
78 | ||||
79 | template <class... Args> | |||
80 | constexpr explicit OptionalStorage(std::in_place_t, Args &&...args) | |||
81 | : val(std::forward<Args>(args)...), hasVal(true) {} | |||
82 | ||||
83 | void reset() noexcept { | |||
84 | if (hasVal) { | |||
85 | val.~T(); | |||
86 | hasVal = false; | |||
87 | } | |||
88 | } | |||
89 | ||||
90 | constexpr bool has_value() const noexcept { return hasVal; } | |||
91 | ||||
92 | T &value() &noexcept { | |||
93 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 93, __extension__ __PRETTY_FUNCTION__)); | |||
94 | return val; | |||
95 | } | |||
96 | constexpr T const &value() const &noexcept { | |||
97 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 97, __extension__ __PRETTY_FUNCTION__)); | |||
98 | return val; | |||
99 | } | |||
100 | T &&value() &&noexcept { | |||
101 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 101, __extension__ __PRETTY_FUNCTION__)); | |||
102 | return std::move(val); | |||
103 | } | |||
104 | ||||
105 | template <class... Args> void emplace(Args &&...args) { | |||
106 | reset(); | |||
107 | ::new ((void *)std::addressof(val)) T(std::forward<Args>(args)...); | |||
108 | hasVal = true; | |||
109 | } | |||
110 | ||||
111 | OptionalStorage &operator=(T const &y) { | |||
112 | if (has_value()) { | |||
113 | val = y; | |||
114 | } else { | |||
115 | ::new ((void *)std::addressof(val)) T(y); | |||
116 | hasVal = true; | |||
117 | } | |||
118 | return *this; | |||
119 | } | |||
120 | OptionalStorage &operator=(T &&y) { | |||
121 | if (has_value()) { | |||
122 | val = std::move(y); | |||
123 | } else { | |||
124 | ::new ((void *)std::addressof(val)) T(std::move(y)); | |||
125 | hasVal = true; | |||
126 | } | |||
127 | return *this; | |||
128 | } | |||
129 | ||||
130 | OptionalStorage &operator=(OptionalStorage const &other) { | |||
131 | if (other.has_value()) { | |||
132 | if (has_value()) { | |||
133 | val = other.val; | |||
134 | } else { | |||
135 | ::new ((void *)std::addressof(val)) T(other.val); | |||
136 | hasVal = true; | |||
137 | } | |||
138 | } else { | |||
139 | reset(); | |||
140 | } | |||
141 | return *this; | |||
142 | } | |||
143 | ||||
144 | OptionalStorage &operator=(OptionalStorage &&other) { | |||
145 | if (other.has_value()) { | |||
146 | if (has_value()) { | |||
147 | val = std::move(other.val); | |||
148 | } else { | |||
149 | ::new ((void *)std::addressof(val)) T(std::move(other.val)); | |||
150 | hasVal = true; | |||
151 | } | |||
152 | } else { | |||
153 | reset(); | |||
154 | } | |||
155 | return *this; | |||
156 | } | |||
157 | }; | |||
158 | ||||
159 | template <typename T> class OptionalStorage<T, true> { | |||
160 | union { | |||
161 | char empty; | |||
162 | T val; | |||
163 | }; | |||
164 | bool hasVal = false; | |||
165 | ||||
166 | public: | |||
167 | ~OptionalStorage() = default; | |||
168 | ||||
169 | constexpr OptionalStorage() noexcept : empty{} {} | |||
170 | ||||
171 | constexpr OptionalStorage(OptionalStorage const &other) = default; | |||
172 | constexpr OptionalStorage(OptionalStorage &&other) = default; | |||
173 | ||||
174 | OptionalStorage &operator=(OptionalStorage const &other) = default; | |||
175 | OptionalStorage &operator=(OptionalStorage &&other) = default; | |||
176 | ||||
177 | template <class... Args> | |||
178 | constexpr explicit OptionalStorage(std::in_place_t, Args &&...args) | |||
179 | : val(std::forward<Args>(args)...), hasVal(true) {} | |||
| ||||
180 | ||||
181 | void reset() noexcept { | |||
182 | if (hasVal) { | |||
183 | val.~T(); | |||
184 | hasVal = false; | |||
185 | } | |||
186 | } | |||
187 | ||||
188 | constexpr bool has_value() const noexcept { return hasVal; } | |||
189 | ||||
190 | T &value() &noexcept { | |||
191 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 191, __extension__ __PRETTY_FUNCTION__)); | |||
192 | return val; | |||
193 | } | |||
194 | constexpr T const &value() const &noexcept { | |||
195 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 195, __extension__ __PRETTY_FUNCTION__)); | |||
196 | return val; | |||
197 | } | |||
198 | T &&value() &&noexcept { | |||
199 | assert(hasVal)(static_cast <bool> (hasVal) ? void (0) : __assert_fail ("hasVal", "llvm/include/llvm/ADT/Optional.h", 199, __extension__ __PRETTY_FUNCTION__)); | |||
200 | return std::move(val); | |||
201 | } | |||
202 | ||||
203 | template <class... Args> void emplace(Args &&...args) { | |||
204 | reset(); | |||
205 | ::new ((void *)std::addressof(val)) T(std::forward<Args>(args)...); | |||
206 | hasVal = true; | |||
207 | } | |||
208 | ||||
209 | OptionalStorage &operator=(T const &y) { | |||
210 | if (has_value()) { | |||
211 | val = y; | |||
212 | } else { | |||
213 | ::new ((void *)std::addressof(val)) T(y); | |||
214 | hasVal = true; | |||
215 | } | |||
216 | return *this; | |||
217 | } | |||
218 | OptionalStorage &operator=(T &&y) { | |||
219 | if (has_value()) { | |||
220 | val = std::move(y); | |||
221 | } else { | |||
222 | ::new ((void *)std::addressof(val)) T(std::move(y)); | |||
223 | hasVal = true; | |||
224 | } | |||
225 | return *this; | |||
226 | } | |||
227 | }; | |||
228 | ||||
229 | } // namespace optional_detail | |||
230 | ||||
231 | template <typename T> class Optional { | |||
232 | optional_detail::OptionalStorage<T> Storage; | |||
233 | ||||
234 | public: | |||
235 | using value_type = T; | |||
236 | ||||
237 | constexpr Optional() = default; | |||
238 | constexpr Optional(std::nullopt_t) {} | |||
239 | ||||
240 | constexpr Optional(const T &y) : Storage(std::in_place, y) {} | |||
241 | constexpr Optional(const Optional &O) = default; | |||
242 | ||||
243 | constexpr Optional(T &&y) : Storage(std::in_place, std::move(y)) {} | |||
244 | constexpr Optional(Optional &&O) = default; | |||
245 | ||||
246 | template <typename... ArgTypes> | |||
247 | constexpr Optional(std::in_place_t, ArgTypes &&...Args) | |||
248 | : Storage(std::in_place, std::forward<ArgTypes>(Args)...) {} | |||
249 | ||||
250 | Optional &operator=(T &&y) { | |||
251 | Storage = std::move(y); | |||
252 | return *this; | |||
253 | } | |||
254 | Optional &operator=(Optional &&O) = default; | |||
255 | ||||
256 | /// Create a new object by constructing it in place with the given arguments. | |||
257 | template <typename... ArgTypes> void emplace(ArgTypes &&... Args) { | |||
258 | Storage.emplace(std::forward<ArgTypes>(Args)...); | |||
259 | } | |||
260 | ||||
261 | Optional &operator=(const T &y) { | |||
262 | Storage = y; | |||
263 | return *this; | |||
264 | } | |||
265 | Optional &operator=(const Optional &O) = default; | |||
266 | ||||
267 | void reset() { Storage.reset(); } | |||
268 | ||||
269 | LLVM_DEPRECATED("Use &*X instead.", "&*X")__attribute__((deprecated("Use &*X instead.", "&*X")) ) | |||
270 | constexpr const T *getPointer() const { return &Storage.value(); } | |||
271 | LLVM_DEPRECATED("Use &*X instead.", "&*X")__attribute__((deprecated("Use &*X instead.", "&*X")) ) | |||
272 | T *getPointer() { return &Storage.value(); } | |||
273 | LLVM_DEPRECATED("std::optional::value is throwing. Use *X instead", "*X")__attribute__((deprecated("std::optional::value is throwing. Use *X instead" , "*X"))) | |||
274 | constexpr const T &value() const & { return Storage.value(); } | |||
275 | LLVM_DEPRECATED("std::optional::value is throwing. Use *X instead", "*X")__attribute__((deprecated("std::optional::value is throwing. Use *X instead" , "*X"))) | |||
276 | T &value() & { return Storage.value(); } | |||
277 | ||||
278 | constexpr explicit operator bool() const { return has_value(); } | |||
279 | constexpr bool has_value() const { return Storage.has_value(); } | |||
280 | constexpr const T *operator->() const { return &Storage.value(); } | |||
281 | T *operator->() { return &Storage.value(); } | |||
282 | constexpr const T &operator*() const & { return Storage.value(); } | |||
283 | T &operator*() & { return Storage.value(); } | |||
284 | ||||
285 | template <typename U> constexpr T value_or(U &&alt) const & { | |||
286 | return has_value() ? operator*() : std::forward<U>(alt); | |||
287 | } | |||
288 | ||||
289 | LLVM_DEPRECATED("std::optional::value is throwing. Use *X instead", "*X")__attribute__((deprecated("std::optional::value is throwing. Use *X instead" , "*X"))) | |||
290 | T &&value() && { return std::move(Storage.value()); } | |||
291 | T &&operator*() && { return std::move(Storage.value()); } | |||
292 | ||||
293 | template <typename U> T value_or(U &&alt) && { | |||
294 | return has_value() ? std::move(operator*()) : std::forward<U>(alt); | |||
295 | } | |||
296 | }; | |||
297 | ||||
298 | template<typename T> | |||
299 | Optional(const T&) -> Optional<T>; | |||
300 | ||||
301 | template <class T> llvm::hash_code hash_value(const Optional<T> &O) { | |||
302 | return O ? hash_combine(true, *O) : hash_value(false); | |||
303 | } | |||
304 | ||||
305 | template <typename T, typename U> | |||
306 | constexpr bool operator==(const Optional<T> &X, const Optional<U> &Y) { | |||
307 | if (X && Y) | |||
308 | return *X == *Y; | |||
309 | return X.has_value() == Y.has_value(); | |||
310 | } | |||
311 | ||||
312 | template <typename T, typename U> | |||
313 | constexpr bool operator!=(const Optional<T> &X, const Optional<U> &Y) { | |||
314 | return !(X == Y); | |||
315 | } | |||
316 | ||||
317 | template <typename T, typename U> | |||
318 | constexpr bool operator<(const Optional<T> &X, const Optional<U> &Y) { | |||
319 | if (X && Y) | |||
320 | return *X < *Y; | |||
321 | return X.has_value() < Y.has_value(); | |||
322 | } | |||
323 | ||||
324 | template <typename T, typename U> | |||
325 | constexpr bool operator<=(const Optional<T> &X, const Optional<U> &Y) { | |||
326 | return !(Y < X); | |||
327 | } | |||
328 | ||||
329 | template <typename T, typename U> | |||
330 | constexpr bool operator>(const Optional<T> &X, const Optional<U> &Y) { | |||
331 | return Y < X; | |||
332 | } | |||
333 | ||||
334 | template <typename T, typename U> | |||
335 | constexpr bool operator>=(const Optional<T> &X, const Optional<U> &Y) { | |||
336 | return !(X < Y); | |||
337 | } | |||
338 | ||||
339 | template <typename T> | |||
340 | constexpr bool operator==(const Optional<T> &X, std::nullopt_t) { | |||
341 | return !X; | |||
342 | } | |||
343 | ||||
344 | template <typename T> | |||
345 | constexpr bool operator==(std::nullopt_t, const Optional<T> &X) { | |||
346 | return X == std::nullopt; | |||
347 | } | |||
348 | ||||
349 | template <typename T> | |||
350 | constexpr bool operator!=(const Optional<T> &X, std::nullopt_t) { | |||
351 | return !(X == std::nullopt); | |||
352 | } | |||
353 | ||||
354 | template <typename T> | |||
355 | constexpr bool operator!=(std::nullopt_t, const Optional<T> &X) { | |||
356 | return X != std::nullopt; | |||
357 | } | |||
358 | ||||
359 | template <typename T> | |||
360 | constexpr bool operator<(const Optional<T> &, std::nullopt_t) { | |||
361 | return false; | |||
362 | } | |||
363 | ||||
364 | template <typename T> | |||
365 | constexpr bool operator<(std::nullopt_t, const Optional<T> &X) { | |||
366 | return X.has_value(); | |||
367 | } | |||
368 | ||||
369 | template <typename T> | |||
370 | constexpr bool operator<=(const Optional<T> &X, std::nullopt_t) { | |||
371 | return !(std::nullopt < X); | |||
372 | } | |||
373 | ||||
374 | template <typename T> | |||
375 | constexpr bool operator<=(std::nullopt_t, const Optional<T> &X) { | |||
376 | return !(X < std::nullopt); | |||
377 | } | |||
378 | ||||
379 | template <typename T> | |||
380 | constexpr bool operator>(const Optional<T> &X, std::nullopt_t) { | |||
381 | return std::nullopt < X; | |||
382 | } | |||
383 | ||||
384 | template <typename T> | |||
385 | constexpr bool operator>(std::nullopt_t, const Optional<T> &X) { | |||
386 | return X < std::nullopt; | |||
387 | } | |||
388 | ||||
389 | template <typename T> | |||
390 | constexpr bool operator>=(const Optional<T> &X, std::nullopt_t) { | |||
391 | return std::nullopt <= X; | |||
392 | } | |||
393 | ||||
394 | template <typename T> | |||
395 | constexpr bool operator>=(std::nullopt_t, const Optional<T> &X) { | |||
396 | return X <= std::nullopt; | |||
397 | } | |||
398 | ||||
399 | template <typename T> | |||
400 | constexpr bool operator==(const Optional<T> &X, const T &Y) { | |||
401 | return X && *X == Y; | |||
402 | } | |||
403 | ||||
404 | template <typename T> | |||
405 | constexpr bool operator==(const T &X, const Optional<T> &Y) { | |||
406 | return Y && X == *Y; | |||
407 | } | |||
408 | ||||
409 | template <typename T> | |||
410 | constexpr bool operator!=(const Optional<T> &X, const T &Y) { | |||
411 | return !(X == Y); | |||
412 | } | |||
413 | ||||
414 | template <typename T> | |||
415 | constexpr bool operator!=(const T &X, const Optional<T> &Y) { | |||
416 | return !(X == Y); | |||
417 | } | |||
418 | ||||
419 | template <typename T> | |||
420 | constexpr bool operator<(const Optional<T> &X, const T &Y) { | |||
421 | return !X || *X < Y; | |||
422 | } | |||
423 | ||||
424 | template <typename T> | |||
425 | constexpr bool operator<(const T &X, const Optional<T> &Y) { | |||
426 | return Y && X < *Y; | |||
427 | } | |||
428 | ||||
429 | template <typename T> | |||
430 | constexpr bool operator<=(const Optional<T> &X, const T &Y) { | |||
431 | return !(Y < X); | |||
432 | } | |||
433 | ||||
434 | template <typename T> | |||
435 | constexpr bool operator<=(const T &X, const Optional<T> &Y) { | |||
436 | return !(Y < X); | |||
437 | } | |||
438 | ||||
439 | template <typename T> | |||
440 | constexpr bool operator>(const Optional<T> &X, const T &Y) { | |||
441 | return Y < X; | |||
442 | } | |||
443 | ||||
444 | template <typename T> | |||
445 | constexpr bool operator>(const T &X, const Optional<T> &Y) { | |||
446 | return Y < X; | |||
447 | } | |||
448 | ||||
449 | template <typename T> | |||
450 | constexpr bool operator>=(const Optional<T> &X, const T &Y) { | |||
451 | return !(X < Y); | |||
452 | } | |||
453 | ||||
454 | template <typename T> | |||
455 | constexpr bool operator>=(const T &X, const Optional<T> &Y) { | |||
456 | return !(X < Y); | |||
457 | } | |||
458 | ||||
459 | } // end namespace llvm | |||
460 | ||||
461 | #endif // LLVM_ADT_OPTIONAL_H |