File: | build/source/mlir/lib/Bytecode/Reader/BytecodeReader.cpp |
Warning: | line 105, column 19 The left operand of '!=' is a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | // TODO: Support for big-endian architectures. | |||
10 | // TODO: Properly preserve use lists of values. | |||
11 | ||||
12 | #include "mlir/Bytecode/BytecodeReader.h" | |||
13 | #include "../Encoding.h" | |||
14 | #include "mlir/AsmParser/AsmParser.h" | |||
15 | #include "mlir/Bytecode/BytecodeImplementation.h" | |||
16 | #include "mlir/IR/BuiltinDialect.h" | |||
17 | #include "mlir/IR/BuiltinOps.h" | |||
18 | #include "mlir/IR/OpImplementation.h" | |||
19 | #include "mlir/IR/Verifier.h" | |||
20 | #include "llvm/ADT/MapVector.h" | |||
21 | #include "llvm/ADT/ScopeExit.h" | |||
22 | #include "llvm/ADT/SmallString.h" | |||
23 | #include "llvm/ADT/StringExtras.h" | |||
24 | #include "llvm/Support/MemoryBufferRef.h" | |||
25 | #include "llvm/Support/SaveAndRestore.h" | |||
26 | #include "llvm/Support/SourceMgr.h" | |||
27 | #include <optional> | |||
28 | ||||
29 | #define DEBUG_TYPE"mlir-bytecode-reader" "mlir-bytecode-reader" | |||
30 | ||||
31 | using namespace mlir; | |||
32 | ||||
33 | /// Stringify the given section ID. | |||
34 | static std::string toString(bytecode::Section::ID sectionID) { | |||
35 | switch (sectionID) { | |||
36 | case bytecode::Section::kString: | |||
37 | return "String (0)"; | |||
38 | case bytecode::Section::kDialect: | |||
39 | return "Dialect (1)"; | |||
40 | case bytecode::Section::kAttrType: | |||
41 | return "AttrType (2)"; | |||
42 | case bytecode::Section::kAttrTypeOffset: | |||
43 | return "AttrTypeOffset (3)"; | |||
44 | case bytecode::Section::kIR: | |||
45 | return "IR (4)"; | |||
46 | case bytecode::Section::kResource: | |||
47 | return "Resource (5)"; | |||
48 | case bytecode::Section::kResourceOffset: | |||
49 | return "ResourceOffset (6)"; | |||
50 | case bytecode::Section::kDialectVersions: | |||
51 | return "DialectVersions (7)"; | |||
52 | default: | |||
53 | return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); | |||
54 | } | |||
55 | } | |||
56 | ||||
57 | /// Returns true if the given top-level section ID is optional. | |||
58 | static bool isSectionOptional(bytecode::Section::ID sectionID) { | |||
59 | switch (sectionID) { | |||
60 | case bytecode::Section::kString: | |||
61 | case bytecode::Section::kDialect: | |||
62 | case bytecode::Section::kAttrType: | |||
63 | case bytecode::Section::kAttrTypeOffset: | |||
64 | case bytecode::Section::kIR: | |||
65 | return false; | |||
66 | case bytecode::Section::kResource: | |||
67 | case bytecode::Section::kResourceOffset: | |||
68 | case bytecode::Section::kDialectVersions: | |||
69 | return true; | |||
70 | default: | |||
71 | llvm_unreachable("unknown section ID")::llvm::llvm_unreachable_internal("unknown section ID", "mlir/lib/Bytecode/Reader/BytecodeReader.cpp" , 71); | |||
72 | } | |||
73 | } | |||
74 | ||||
75 | //===----------------------------------------------------------------------===// | |||
76 | // EncodingReader | |||
77 | //===----------------------------------------------------------------------===// | |||
78 | ||||
79 | namespace { | |||
80 | class EncodingReader { | |||
81 | public: | |||
82 | explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) | |||
83 | : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} | |||
84 | explicit EncodingReader(StringRef contents, Location fileLoc) | |||
85 | : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), | |||
86 | contents.size()}, | |||
87 | fileLoc) {} | |||
88 | ||||
89 | /// Returns true if the entire section has been read. | |||
90 | bool empty() const { return dataIt == dataEnd; } | |||
91 | ||||
92 | /// Returns the remaining size of the bytecode. | |||
93 | size_t size() const { return dataEnd - dataIt; } | |||
94 | ||||
95 | /// Align the current reader position to the specified alignment. | |||
96 | LogicalResult alignTo(unsigned alignment) { | |||
97 | if (!llvm::isPowerOf2_32(alignment)) | |||
98 | return emitError("expected alignment to be a power-of-two"); | |||
99 | ||||
100 | // Shift the reader position to the next alignment boundary. | |||
101 | while (uintptr_t(dataIt) & (uintptr_t(alignment) - 1)) { | |||
102 | uint8_t padding; | |||
103 | if (failed(parseByte(padding))) | |||
104 | return failure(); | |||
105 | if (padding != bytecode::kAlignmentByte) { | |||
| ||||
106 | return emitError("expected alignment byte (0xCB), but got: '0x" + | |||
107 | llvm::utohexstr(padding) + "'"); | |||
108 | } | |||
109 | } | |||
110 | ||||
111 | // Ensure the data iterator is now aligned. This case is unlikely because we | |||
112 | // *just* went through the effort to align the data iterator. | |||
113 | if (LLVM_UNLIKELY(!llvm::isAddrAligned(llvm::Align(alignment), dataIt))__builtin_expect((bool)(!llvm::isAddrAligned(llvm::Align(alignment ), dataIt)), false)) { | |||
114 | return emitError("expected data iterator aligned to ", alignment, | |||
115 | ", but got pointer: '0x" + | |||
116 | llvm::utohexstr((uintptr_t)dataIt) + "'"); | |||
117 | } | |||
118 | ||||
119 | return success(); | |||
120 | } | |||
121 | ||||
122 | /// Emit an error using the given arguments. | |||
123 | template <typename... Args> | |||
124 | InFlightDiagnostic emitError(Args &&...args) const { | |||
125 | return ::emitError(fileLoc).append(std::forward<Args>(args)...); | |||
126 | } | |||
127 | InFlightDiagnostic emitError() const { return ::emitError(fileLoc); } | |||
128 | ||||
129 | /// Parse a single byte from the stream. | |||
130 | template <typename T> | |||
131 | LogicalResult parseByte(T &value) { | |||
132 | if (empty()) | |||
133 | return emitError("attempting to parse a byte at the end of the bytecode"); | |||
134 | value = static_cast<T>(*dataIt++); | |||
135 | return success(); | |||
136 | } | |||
137 | /// Parse a range of bytes of 'length' into the given result. | |||
138 | LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { | |||
139 | if (length > size()) { | |||
140 | return emitError("attempting to parse ", length, " bytes when only ", | |||
141 | size(), " remain"); | |||
142 | } | |||
143 | result = {dataIt, length}; | |||
144 | dataIt += length; | |||
145 | return success(); | |||
146 | } | |||
147 | /// Parse a range of bytes of 'length' into the given result, which can be | |||
148 | /// assumed to be large enough to hold `length`. | |||
149 | LogicalResult parseBytes(size_t length, uint8_t *result) { | |||
150 | if (length > size()) { | |||
151 | return emitError("attempting to parse ", length, " bytes when only ", | |||
152 | size(), " remain"); | |||
153 | } | |||
154 | memcpy(result, dataIt, length); | |||
155 | dataIt += length; | |||
156 | return success(); | |||
157 | } | |||
158 | ||||
159 | /// Parse an aligned blob of data, where the alignment was encoded alongside | |||
160 | /// the data. | |||
161 | LogicalResult parseBlobAndAlignment(ArrayRef<uint8_t> &data, | |||
162 | uint64_t &alignment) { | |||
163 | uint64_t dataSize; | |||
164 | if (failed(parseVarInt(alignment)) || failed(parseVarInt(dataSize)) || | |||
165 | failed(alignTo(alignment))) | |||
166 | return failure(); | |||
167 | return parseBytes(dataSize, data); | |||
168 | } | |||
169 | ||||
170 | /// Parse a variable length encoded integer from the byte stream. The first | |||
171 | /// encoded byte contains a prefix in the low bits indicating the encoded | |||
172 | /// length of the value. This length prefix is a bit sequence of '0's followed | |||
173 | /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes | |||
174 | /// (not including the prefix byte). All remaining bits in the first byte, | |||
175 | /// along with all of the bits in additional bytes, provide the value of the | |||
176 | /// integer encoded in little-endian order. | |||
177 | LogicalResult parseVarInt(uint64_t &result) { | |||
178 | // Parse the first byte of the encoding, which contains the length prefix. | |||
179 | if (failed(parseByte(result))) | |||
180 | return failure(); | |||
181 | ||||
182 | // Handle the overwhelmingly common case where the value is stored in a | |||
183 | // single byte. In this case, the first bit is the `1` marker bit. | |||
184 | if (LLVM_LIKELY(result & 1)__builtin_expect((bool)(result & 1), true)) { | |||
185 | result >>= 1; | |||
186 | return success(); | |||
187 | } | |||
188 | ||||
189 | // Handle the overwhelming uncommon case where the value required all 8 | |||
190 | // bytes (i.e. a really really big number). In this case, the marker byte is | |||
191 | // all zeros: `00000000`. | |||
192 | if (LLVM_UNLIKELY(result == 0)__builtin_expect((bool)(result == 0), false)) | |||
193 | return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); | |||
194 | return parseMultiByteVarInt(result); | |||
195 | } | |||
196 | ||||
197 | /// Parse a signed variable length encoded integer from the byte stream. A | |||
198 | /// signed varint is encoded as a normal varint with zigzag encoding applied, | |||
199 | /// i.e. the low bit of the value is used to indicate the sign. | |||
200 | LogicalResult parseSignedVarInt(uint64_t &result) { | |||
201 | if (failed(parseVarInt(result))) | |||
202 | return failure(); | |||
203 | // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) | |||
204 | result = (result >> 1) ^ (~(result & 1) + 1); | |||
205 | return success(); | |||
206 | } | |||
207 | ||||
208 | /// Parse a variable length encoded integer whose low bit is used to encode an | |||
209 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. | |||
210 | LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { | |||
211 | if (failed(parseVarInt(result))) | |||
212 | return failure(); | |||
213 | flag = result & 1; | |||
214 | result >>= 1; | |||
215 | return success(); | |||
216 | } | |||
217 | ||||
218 | /// Skip the first `length` bytes within the reader. | |||
219 | LogicalResult skipBytes(size_t length) { | |||
220 | if (length > size()) { | |||
221 | return emitError("attempting to skip ", length, " bytes when only ", | |||
222 | size(), " remain"); | |||
223 | } | |||
224 | dataIt += length; | |||
225 | return success(); | |||
226 | } | |||
227 | ||||
228 | /// Parse a null-terminated string into `result` (without including the NUL | |||
229 | /// terminator). | |||
230 | LogicalResult parseNullTerminatedString(StringRef &result) { | |||
231 | const char *startIt = (const char *)dataIt; | |||
232 | const char *nulIt = (const char *)memchr(startIt, 0, size()); | |||
233 | if (!nulIt) | |||
234 | return emitError( | |||
235 | "malformed null-terminated string, no null character found"); | |||
236 | ||||
237 | result = StringRef(startIt, nulIt - startIt); | |||
238 | dataIt = (const uint8_t *)nulIt + 1; | |||
239 | return success(); | |||
240 | } | |||
241 | ||||
242 | /// Parse a section header, placing the kind of section in `sectionID` and the | |||
243 | /// contents of the section in `sectionData`. | |||
244 | LogicalResult parseSection(bytecode::Section::ID §ionID, | |||
245 | ArrayRef<uint8_t> §ionData) { | |||
246 | uint8_t sectionIDAndHasAlignment; | |||
247 | uint64_t length; | |||
248 | if (failed(parseByte(sectionIDAndHasAlignment)) || | |||
249 | failed(parseVarInt(length))) | |||
250 | return failure(); | |||
251 | ||||
252 | // Extract the section ID and whether the section is aligned. The high bit | |||
253 | // of the ID is the alignment flag. | |||
254 | sectionID = static_cast<bytecode::Section::ID>(sectionIDAndHasAlignment & | |||
255 | 0b01111111); | |||
256 | bool hasAlignment = sectionIDAndHasAlignment & 0b10000000; | |||
257 | ||||
258 | // Check that the section is actually valid before trying to process its | |||
259 | // data. | |||
260 | if (sectionID >= bytecode::Section::kNumSections) | |||
261 | return emitError("invalid section ID: ", unsigned(sectionID)); | |||
262 | ||||
263 | // Process the section alignment if present. | |||
264 | if (hasAlignment) { | |||
265 | uint64_t alignment; | |||
266 | if (failed(parseVarInt(alignment)) || failed(alignTo(alignment))) | |||
267 | return failure(); | |||
268 | } | |||
269 | ||||
270 | // Parse the actual section data. | |||
271 | return parseBytes(static_cast<size_t>(length), sectionData); | |||
272 | } | |||
273 | ||||
274 | private: | |||
275 | /// Parse a variable length encoded integer from the byte stream. This method | |||
276 | /// is a fallback when the number of bytes used to encode the value is greater | |||
277 | /// than 1, but less than the max (9). The provided `result` value can be | |||
278 | /// assumed to already contain the first byte of the value. | |||
279 | /// NOTE: This method is marked noinline to avoid pessimizing the common case | |||
280 | /// of single byte encoding. | |||
281 | LLVM_ATTRIBUTE_NOINLINE__attribute__((noinline)) LogicalResult parseMultiByteVarInt(uint64_t &result) { | |||
282 | // Count the number of trailing zeros in the marker byte, this indicates the | |||
283 | // number of trailing bytes that are part of the value. We use `uint32_t` | |||
284 | // here because we only care about the first byte, and so that be actually | |||
285 | // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop | |||
286 | // implementation). | |||
287 | uint32_t numBytes = llvm::countr_zero<uint32_t>(result); | |||
288 | assert(numBytes > 0 && numBytes <= 7 &&(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 289, __extension__ __PRETTY_FUNCTION__)) | |||
289 | "unexpected number of trailing zeros in varint encoding")(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 289, __extension__ __PRETTY_FUNCTION__)); | |||
290 | ||||
291 | // Parse in the remaining bytes of the value. | |||
292 | if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) | |||
293 | return failure(); | |||
294 | ||||
295 | // Shift out the low-order bits that were used to mark how the value was | |||
296 | // encoded. | |||
297 | result >>= (numBytes + 1); | |||
298 | return success(); | |||
299 | } | |||
300 | ||||
301 | /// The current data iterator, and an iterator to the end of the buffer. | |||
302 | const uint8_t *dataIt, *dataEnd; | |||
303 | ||||
304 | /// A location for the bytecode used to report errors. | |||
305 | Location fileLoc; | |||
306 | }; | |||
307 | } // namespace | |||
308 | ||||
309 | /// Resolve an index into the given entry list. `entry` may either be a | |||
310 | /// reference, in which case it is assigned to the corresponding value in | |||
311 | /// `entries`, or a pointer, in which case it is assigned to the address of the | |||
312 | /// element in `entries`. | |||
313 | template <typename RangeT, typename T> | |||
314 | static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, | |||
315 | uint64_t index, T &entry, | |||
316 | StringRef entryStr) { | |||
317 | if (index >= entries.size()) | |||
318 | return reader.emitError("invalid ", entryStr, " index: ", index); | |||
319 | ||||
320 | // If the provided entry is a pointer, resolve to the address of the entry. | |||
321 | if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) | |||
322 | entry = entries[index]; | |||
323 | else | |||
324 | entry = &entries[index]; | |||
325 | return success(); | |||
326 | } | |||
327 | ||||
328 | /// Parse and resolve an index into the given entry list. | |||
329 | template <typename RangeT, typename T> | |||
330 | static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, | |||
331 | T &entry, StringRef entryStr) { | |||
332 | uint64_t entryIdx; | |||
333 | if (failed(reader.parseVarInt(entryIdx))) | |||
334 | return failure(); | |||
335 | return resolveEntry(reader, entries, entryIdx, entry, entryStr); | |||
336 | } | |||
337 | ||||
338 | //===----------------------------------------------------------------------===// | |||
339 | // StringSectionReader | |||
340 | //===----------------------------------------------------------------------===// | |||
341 | ||||
342 | namespace { | |||
343 | /// This class is used to read references to the string section from the | |||
344 | /// bytecode. | |||
345 | class StringSectionReader { | |||
346 | public: | |||
347 | /// Initialize the string section reader with the given section data. | |||
348 | LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); | |||
349 | ||||
350 | /// Parse a shared string from the string section. The shared string is | |||
351 | /// encoded using an index to a corresponding string in the string section. | |||
352 | LogicalResult parseString(EncodingReader &reader, StringRef &result) { | |||
353 | return parseEntry(reader, strings, result, "string"); | |||
354 | } | |||
355 | ||||
356 | /// Parse a shared string from the string section. The shared string is | |||
357 | /// encoded using an index to a corresponding string in the string section. | |||
358 | LogicalResult parseStringAtIndex(EncodingReader &reader, uint64_t index, | |||
359 | StringRef &result) { | |||
360 | return resolveEntry(reader, strings, index, result, "string"); | |||
361 | } | |||
362 | ||||
363 | private: | |||
364 | /// The table of strings referenced within the bytecode file. | |||
365 | SmallVector<StringRef> strings; | |||
366 | }; | |||
367 | } // namespace | |||
368 | ||||
369 | LogicalResult StringSectionReader::initialize(Location fileLoc, | |||
370 | ArrayRef<uint8_t> sectionData) { | |||
371 | EncodingReader stringReader(sectionData, fileLoc); | |||
372 | ||||
373 | // Parse the number of strings in the section. | |||
374 | uint64_t numStrings; | |||
375 | if (failed(stringReader.parseVarInt(numStrings))) | |||
376 | return failure(); | |||
377 | strings.resize(numStrings); | |||
378 | ||||
379 | // Parse each of the strings. The sizes of the strings are encoded in reverse | |||
380 | // order, so that's the order we populate the table. | |||
381 | size_t stringDataEndOffset = sectionData.size(); | |||
382 | for (StringRef &string : llvm::reverse(strings)) { | |||
383 | uint64_t stringSize; | |||
384 | if (failed(stringReader.parseVarInt(stringSize))) | |||
385 | return failure(); | |||
386 | if (stringDataEndOffset < stringSize) { | |||
387 | return stringReader.emitError( | |||
388 | "string size exceeds the available data size"); | |||
389 | } | |||
390 | ||||
391 | // Extract the string from the data, dropping the null character. | |||
392 | size_t stringOffset = stringDataEndOffset - stringSize; | |||
393 | string = StringRef( | |||
394 | reinterpret_cast<const char *>(sectionData.data() + stringOffset), | |||
395 | stringSize - 1); | |||
396 | stringDataEndOffset = stringOffset; | |||
397 | } | |||
398 | ||||
399 | // Check that the only remaining data was for the strings, i.e. the reader | |||
400 | // should be at the same offset as the first string. | |||
401 | if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { | |||
402 | return stringReader.emitError("unexpected trailing data between the " | |||
403 | "offsets for strings and their data"); | |||
404 | } | |||
405 | return success(); | |||
406 | } | |||
407 | ||||
408 | //===----------------------------------------------------------------------===// | |||
409 | // BytecodeDialect | |||
410 | //===----------------------------------------------------------------------===// | |||
411 | ||||
412 | namespace { | |||
413 | class DialectReader; | |||
414 | ||||
415 | /// This struct represents a dialect entry within the bytecode. | |||
416 | struct BytecodeDialect { | |||
417 | /// Load the dialect into the provided context if it hasn't been loaded yet. | |||
418 | /// Returns failure if the dialect couldn't be loaded *and* the provided | |||
419 | /// context does not allow unregistered dialects. The provided reader is used | |||
420 | /// for error emission if necessary. | |||
421 | LogicalResult load(DialectReader &reader, MLIRContext *ctx); | |||
422 | ||||
423 | /// Return the loaded dialect, or nullptr if the dialect is unknown. This can | |||
424 | /// only be called after `load`. | |||
425 | Dialect *getLoadedDialect() const { | |||
426 | assert(dialect &&(static_cast <bool> (dialect && "expected `load` to be invoked before `getLoadedDialect`" ) ? void (0) : __assert_fail ("dialect && \"expected `load` to be invoked before `getLoadedDialect`\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 427, __extension__ __PRETTY_FUNCTION__)) | |||
427 | "expected `load` to be invoked before `getLoadedDialect`")(static_cast <bool> (dialect && "expected `load` to be invoked before `getLoadedDialect`" ) ? void (0) : __assert_fail ("dialect && \"expected `load` to be invoked before `getLoadedDialect`\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 427, __extension__ __PRETTY_FUNCTION__)); | |||
428 | return *dialect; | |||
429 | } | |||
430 | ||||
431 | /// The loaded dialect entry. This field is std::nullopt if we haven't | |||
432 | /// attempted to load, nullptr if we failed to load, otherwise the loaded | |||
433 | /// dialect. | |||
434 | std::optional<Dialect *> dialect; | |||
435 | ||||
436 | /// The bytecode interface of the dialect, or nullptr if the dialect does not | |||
437 | /// implement the bytecode interface. This field should only be checked if the | |||
438 | /// `dialect` field is not std::nullopt. | |||
439 | const BytecodeDialectInterface *interface = nullptr; | |||
440 | ||||
441 | /// The name of the dialect. | |||
442 | StringRef name; | |||
443 | ||||
444 | /// A buffer containing the encoding of the dialect version parsed. | |||
445 | ArrayRef<uint8_t> versionBuffer; | |||
446 | ||||
447 | /// Lazy loaded dialect version from the handle above. | |||
448 | std::unique_ptr<DialectVersion> loadedVersion; | |||
449 | }; | |||
450 | ||||
451 | /// This struct represents an operation name entry within the bytecode. | |||
452 | struct BytecodeOperationName { | |||
453 | BytecodeOperationName(BytecodeDialect *dialect, StringRef name) | |||
454 | : dialect(dialect), name(name) {} | |||
455 | ||||
456 | /// The loaded operation name, or std::nullopt if it hasn't been processed | |||
457 | /// yet. | |||
458 | std::optional<OperationName> opName; | |||
459 | ||||
460 | /// The dialect that owns this operation name. | |||
461 | BytecodeDialect *dialect; | |||
462 | ||||
463 | /// The name of the operation, without the dialect prefix. | |||
464 | StringRef name; | |||
465 | }; | |||
466 | } // namespace | |||
467 | ||||
468 | /// Parse a single dialect group encoded in the byte stream. | |||
469 | static LogicalResult parseDialectGrouping( | |||
470 | EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, | |||
471 | function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { | |||
472 | // Parse the dialect and the number of entries in the group. | |||
473 | BytecodeDialect *dialect; | |||
474 | if (failed(parseEntry(reader, dialects, dialect, "dialect"))) | |||
475 | return failure(); | |||
476 | uint64_t numEntries; | |||
477 | if (failed(reader.parseVarInt(numEntries))) | |||
478 | return failure(); | |||
479 | ||||
480 | for (uint64_t i = 0; i < numEntries; ++i) | |||
481 | if (failed(entryCallback(dialect))) | |||
482 | return failure(); | |||
483 | return success(); | |||
484 | } | |||
485 | ||||
486 | //===----------------------------------------------------------------------===// | |||
487 | // ResourceSectionReader | |||
488 | //===----------------------------------------------------------------------===// | |||
489 | ||||
490 | namespace { | |||
491 | /// This class is used to read the resource section from the bytecode. | |||
492 | class ResourceSectionReader { | |||
493 | public: | |||
494 | /// Initialize the resource section reader with the given section data. | |||
495 | LogicalResult | |||
496 | initialize(Location fileLoc, const ParserConfig &config, | |||
497 | MutableArrayRef<BytecodeDialect> dialects, | |||
498 | StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, | |||
499 | ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, | |||
500 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef); | |||
501 | ||||
502 | /// Parse a dialect resource handle from the resource section. | |||
503 | LogicalResult parseResourceHandle(EncodingReader &reader, | |||
504 | AsmDialectResourceHandle &result) { | |||
505 | return parseEntry(reader, dialectResources, result, "resource handle"); | |||
506 | } | |||
507 | ||||
508 | private: | |||
509 | /// The table of dialect resources within the bytecode file. | |||
510 | SmallVector<AsmDialectResourceHandle> dialectResources; | |||
511 | }; | |||
512 | ||||
513 | class ParsedResourceEntry : public AsmParsedResourceEntry { | |||
514 | public: | |||
515 | ParsedResourceEntry(StringRef key, AsmResourceEntryKind kind, | |||
516 | EncodingReader &reader, StringSectionReader &stringReader, | |||
517 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) | |||
518 | : key(key), kind(kind), reader(reader), stringReader(stringReader), | |||
519 | bufferOwnerRef(bufferOwnerRef) {} | |||
520 | ~ParsedResourceEntry() override = default; | |||
521 | ||||
522 | StringRef getKey() const final { return key; } | |||
523 | ||||
524 | InFlightDiagnostic emitError() const final { return reader.emitError(); } | |||
525 | ||||
526 | AsmResourceEntryKind getKind() const final { return kind; } | |||
527 | ||||
528 | FailureOr<bool> parseAsBool() const final { | |||
529 | if (kind != AsmResourceEntryKind::Bool) | |||
530 | return emitError() << "expected a bool resource entry, but found a " | |||
531 | << toString(kind) << " entry instead"; | |||
532 | ||||
533 | bool value; | |||
534 | if (failed(reader.parseByte(value))) | |||
535 | return failure(); | |||
536 | return value; | |||
537 | } | |||
538 | FailureOr<std::string> parseAsString() const final { | |||
539 | if (kind != AsmResourceEntryKind::String) | |||
540 | return emitError() << "expected a string resource entry, but found a " | |||
541 | << toString(kind) << " entry instead"; | |||
542 | ||||
543 | StringRef string; | |||
544 | if (failed(stringReader.parseString(reader, string))) | |||
545 | return failure(); | |||
546 | return string.str(); | |||
547 | } | |||
548 | ||||
549 | FailureOr<AsmResourceBlob> | |||
550 | parseAsBlob(BlobAllocatorFn allocator) const final { | |||
551 | if (kind != AsmResourceEntryKind::Blob) | |||
552 | return emitError() << "expected a blob resource entry, but found a " | |||
553 | << toString(kind) << " entry instead"; | |||
554 | ||||
555 | ArrayRef<uint8_t> data; | |||
556 | uint64_t alignment; | |||
557 | if (failed(reader.parseBlobAndAlignment(data, alignment))) | |||
558 | return failure(); | |||
559 | ||||
560 | // If we have an extendable reference to the buffer owner, we don't need to | |||
561 | // allocate a new buffer for the data, and can use the data directly. | |||
562 | if (bufferOwnerRef) { | |||
563 | ArrayRef<char> charData(reinterpret_cast<const char *>(data.data()), | |||
564 | data.size()); | |||
565 | ||||
566 | // Allocate an unmanager buffer which captures a reference to the owner. | |||
567 | // For now we just mark this as immutable, but in the future we should | |||
568 | // explore marking this as mutable when desired. | |||
569 | return UnmanagedAsmResourceBlob::allocateWithAlign( | |||
570 | charData, alignment, | |||
571 | [bufferOwnerRef = bufferOwnerRef](void *, size_t, size_t) {}); | |||
572 | } | |||
573 | ||||
574 | // Allocate memory for the blob using the provided allocator and copy the | |||
575 | // data into it. | |||
576 | AsmResourceBlob blob = allocator(data.size(), alignment); | |||
577 | assert(llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) &&(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 579, __extension__ __PRETTY_FUNCTION__)) | |||
578 | blob.isMutable() &&(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 579, __extension__ __PRETTY_FUNCTION__)) | |||
579 | "blob allocator did not return a properly aligned address")(static_cast <bool> (llvm::isAddrAligned(llvm::Align(alignment ), blob.getData().data()) && blob.isMutable() && "blob allocator did not return a properly aligned address") ? void (0) : __assert_fail ("llvm::isAddrAligned(llvm::Align(alignment), blob.getData().data()) && blob.isMutable() && \"blob allocator did not return a properly aligned address\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 579, __extension__ __PRETTY_FUNCTION__)); | |||
580 | memcpy(blob.getMutableData().data(), data.data(), data.size()); | |||
581 | return blob; | |||
582 | } | |||
583 | ||||
584 | private: | |||
585 | StringRef key; | |||
586 | AsmResourceEntryKind kind; | |||
587 | EncodingReader &reader; | |||
588 | StringSectionReader &stringReader; | |||
589 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; | |||
590 | }; | |||
591 | } // namespace | |||
592 | ||||
593 | template <typename T> | |||
594 | static LogicalResult | |||
595 | parseResourceGroup(Location fileLoc, bool allowEmpty, | |||
596 | EncodingReader &offsetReader, EncodingReader &resourceReader, | |||
597 | StringSectionReader &stringReader, T *handler, | |||
598 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef, | |||
599 | function_ref<LogicalResult(StringRef)> processKeyFn = {}) { | |||
600 | uint64_t numResources; | |||
601 | if (failed(offsetReader.parseVarInt(numResources))) | |||
602 | return failure(); | |||
603 | ||||
604 | for (uint64_t i = 0; i < numResources; ++i) { | |||
605 | StringRef key; | |||
606 | AsmResourceEntryKind kind; | |||
607 | uint64_t resourceOffset; | |||
608 | ArrayRef<uint8_t> data; | |||
609 | if (failed(stringReader.parseString(offsetReader, key)) || | |||
610 | failed(offsetReader.parseVarInt(resourceOffset)) || | |||
611 | failed(offsetReader.parseByte(kind)) || | |||
612 | failed(resourceReader.parseBytes(resourceOffset, data))) | |||
613 | return failure(); | |||
614 | ||||
615 | // Process the resource key. | |||
616 | if ((processKeyFn && failed(processKeyFn(key)))) | |||
617 | return failure(); | |||
618 | ||||
619 | // If the resource data is empty and we allow it, don't error out when | |||
620 | // parsing below, just skip it. | |||
621 | if (allowEmpty && data.empty()) | |||
622 | continue; | |||
623 | ||||
624 | // Ignore the entry if we don't have a valid handler. | |||
625 | if (!handler) | |||
626 | continue; | |||
627 | ||||
628 | // Otherwise, parse the resource value. | |||
629 | EncodingReader entryReader(data, fileLoc); | |||
630 | ParsedResourceEntry entry(key, kind, entryReader, stringReader, | |||
631 | bufferOwnerRef); | |||
632 | if (failed(handler->parseResource(entry))) | |||
633 | return failure(); | |||
634 | if (!entryReader.empty()) { | |||
635 | return entryReader.emitError( | |||
636 | "unexpected trailing bytes in resource entry '", key, "'"); | |||
637 | } | |||
638 | } | |||
639 | return success(); | |||
640 | } | |||
641 | ||||
642 | LogicalResult ResourceSectionReader::initialize( | |||
643 | Location fileLoc, const ParserConfig &config, | |||
644 | MutableArrayRef<BytecodeDialect> dialects, | |||
645 | StringSectionReader &stringReader, ArrayRef<uint8_t> sectionData, | |||
646 | ArrayRef<uint8_t> offsetSectionData, DialectReader &dialectReader, | |||
647 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { | |||
648 | EncodingReader resourceReader(sectionData, fileLoc); | |||
649 | EncodingReader offsetReader(offsetSectionData, fileLoc); | |||
650 | ||||
651 | // Read the number of external resource providers. | |||
652 | uint64_t numExternalResourceGroups; | |||
653 | if (failed(offsetReader.parseVarInt(numExternalResourceGroups))) | |||
654 | return failure(); | |||
655 | ||||
656 | // Utility functor that dispatches to `parseResourceGroup`, but implicitly | |||
657 | // provides most of the arguments. | |||
658 | auto parseGroup = [&](auto *handler, bool allowEmpty = false, | |||
659 | function_ref<LogicalResult(StringRef)> keyFn = {}) { | |||
660 | return parseResourceGroup(fileLoc, allowEmpty, offsetReader, resourceReader, | |||
661 | stringReader, handler, bufferOwnerRef, keyFn); | |||
662 | }; | |||
663 | ||||
664 | // Read the external resources from the bytecode. | |||
665 | for (uint64_t i = 0; i < numExternalResourceGroups; ++i) { | |||
666 | StringRef key; | |||
667 | if (failed(stringReader.parseString(offsetReader, key))) | |||
668 | return failure(); | |||
669 | ||||
670 | // Get the handler for these resources. | |||
671 | // TODO: Should we require handling external resources in some scenarios? | |||
672 | AsmResourceParser *handler = config.getResourceParser(key); | |||
673 | if (!handler) { | |||
674 | emitWarning(fileLoc) << "ignoring unknown external resources for '" << key | |||
675 | << "'"; | |||
676 | } | |||
677 | ||||
678 | if (failed(parseGroup(handler))) | |||
679 | return failure(); | |||
680 | } | |||
681 | ||||
682 | // Read the dialect resources from the bytecode. | |||
683 | MLIRContext *ctx = fileLoc->getContext(); | |||
684 | while (!offsetReader.empty()) { | |||
685 | BytecodeDialect *dialect; | |||
686 | if (failed(parseEntry(offsetReader, dialects, dialect, "dialect")) || | |||
687 | failed(dialect->load(dialectReader, ctx))) | |||
688 | return failure(); | |||
689 | Dialect *loadedDialect = dialect->getLoadedDialect(); | |||
690 | if (!loadedDialect) { | |||
691 | return resourceReader.emitError() | |||
692 | << "dialect '" << dialect->name << "' is unknown"; | |||
693 | } | |||
694 | const auto *handler = dyn_cast<OpAsmDialectInterface>(loadedDialect); | |||
695 | if (!handler) { | |||
696 | return resourceReader.emitError() | |||
697 | << "unexpected resources for dialect '" << dialect->name << "'"; | |||
698 | } | |||
699 | ||||
700 | // Ensure that each resource is declared before being processed. | |||
701 | auto processResourceKeyFn = [&](StringRef key) -> LogicalResult { | |||
702 | FailureOr<AsmDialectResourceHandle> handle = | |||
703 | handler->declareResource(key); | |||
704 | if (failed(handle)) { | |||
705 | return resourceReader.emitError() | |||
706 | << "unknown 'resource' key '" << key << "' for dialect '" | |||
707 | << dialect->name << "'"; | |||
708 | } | |||
709 | dialectResources.push_back(*handle); | |||
710 | return success(); | |||
711 | }; | |||
712 | ||||
713 | // Parse the resources for this dialect. We allow empty resources because we | |||
714 | // just treat these as declarations. | |||
715 | if (failed(parseGroup(handler, /*allowEmpty=*/true, processResourceKeyFn))) | |||
716 | return failure(); | |||
717 | } | |||
718 | ||||
719 | return success(); | |||
720 | } | |||
721 | ||||
722 | //===----------------------------------------------------------------------===// | |||
723 | // Attribute/Type Reader | |||
724 | //===----------------------------------------------------------------------===// | |||
725 | ||||
726 | namespace { | |||
727 | /// This class provides support for reading attribute and type entries from the | |||
728 | /// bytecode. Attribute and Type entries are read lazily on demand, so we use | |||
729 | /// this reader to manage when to actually parse them from the bytecode. | |||
730 | class AttrTypeReader { | |||
731 | /// This class represents a single attribute or type entry. | |||
732 | template <typename T> | |||
733 | struct Entry { | |||
734 | /// The entry, or null if it hasn't been resolved yet. | |||
735 | T entry = {}; | |||
736 | /// The parent dialect of this entry. | |||
737 | BytecodeDialect *dialect = nullptr; | |||
738 | /// A flag indicating if the entry was encoded using a custom encoding, | |||
739 | /// instead of using the textual assembly format. | |||
740 | bool hasCustomEncoding = false; | |||
741 | /// The raw data of this entry in the bytecode. | |||
742 | ArrayRef<uint8_t> data; | |||
743 | }; | |||
744 | using AttrEntry = Entry<Attribute>; | |||
745 | using TypeEntry = Entry<Type>; | |||
746 | ||||
747 | public: | |||
748 | AttrTypeReader(StringSectionReader &stringReader, | |||
749 | ResourceSectionReader &resourceReader, Location fileLoc) | |||
750 | : stringReader(stringReader), resourceReader(resourceReader), | |||
751 | fileLoc(fileLoc) {} | |||
752 | ||||
753 | /// Initialize the attribute and type information within the reader. | |||
754 | LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
755 | ArrayRef<uint8_t> sectionData, | |||
756 | ArrayRef<uint8_t> offsetSectionData); | |||
757 | ||||
758 | /// Resolve the attribute or type at the given index. Returns nullptr on | |||
759 | /// failure. | |||
760 | Attribute resolveAttribute(size_t index) { | |||
761 | return resolveEntry(attributes, index, "Attribute"); | |||
762 | } | |||
763 | Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } | |||
764 | ||||
765 | /// Parse a reference to an attribute or type using the given reader. | |||
766 | LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { | |||
767 | uint64_t attrIdx; | |||
768 | if (failed(reader.parseVarInt(attrIdx))) | |||
769 | return failure(); | |||
770 | result = resolveAttribute(attrIdx); | |||
771 | return success(!!result); | |||
772 | } | |||
773 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
774 | uint64_t typeIdx; | |||
775 | if (failed(reader.parseVarInt(typeIdx))) | |||
776 | return failure(); | |||
777 | result = resolveType(typeIdx); | |||
778 | return success(!!result); | |||
779 | } | |||
780 | ||||
781 | template <typename T> | |||
782 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
783 | Attribute baseResult; | |||
784 | if (failed(parseAttribute(reader, baseResult))) | |||
785 | return failure(); | |||
786 | if ((result = baseResult.dyn_cast<T>())) | |||
787 | return success(); | |||
788 | return reader.emitError("expected attribute of type: ", | |||
789 | llvm::getTypeName<T>(), ", but got: ", baseResult); | |||
790 | } | |||
791 | ||||
792 | private: | |||
793 | /// Resolve the given entry at `index`. | |||
794 | template <typename T> | |||
795 | T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
796 | StringRef entryType); | |||
797 | ||||
798 | /// Parse an entry using the given reader that was encoded using the textual | |||
799 | /// assembly format. | |||
800 | template <typename T> | |||
801 | LogicalResult parseAsmEntry(T &result, EncodingReader &reader, | |||
802 | StringRef entryType); | |||
803 | ||||
804 | /// Parse an entry using the given reader that was encoded using a custom | |||
805 | /// bytecode format. | |||
806 | template <typename T> | |||
807 | LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, | |||
808 | StringRef entryType); | |||
809 | ||||
810 | /// The string section reader used to resolve string references when parsing | |||
811 | /// custom encoded attribute/type entries. | |||
812 | StringSectionReader &stringReader; | |||
813 | ||||
814 | /// The resource section reader used to resolve resource references when | |||
815 | /// parsing custom encoded attribute/type entries. | |||
816 | ResourceSectionReader &resourceReader; | |||
817 | ||||
818 | /// The set of attribute and type entries. | |||
819 | SmallVector<AttrEntry> attributes; | |||
820 | SmallVector<TypeEntry> types; | |||
821 | ||||
822 | /// A location used for error emission. | |||
823 | Location fileLoc; | |||
824 | }; | |||
825 | ||||
826 | class DialectReader : public DialectBytecodeReader { | |||
827 | public: | |||
828 | DialectReader(AttrTypeReader &attrTypeReader, | |||
829 | StringSectionReader &stringReader, | |||
830 | ResourceSectionReader &resourceReader, EncodingReader &reader) | |||
831 | : attrTypeReader(attrTypeReader), stringReader(stringReader), | |||
832 | resourceReader(resourceReader), reader(reader) {} | |||
833 | ||||
834 | InFlightDiagnostic emitError(const Twine &msg) override { | |||
835 | return reader.emitError(msg); | |||
836 | } | |||
837 | ||||
838 | //===--------------------------------------------------------------------===// | |||
839 | // IR | |||
840 | //===--------------------------------------------------------------------===// | |||
841 | ||||
842 | LogicalResult readAttribute(Attribute &result) override { | |||
843 | return attrTypeReader.parseAttribute(reader, result); | |||
844 | } | |||
845 | ||||
846 | LogicalResult readType(Type &result) override { | |||
847 | return attrTypeReader.parseType(reader, result); | |||
848 | } | |||
849 | ||||
850 | FailureOr<AsmDialectResourceHandle> readResourceHandle() override { | |||
851 | AsmDialectResourceHandle handle; | |||
852 | if (failed(resourceReader.parseResourceHandle(reader, handle))) | |||
853 | return failure(); | |||
854 | return handle; | |||
855 | } | |||
856 | ||||
857 | //===--------------------------------------------------------------------===// | |||
858 | // Primitives | |||
859 | //===--------------------------------------------------------------------===// | |||
860 | ||||
861 | LogicalResult readVarInt(uint64_t &result) override { | |||
862 | return reader.parseVarInt(result); | |||
863 | } | |||
864 | ||||
865 | LogicalResult readSignedVarInt(int64_t &result) override { | |||
866 | uint64_t unsignedResult; | |||
867 | if (failed(reader.parseSignedVarInt(unsignedResult))) | |||
868 | return failure(); | |||
869 | result = static_cast<int64_t>(unsignedResult); | |||
870 | return success(); | |||
871 | } | |||
872 | ||||
873 | FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { | |||
874 | // Small values are encoded using a single byte. | |||
875 | if (bitWidth <= 8) { | |||
876 | uint8_t value; | |||
877 | if (failed(reader.parseByte(value))) | |||
878 | return failure(); | |||
879 | return APInt(bitWidth, value); | |||
880 | } | |||
881 | ||||
882 | // Large values up to 64 bits are encoded using a single varint. | |||
883 | if (bitWidth <= 64) { | |||
884 | uint64_t value; | |||
885 | if (failed(reader.parseSignedVarInt(value))) | |||
886 | return failure(); | |||
887 | return APInt(bitWidth, value); | |||
888 | } | |||
889 | ||||
890 | // Otherwise, for really big values we encode the array of active words in | |||
891 | // the value. | |||
892 | uint64_t numActiveWords; | |||
893 | if (failed(reader.parseVarInt(numActiveWords))) | |||
894 | return failure(); | |||
895 | SmallVector<uint64_t, 4> words(numActiveWords); | |||
896 | for (uint64_t i = 0; i < numActiveWords; ++i) | |||
897 | if (failed(reader.parseSignedVarInt(words[i]))) | |||
898 | return failure(); | |||
899 | return APInt(bitWidth, words); | |||
900 | } | |||
901 | ||||
902 | FailureOr<APFloat> | |||
903 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { | |||
904 | FailureOr<APInt> intVal = | |||
905 | readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); | |||
906 | if (failed(intVal)) | |||
907 | return failure(); | |||
908 | return APFloat(semantics, *intVal); | |||
909 | } | |||
910 | ||||
911 | LogicalResult readString(StringRef &result) override { | |||
912 | return stringReader.parseString(reader, result); | |||
913 | } | |||
914 | ||||
915 | LogicalResult readBlob(ArrayRef<char> &result) override { | |||
916 | uint64_t dataSize; | |||
917 | ArrayRef<uint8_t> data; | |||
918 | if (failed(reader.parseVarInt(dataSize)) || | |||
919 | failed(reader.parseBytes(dataSize, data))) | |||
920 | return failure(); | |||
921 | result = llvm::ArrayRef(reinterpret_cast<const char *>(data.data()), | |||
922 | data.size()); | |||
923 | return success(); | |||
924 | } | |||
925 | ||||
926 | private: | |||
927 | AttrTypeReader &attrTypeReader; | |||
928 | StringSectionReader &stringReader; | |||
929 | ResourceSectionReader &resourceReader; | |||
930 | EncodingReader &reader; | |||
931 | }; | |||
932 | } // namespace | |||
933 | ||||
934 | LogicalResult | |||
935 | AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
936 | ArrayRef<uint8_t> sectionData, | |||
937 | ArrayRef<uint8_t> offsetSectionData) { | |||
938 | EncodingReader offsetReader(offsetSectionData, fileLoc); | |||
939 | ||||
940 | // Parse the number of attribute and type entries. | |||
941 | uint64_t numAttributes, numTypes; | |||
942 | if (failed(offsetReader.parseVarInt(numAttributes)) || | |||
943 | failed(offsetReader.parseVarInt(numTypes))) | |||
944 | return failure(); | |||
945 | attributes.resize(numAttributes); | |||
946 | types.resize(numTypes); | |||
947 | ||||
948 | // A functor used to accumulate the offsets for the entries in the given | |||
949 | // range. | |||
950 | uint64_t currentOffset = 0; | |||
951 | auto parseEntries = [&](auto &&range) { | |||
952 | size_t currentIndex = 0, endIndex = range.size(); | |||
953 | ||||
954 | // Parse an individual entry. | |||
955 | auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { | |||
956 | auto &entry = range[currentIndex++]; | |||
957 | ||||
958 | uint64_t entrySize; | |||
959 | if (failed(offsetReader.parseVarIntWithFlag(entrySize, | |||
960 | entry.hasCustomEncoding))) | |||
961 | return failure(); | |||
962 | ||||
963 | // Verify that the offset is actually valid. | |||
964 | if (currentOffset + entrySize > sectionData.size()) { | |||
965 | return offsetReader.emitError( | |||
966 | "Attribute or Type entry offset points past the end of section"); | |||
967 | } | |||
968 | ||||
969 | entry.data = sectionData.slice(currentOffset, entrySize); | |||
970 | entry.dialect = dialect; | |||
971 | currentOffset += entrySize; | |||
972 | return success(); | |||
973 | }; | |||
974 | while (currentIndex != endIndex) | |||
975 | if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) | |||
976 | return failure(); | |||
977 | return success(); | |||
978 | }; | |||
979 | ||||
980 | // Process each of the attributes, and then the types. | |||
981 | if (failed(parseEntries(attributes)) || failed(parseEntries(types))) | |||
982 | return failure(); | |||
983 | ||||
984 | // Ensure that we read everything from the section. | |||
985 | if (!offsetReader.empty()) { | |||
986 | return offsetReader.emitError( | |||
987 | "unexpected trailing data in the Attribute/Type offset section"); | |||
988 | } | |||
989 | return success(); | |||
990 | } | |||
991 | ||||
992 | template <typename T> | |||
993 | T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
994 | StringRef entryType) { | |||
995 | if (index >= entries.size()) { | |||
996 | emitError(fileLoc) << "invalid " << entryType << " index: " << index; | |||
997 | return {}; | |||
998 | } | |||
999 | ||||
1000 | // If the entry has already been resolved, there is nothing left to do. | |||
1001 | Entry<T> &entry = entries[index]; | |||
1002 | if (entry.entry) | |||
1003 | return entry.entry; | |||
1004 | ||||
1005 | // Parse the entry. | |||
1006 | EncodingReader reader(entry.data, fileLoc); | |||
1007 | ||||
1008 | // Parse based on how the entry was encoded. | |||
1009 | if (entry.hasCustomEncoding) { | |||
1010 | if (failed(parseCustomEntry(entry, reader, entryType))) | |||
1011 | return T(); | |||
1012 | } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { | |||
1013 | return T(); | |||
1014 | } | |||
1015 | ||||
1016 | if (!reader.empty()) { | |||
1017 | reader.emitError("unexpected trailing bytes after " + entryType + " entry"); | |||
1018 | return T(); | |||
1019 | } | |||
1020 | return entry.entry; | |||
1021 | } | |||
1022 | ||||
1023 | template <typename T> | |||
1024 | LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, | |||
1025 | StringRef entryType) { | |||
1026 | StringRef asmStr; | |||
1027 | if (failed(reader.parseNullTerminatedString(asmStr))) | |||
1028 | return failure(); | |||
1029 | ||||
1030 | // Invoke the MLIR assembly parser to parse the entry text. | |||
1031 | size_t numRead = 0; | |||
1032 | MLIRContext *context = fileLoc->getContext(); | |||
1033 | if constexpr (std::is_same_v<T, Type>) | |||
1034 | result = | |||
1035 | ::parseType(asmStr, context, &numRead, /*isKnownNullTerminated=*/true); | |||
1036 | else | |||
1037 | result = ::parseAttribute(asmStr, context, Type(), &numRead, | |||
1038 | /*isKnownNullTerminated=*/true); | |||
1039 | if (!result) | |||
1040 | return failure(); | |||
1041 | ||||
1042 | // Ensure there weren't dangling characters after the entry. | |||
1043 | if (numRead != asmStr.size()) { | |||
1044 | return reader.emitError("trailing characters found after ", entryType, | |||
1045 | " assembly format: ", asmStr.drop_front(numRead)); | |||
1046 | } | |||
1047 | return success(); | |||
1048 | } | |||
1049 | ||||
1050 | template <typename T> | |||
1051 | LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, | |||
1052 | EncodingReader &reader, | |||
1053 | StringRef entryType) { | |||
1054 | DialectReader dialectReader(*this, stringReader, resourceReader, reader); | |||
1055 | if (failed(entry.dialect->load(dialectReader, fileLoc.getContext()))) | |||
1056 | return failure(); | |||
1057 | ||||
1058 | // Ensure that the dialect implements the bytecode interface. | |||
1059 | if (!entry.dialect->interface) { | |||
1060 | return reader.emitError("dialect '", entry.dialect->name, | |||
1061 | "' does not implement the bytecode interface"); | |||
1062 | } | |||
1063 | ||||
1064 | // Ask the dialect to parse the entry. If the dialect is versioned, parse | |||
1065 | // using the versioned encoding readers. | |||
1066 | if (entry.dialect->loadedVersion.get()) { | |||
1067 | if constexpr (std::is_same_v<T, Type>) | |||
1068 | entry.entry = entry.dialect->interface->readType( | |||
1069 | dialectReader, *entry.dialect->loadedVersion); | |||
1070 | else | |||
1071 | entry.entry = entry.dialect->interface->readAttribute( | |||
1072 | dialectReader, *entry.dialect->loadedVersion); | |||
1073 | ||||
1074 | } else { | |||
1075 | if constexpr (std::is_same_v<T, Type>) | |||
1076 | entry.entry = entry.dialect->interface->readType(dialectReader); | |||
1077 | else | |||
1078 | entry.entry = entry.dialect->interface->readAttribute(dialectReader); | |||
1079 | } | |||
1080 | return success(!!entry.entry); | |||
1081 | } | |||
1082 | ||||
1083 | //===----------------------------------------------------------------------===// | |||
1084 | // Bytecode Reader | |||
1085 | //===----------------------------------------------------------------------===// | |||
1086 | ||||
1087 | namespace { | |||
1088 | /// This class is used to read a bytecode buffer and translate it into MLIR. | |||
1089 | class BytecodeReader { | |||
1090 | public: | |||
1091 | BytecodeReader(Location fileLoc, const ParserConfig &config, | |||
1092 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) | |||
1093 | : config(config), fileLoc(fileLoc), | |||
1094 | attrTypeReader(stringReader, resourceReader, fileLoc), | |||
1095 | // Use the builtin unrealized conversion cast operation to represent | |||
1096 | // forward references to values that aren't yet defined. | |||
1097 | forwardRefOpState(UnknownLoc::get(config.getContext()), | |||
1098 | "builtin.unrealized_conversion_cast", ValueRange(), | |||
1099 | NoneType::get(config.getContext())), | |||
1100 | bufferOwnerRef(bufferOwnerRef) {} | |||
1101 | ||||
1102 | /// Read the bytecode defined within `buffer` into the given block. | |||
1103 | LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); | |||
1104 | ||||
1105 | private: | |||
1106 | /// Return the context for this config. | |||
1107 | MLIRContext *getContext() const { return config.getContext(); } | |||
1108 | ||||
1109 | /// Parse the bytecode version. | |||
1110 | LogicalResult parseVersion(EncodingReader &reader); | |||
1111 | ||||
1112 | //===--------------------------------------------------------------------===// | |||
1113 | // Dialect Section | |||
1114 | ||||
1115 | LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); | |||
1116 | ||||
1117 | /// Parse an operation name reference using the given reader. | |||
1118 | FailureOr<OperationName> parseOpName(EncodingReader &reader); | |||
1119 | ||||
1120 | //===--------------------------------------------------------------------===// | |||
1121 | // Attribute/Type Section | |||
1122 | ||||
1123 | /// Parse an attribute or type using the given reader. | |||
1124 | template <typename T> | |||
1125 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
1126 | return attrTypeReader.parseAttribute(reader, result); | |||
1127 | } | |||
1128 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
1129 | return attrTypeReader.parseType(reader, result); | |||
1130 | } | |||
1131 | ||||
1132 | //===--------------------------------------------------------------------===// | |||
1133 | // Resource Section | |||
1134 | ||||
1135 | LogicalResult | |||
1136 | parseResourceSection(EncodingReader &reader, | |||
1137 | std::optional<ArrayRef<uint8_t>> resourceData, | |||
1138 | std::optional<ArrayRef<uint8_t>> resourceOffsetData); | |||
1139 | ||||
1140 | //===--------------------------------------------------------------------===// | |||
1141 | // IR Section | |||
1142 | ||||
1143 | /// This struct represents the current read state of a range of regions. This | |||
1144 | /// struct is used to enable iterative parsing of regions. | |||
1145 | struct RegionReadState { | |||
1146 | RegionReadState(Operation *op, bool isIsolatedFromAbove) | |||
1147 | : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} | |||
1148 | RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove) | |||
1149 | : curRegion(regions.begin()), endRegion(regions.end()), | |||
1150 | isIsolatedFromAbove(isIsolatedFromAbove) {} | |||
1151 | ||||
1152 | /// The current regions being read. | |||
1153 | MutableArrayRef<Region>::iterator curRegion, endRegion; | |||
1154 | ||||
1155 | /// The number of values defined immediately within this region. | |||
1156 | unsigned numValues = 0; | |||
1157 | ||||
1158 | /// The current blocks of the region being read. | |||
1159 | SmallVector<Block *> curBlocks; | |||
1160 | Region::iterator curBlock = {}; | |||
1161 | ||||
1162 | /// The number of operations remaining to be read from the current block | |||
1163 | /// being read. | |||
1164 | uint64_t numOpsRemaining = 0; | |||
1165 | ||||
1166 | /// A flag indicating if the regions being read are isolated from above. | |||
1167 | bool isIsolatedFromAbove = false; | |||
1168 | }; | |||
1169 | ||||
1170 | LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); | |||
1171 | LogicalResult parseRegions(EncodingReader &reader, | |||
1172 | std::vector<RegionReadState> ®ionStack, | |||
1173 | RegionReadState &readState); | |||
1174 | FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, | |||
1175 | RegionReadState &readState, | |||
1176 | bool &isIsolatedFromAbove); | |||
1177 | ||||
1178 | LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); | |||
1179 | LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); | |||
1180 | LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); | |||
1181 | ||||
1182 | //===--------------------------------------------------------------------===// | |||
1183 | // Value Processing | |||
1184 | ||||
1185 | /// Parse an operand reference using the given reader. Returns nullptr in the | |||
1186 | /// case of failure. | |||
1187 | Value parseOperand(EncodingReader &reader); | |||
1188 | ||||
1189 | /// Sequentially define the given value range. | |||
1190 | LogicalResult defineValues(EncodingReader &reader, ValueRange values); | |||
1191 | ||||
1192 | /// Create a value to use for a forward reference. | |||
1193 | Value createForwardRef(); | |||
1194 | ||||
1195 | //===--------------------------------------------------------------------===// | |||
1196 | // Fields | |||
1197 | ||||
1198 | /// This class represents a single value scope, in which a value scope is | |||
1199 | /// delimited by isolated from above regions. | |||
1200 | struct ValueScope { | |||
1201 | /// Push a new region state onto this scope, reserving enough values for | |||
1202 | /// those defined within the current region of the provided state. | |||
1203 | void push(RegionReadState &readState) { | |||
1204 | nextValueIDs.push_back(values.size()); | |||
1205 | values.resize(values.size() + readState.numValues); | |||
1206 | } | |||
1207 | ||||
1208 | /// Pop the values defined for the current region within the provided region | |||
1209 | /// state. | |||
1210 | void pop(RegionReadState &readState) { | |||
1211 | values.resize(values.size() - readState.numValues); | |||
1212 | nextValueIDs.pop_back(); | |||
1213 | } | |||
1214 | ||||
1215 | /// The set of values defined in this scope. | |||
1216 | std::vector<Value> values; | |||
1217 | ||||
1218 | /// The ID for the next defined value for each region current being | |||
1219 | /// processed in this scope. | |||
1220 | SmallVector<unsigned, 4> nextValueIDs; | |||
1221 | }; | |||
1222 | ||||
1223 | /// The configuration of the parser. | |||
1224 | const ParserConfig &config; | |||
1225 | ||||
1226 | /// A location to use when emitting errors. | |||
1227 | Location fileLoc; | |||
1228 | ||||
1229 | /// The reader used to process attribute and types within the bytecode. | |||
1230 | AttrTypeReader attrTypeReader; | |||
1231 | ||||
1232 | /// The version of the bytecode being read. | |||
1233 | uint64_t version = 0; | |||
1234 | ||||
1235 | /// The producer of the bytecode being read. | |||
1236 | StringRef producer; | |||
1237 | ||||
1238 | /// The table of IR units referenced within the bytecode file. | |||
1239 | SmallVector<BytecodeDialect> dialects; | |||
1240 | SmallVector<BytecodeOperationName> opNames; | |||
1241 | ||||
1242 | /// The reader used to process resources within the bytecode. | |||
1243 | ResourceSectionReader resourceReader; | |||
1244 | ||||
1245 | /// The table of strings referenced within the bytecode file. | |||
1246 | StringSectionReader stringReader; | |||
1247 | ||||
1248 | /// The current set of available IR value scopes. | |||
1249 | std::vector<ValueScope> valueScopes; | |||
1250 | /// A block containing the set of operations defined to create forward | |||
1251 | /// references. | |||
1252 | Block forwardRefOps; | |||
1253 | /// A block containing previously created, and no longer used, forward | |||
1254 | /// reference operations. | |||
1255 | Block openForwardRefOps; | |||
1256 | /// An operation state used when instantiating forward references. | |||
1257 | OperationState forwardRefOpState; | |||
1258 | ||||
1259 | /// The optional owning source manager, which when present may be used to | |||
1260 | /// extend the lifetime of the input buffer. | |||
1261 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef; | |||
1262 | }; | |||
1263 | } // namespace | |||
1264 | ||||
1265 | LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { | |||
1266 | EncodingReader reader(buffer.getBuffer(), fileLoc); | |||
1267 | ||||
1268 | // Skip over the bytecode header, this should have already been checked. | |||
1269 | if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) | |||
1270 | return failure(); | |||
1271 | // Parse the bytecode version and producer. | |||
1272 | if (failed(parseVersion(reader)) || | |||
1273 | failed(reader.parseNullTerminatedString(producer))) | |||
1274 | return failure(); | |||
1275 | ||||
1276 | // Add a diagnostic handler that attaches a note that includes the original | |||
1277 | // producer of the bytecode. | |||
1278 | ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { | |||
1279 | diag.attachNote() << "in bytecode version " << version | |||
1280 | << " produced by: " << producer; | |||
1281 | return failure(); | |||
1282 | }); | |||
1283 | ||||
1284 | // Parse the raw data for each of the top-level sections of the bytecode. | |||
1285 | std::optional<ArrayRef<uint8_t>> | |||
1286 | sectionDatas[bytecode::Section::kNumSections]; | |||
1287 | while (!reader.empty()) { | |||
1288 | // Read the next section from the bytecode. | |||
1289 | bytecode::Section::ID sectionID; | |||
1290 | ArrayRef<uint8_t> sectionData; | |||
1291 | if (failed(reader.parseSection(sectionID, sectionData))) | |||
1292 | return failure(); | |||
1293 | ||||
1294 | // Check for duplicate sections, we only expect one instance of each. | |||
1295 | if (sectionDatas[sectionID]) { | |||
1296 | return reader.emitError("duplicate top-level section: ", | |||
1297 | toString(sectionID)); | |||
1298 | } | |||
1299 | sectionDatas[sectionID] = sectionData; | |||
1300 | } | |||
1301 | // Check that all of the required sections were found. | |||
1302 | for (int i = 0; i < bytecode::Section::kNumSections; ++i) { | |||
1303 | bytecode::Section::ID sectionID = static_cast<bytecode::Section::ID>(i); | |||
1304 | if (!sectionDatas[i] && !isSectionOptional(sectionID)) { | |||
1305 | return reader.emitError("missing data for top-level section: ", | |||
1306 | toString(sectionID)); | |||
1307 | } | |||
1308 | } | |||
1309 | ||||
1310 | // Process the string section first. | |||
1311 | if (failed(stringReader.initialize( | |||
1312 | fileLoc, *sectionDatas[bytecode::Section::kString]))) | |||
1313 | return failure(); | |||
1314 | ||||
1315 | // Process the dialect section. | |||
1316 | if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) | |||
1317 | return failure(); | |||
1318 | ||||
1319 | // Process the resource section if present. | |||
1320 | if (failed(parseResourceSection( | |||
1321 | reader, sectionDatas[bytecode::Section::kResource], | |||
1322 | sectionDatas[bytecode::Section::kResourceOffset]))) | |||
1323 | return failure(); | |||
1324 | ||||
1325 | // Process the attribute and type section. | |||
1326 | if (failed(attrTypeReader.initialize( | |||
1327 | dialects, *sectionDatas[bytecode::Section::kAttrType], | |||
1328 | *sectionDatas[bytecode::Section::kAttrTypeOffset]))) | |||
1329 | return failure(); | |||
1330 | ||||
1331 | // Finally, process the IR section. | |||
1332 | return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); | |||
1333 | } | |||
1334 | ||||
1335 | LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { | |||
1336 | if (failed(reader.parseVarInt(version))) | |||
1337 | return failure(); | |||
1338 | ||||
1339 | // Validate the bytecode version. | |||
1340 | uint64_t currentVersion = bytecode::kVersion; | |||
1341 | uint64_t minSupportedVersion = bytecode::kMinSupportedVersion; | |||
1342 | if (version < minSupportedVersion) { | |||
1343 | return reader.emitError("bytecode version ", version, | |||
1344 | " is older than the current version of ", | |||
1345 | currentVersion, ", and upgrade is not supported"); | |||
1346 | } | |||
1347 | if (version > currentVersion) { | |||
1348 | return reader.emitError("bytecode version ", version, | |||
1349 | " is newer than the current version ", | |||
1350 | currentVersion); | |||
1351 | } | |||
1352 | return success(); | |||
1353 | } | |||
1354 | ||||
1355 | //===----------------------------------------------------------------------===// | |||
1356 | // Dialect Section | |||
1357 | ||||
1358 | LogicalResult BytecodeDialect::load(DialectReader &reader, MLIRContext *ctx) { | |||
1359 | if (dialect) | |||
1360 | return success(); | |||
1361 | Dialect *loadedDialect = ctx->getOrLoadDialect(name); | |||
1362 | if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { | |||
1363 | return reader.emitError("dialect '") | |||
1364 | << name | |||
1365 | << "' is unknown. If this is intended, please call " | |||
1366 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
1367 | "-allow-unregistered-dialect with the MLIR tool used."; | |||
1368 | } | |||
1369 | dialect = loadedDialect; | |||
1370 | ||||
1371 | // If the dialect was actually loaded, check to see if it has a bytecode | |||
1372 | // interface. | |||
1373 | if (loadedDialect) | |||
1374 | interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); | |||
1375 | if (!versionBuffer.empty()) { | |||
1376 | if (!interface) | |||
1377 | return reader.emitError("dialect '") | |||
1378 | << name | |||
1379 | << "' does not implement the bytecode interface, " | |||
1380 | "but found a version entry"; | |||
1381 | loadedVersion = interface->readVersion(reader); | |||
1382 | if (!loadedVersion) | |||
1383 | return failure(); | |||
1384 | } | |||
1385 | return success(); | |||
1386 | } | |||
1387 | ||||
1388 | LogicalResult | |||
1389 | BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { | |||
1390 | EncodingReader sectionReader(sectionData, fileLoc); | |||
1391 | ||||
1392 | // Parse the number of dialects in the section. | |||
1393 | uint64_t numDialects; | |||
1394 | if (failed(sectionReader.parseVarInt(numDialects))) | |||
1395 | return failure(); | |||
1396 | dialects.resize(numDialects); | |||
1397 | ||||
1398 | // Parse each of the dialects. | |||
1399 | for (uint64_t i = 0; i < numDialects; ++i) { | |||
1400 | /// Before version 1, there wasn't any versioning available for dialects, | |||
1401 | /// and the entryIdx represent the string itself. | |||
1402 | if (version == 0) { | |||
1403 | if (failed(stringReader.parseString(sectionReader, dialects[i].name))) | |||
1404 | return failure(); | |||
1405 | continue; | |||
1406 | } | |||
1407 | // Parse ID representing dialect and version. | |||
1408 | uint64_t dialectNameIdx; | |||
1409 | bool versionAvailable; | |||
1410 | if (failed(sectionReader.parseVarIntWithFlag(dialectNameIdx, | |||
1411 | versionAvailable))) | |||
1412 | return failure(); | |||
1413 | if (failed(stringReader.parseStringAtIndex(sectionReader, dialectNameIdx, | |||
1414 | dialects[i].name))) | |||
1415 | return failure(); | |||
1416 | if (versionAvailable) { | |||
1417 | bytecode::Section::ID sectionID; | |||
1418 | if (failed( | |||
1419 | sectionReader.parseSection(sectionID, dialects[i].versionBuffer))) | |||
1420 | return failure(); | |||
1421 | if (sectionID != bytecode::Section::kDialectVersions) { | |||
1422 | emitError(fileLoc, "expected dialect version section"); | |||
1423 | return failure(); | |||
1424 | } | |||
1425 | } | |||
1426 | } | |||
1427 | ||||
1428 | // Parse the operation names, which are grouped by dialect. | |||
1429 | auto parseOpName = [&](BytecodeDialect *dialect) { | |||
1430 | StringRef opName; | |||
1431 | if (failed(stringReader.parseString(sectionReader, opName))) | |||
1432 | return failure(); | |||
1433 | opNames.emplace_back(dialect, opName); | |||
1434 | return success(); | |||
1435 | }; | |||
1436 | while (!sectionReader.empty()) | |||
1437 | if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) | |||
1438 | return failure(); | |||
1439 | return success(); | |||
1440 | } | |||
1441 | ||||
1442 | FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { | |||
1443 | BytecodeOperationName *opName = nullptr; | |||
1444 | if (failed(parseEntry(reader, opNames, opName, "operation name"))) | |||
1445 | return failure(); | |||
1446 | ||||
1447 | // Check to see if this operation name has already been resolved. If we | |||
1448 | // haven't, load the dialect and build the operation name. | |||
1449 | if (!opName->opName) { | |||
1450 | // Load the dialect and its version. | |||
1451 | EncodingReader versionReader(opName->dialect->versionBuffer, fileLoc); | |||
1452 | DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, | |||
1453 | versionReader); | |||
1454 | if (failed(opName->dialect->load(dialectReader, getContext()))) | |||
1455 | return failure(); | |||
1456 | opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), | |||
1457 | getContext()); | |||
1458 | } | |||
1459 | return *opName->opName; | |||
1460 | } | |||
1461 | ||||
1462 | //===----------------------------------------------------------------------===// | |||
1463 | // Resource Section | |||
1464 | ||||
1465 | LogicalResult BytecodeReader::parseResourceSection( | |||
1466 | EncodingReader &reader, std::optional<ArrayRef<uint8_t>> resourceData, | |||
1467 | std::optional<ArrayRef<uint8_t>> resourceOffsetData) { | |||
1468 | // Ensure both sections are either present or not. | |||
1469 | if (resourceData.has_value() != resourceOffsetData.has_value()) { | |||
1470 | if (resourceOffsetData) | |||
1471 | return emitError(fileLoc, "unexpected resource offset section when " | |||
1472 | "resource section is not present"); | |||
1473 | return emitError( | |||
1474 | fileLoc, | |||
1475 | "expected resource offset section when resource section is present"); | |||
1476 | } | |||
1477 | ||||
1478 | // If the resource sections are absent, there is nothing to do. | |||
1479 | if (!resourceData) | |||
1480 | return success(); | |||
1481 | ||||
1482 | // Initialize the resource reader with the resource sections. | |||
1483 | DialectReader dialectReader(attrTypeReader, stringReader, resourceReader, | |||
1484 | reader); | |||
1485 | return resourceReader.initialize(fileLoc, config, dialects, stringReader, | |||
1486 | *resourceData, *resourceOffsetData, | |||
1487 | dialectReader, bufferOwnerRef); | |||
1488 | } | |||
1489 | ||||
1490 | //===----------------------------------------------------------------------===// | |||
1491 | // IR Section | |||
1492 | ||||
1493 | LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, | |||
1494 | Block *block) { | |||
1495 | EncodingReader reader(sectionData, fileLoc); | |||
1496 | ||||
1497 | // A stack of operation regions currently being read from the bytecode. | |||
1498 | std::vector<RegionReadState> regionStack; | |||
1499 | ||||
1500 | // Parse the top-level block using a temporary module operation. | |||
1501 | OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); | |||
1502 | regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); | |||
1503 | regionStack.back().curBlocks.push_back(moduleOp->getBody()); | |||
1504 | regionStack.back().curBlock = regionStack.back().curRegion->begin(); | |||
1505 | if (failed(parseBlock(reader, regionStack.back()))) | |||
1506 | return failure(); | |||
1507 | valueScopes.emplace_back(); | |||
1508 | valueScopes.back().push(regionStack.back()); | |||
1509 | ||||
1510 | // Iteratively parse regions until everything has been resolved. | |||
1511 | while (!regionStack.empty()) | |||
1512 | if (failed(parseRegions(reader, regionStack, regionStack.back()))) | |||
1513 | return failure(); | |||
1514 | if (!forwardRefOps.empty()) { | |||
1515 | return reader.emitError( | |||
1516 | "not all forward unresolved forward operand references"); | |||
1517 | } | |||
1518 | ||||
1519 | // Resolve dialect version. | |||
1520 | for (const BytecodeDialect &byteCodeDialect : dialects) { | |||
1521 | // Parsing is complete, give an opportunity to each dialect to visit the | |||
1522 | // IR and perform upgrades. | |||
1523 | if (!byteCodeDialect.loadedVersion) | |||
1524 | continue; | |||
1525 | if (byteCodeDialect.interface && | |||
1526 | failed(byteCodeDialect.interface->upgradeFromVersion( | |||
1527 | *moduleOp, *byteCodeDialect.loadedVersion))) | |||
1528 | return failure(); | |||
1529 | } | |||
1530 | ||||
1531 | // Verify that the parsed operations are valid. | |||
1532 | if (config.shouldVerifyAfterParse() && failed(verify(*moduleOp))) | |||
1533 | return failure(); | |||
1534 | ||||
1535 | // Splice the parsed operations over to the provided top-level block. | |||
1536 | auto &parsedOps = moduleOp->getBody()->getOperations(); | |||
1537 | auto &destOps = block->getOperations(); | |||
1538 | destOps.splice(destOps.end(), parsedOps, parsedOps.begin(), parsedOps.end()); | |||
1539 | return success(); | |||
1540 | } | |||
1541 | ||||
1542 | LogicalResult | |||
1543 | BytecodeReader::parseRegions(EncodingReader &reader, | |||
1544 | std::vector<RegionReadState> ®ionStack, | |||
1545 | RegionReadState &readState) { | |||
1546 | // Read the regions of this operation. | |||
1547 | for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { | |||
1548 | // If the current block hasn't been setup yet, parse the header for this | |||
1549 | // region. | |||
1550 | if (readState.curBlock == Region::iterator()) { | |||
1551 | if (failed(parseRegion(reader, readState))) | |||
1552 | return failure(); | |||
1553 | ||||
1554 | // If the region is empty, there is nothing to more to do. | |||
1555 | if (readState.curRegion->empty()) | |||
1556 | continue; | |||
1557 | } | |||
1558 | ||||
1559 | // Parse the blocks within the region. | |||
1560 | do { | |||
1561 | while (readState.numOpsRemaining--) { | |||
1562 | // Read in the next operation. We don't read its regions directly, we | |||
1563 | // handle those afterwards as necessary. | |||
1564 | bool isIsolatedFromAbove = false; | |||
1565 | FailureOr<Operation *> op = | |||
1566 | parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); | |||
1567 | if (failed(op)) | |||
1568 | return failure(); | |||
1569 | ||||
1570 | // If the op has regions, add it to the stack for processing. | |||
1571 | if ((*op)->getNumRegions()) { | |||
1572 | regionStack.emplace_back(*op, isIsolatedFromAbove); | |||
1573 | ||||
1574 | // If the op is isolated from above, push a new value scope. | |||
1575 | if (isIsolatedFromAbove) | |||
1576 | valueScopes.emplace_back(); | |||
1577 | return success(); | |||
1578 | } | |||
1579 | } | |||
1580 | ||||
1581 | // Move to the next block of the region. | |||
1582 | if (++readState.curBlock == readState.curRegion->end()) | |||
1583 | break; | |||
1584 | if (failed(parseBlock(reader, readState))) | |||
1585 | return failure(); | |||
1586 | } while (true); | |||
1587 | ||||
1588 | // Reset the current block and any values reserved for this region. | |||
1589 | readState.curBlock = {}; | |||
1590 | valueScopes.back().pop(readState); | |||
1591 | } | |||
1592 | ||||
1593 | // When the regions have been fully parsed, pop them off of the read stack. If | |||
1594 | // the regions were isolated from above, we also pop the last value scope. | |||
1595 | if (readState.isIsolatedFromAbove) | |||
1596 | valueScopes.pop_back(); | |||
1597 | regionStack.pop_back(); | |||
1598 | return success(); | |||
1599 | } | |||
1600 | ||||
1601 | FailureOr<Operation *> | |||
1602 | BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, | |||
1603 | RegionReadState &readState, | |||
1604 | bool &isIsolatedFromAbove) { | |||
1605 | // Parse the name of the operation. | |||
1606 | FailureOr<OperationName> opName = parseOpName(reader); | |||
1607 | if (failed(opName)) | |||
1608 | return failure(); | |||
1609 | ||||
1610 | // Parse the operation mask, which indicates which components of the operation | |||
1611 | // are present. | |||
1612 | uint8_t opMask; | |||
1613 | if (failed(reader.parseByte(opMask))) | |||
1614 | return failure(); | |||
1615 | ||||
1616 | /// Parse the location. | |||
1617 | LocationAttr opLoc; | |||
1618 | if (failed(parseAttribute(reader, opLoc))) | |||
1619 | return failure(); | |||
1620 | ||||
1621 | // With the location and name resolved, we can start building the operation | |||
1622 | // state. | |||
1623 | OperationState opState(opLoc, *opName); | |||
1624 | ||||
1625 | // Parse the attributes of the operation. | |||
1626 | if (opMask & bytecode::OpEncodingMask::kHasAttrs) { | |||
1627 | DictionaryAttr dictAttr; | |||
1628 | if (failed(parseAttribute(reader, dictAttr))) | |||
1629 | return failure(); | |||
1630 | opState.attributes = dictAttr; | |||
1631 | } | |||
1632 | ||||
1633 | /// Parse the results of the operation. | |||
1634 | if (opMask & bytecode::OpEncodingMask::kHasResults) { | |||
1635 | uint64_t numResults; | |||
1636 | if (failed(reader.parseVarInt(numResults))) | |||
1637 | return failure(); | |||
1638 | opState.types.resize(numResults); | |||
1639 | for (int i = 0, e = numResults; i < e; ++i) | |||
1640 | if (failed(parseType(reader, opState.types[i]))) | |||
1641 | return failure(); | |||
1642 | } | |||
1643 | ||||
1644 | /// Parse the operands of the operation. | |||
1645 | if (opMask & bytecode::OpEncodingMask::kHasOperands) { | |||
1646 | uint64_t numOperands; | |||
1647 | if (failed(reader.parseVarInt(numOperands))) | |||
1648 | return failure(); | |||
1649 | opState.operands.resize(numOperands); | |||
1650 | for (int i = 0, e = numOperands; i < e; ++i) | |||
1651 | if (!(opState.operands[i] = parseOperand(reader))) | |||
1652 | return failure(); | |||
1653 | } | |||
1654 | ||||
1655 | /// Parse the successors of the operation. | |||
1656 | if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { | |||
1657 | uint64_t numSuccs; | |||
1658 | if (failed(reader.parseVarInt(numSuccs))) | |||
1659 | return failure(); | |||
1660 | opState.successors.resize(numSuccs); | |||
1661 | for (int i = 0, e = numSuccs; i < e; ++i) { | |||
1662 | if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], | |||
1663 | "successor"))) | |||
1664 | return failure(); | |||
1665 | } | |||
1666 | } | |||
1667 | ||||
1668 | /// Parse the regions of the operation. | |||
1669 | if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { | |||
1670 | uint64_t numRegions; | |||
1671 | if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) | |||
1672 | return failure(); | |||
1673 | ||||
1674 | opState.regions.reserve(numRegions); | |||
1675 | for (int i = 0, e = numRegions; i < e; ++i) | |||
1676 | opState.regions.push_back(std::make_unique<Region>()); | |||
1677 | } | |||
1678 | ||||
1679 | // Create the operation at the back of the current block. | |||
1680 | Operation *op = Operation::create(opState); | |||
1681 | readState.curBlock->push_back(op); | |||
1682 | ||||
1683 | // If the operation had results, update the value references. | |||
1684 | if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) | |||
1685 | return failure(); | |||
1686 | ||||
1687 | return op; | |||
1688 | } | |||
1689 | ||||
1690 | LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, | |||
1691 | RegionReadState &readState) { | |||
1692 | // Parse the number of blocks in the region. | |||
1693 | uint64_t numBlocks; | |||
1694 | if (failed(reader.parseVarInt(numBlocks))) | |||
1695 | return failure(); | |||
1696 | ||||
1697 | // If the region is empty, there is nothing else to do. | |||
1698 | if (numBlocks == 0) | |||
1699 | return success(); | |||
1700 | ||||
1701 | // Parse the number of values defined in this region. | |||
1702 | uint64_t numValues; | |||
1703 | if (failed(reader.parseVarInt(numValues))) | |||
1704 | return failure(); | |||
1705 | readState.numValues = numValues; | |||
1706 | ||||
1707 | // Create the blocks within this region. We do this before processing so that | |||
1708 | // we can rely on the blocks existing when creating operations. | |||
1709 | readState.curBlocks.clear(); | |||
1710 | readState.curBlocks.reserve(numBlocks); | |||
1711 | for (uint64_t i = 0; i < numBlocks; ++i) { | |||
1712 | readState.curBlocks.push_back(new Block()); | |||
1713 | readState.curRegion->push_back(readState.curBlocks.back()); | |||
1714 | } | |||
1715 | ||||
1716 | // Prepare the current value scope for this region. | |||
1717 | valueScopes.back().push(readState); | |||
1718 | ||||
1719 | // Parse the entry block of the region. | |||
1720 | readState.curBlock = readState.curRegion->begin(); | |||
1721 | return parseBlock(reader, readState); | |||
1722 | } | |||
1723 | ||||
1724 | LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, | |||
1725 | RegionReadState &readState) { | |||
1726 | bool hasArgs; | |||
1727 | if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) | |||
1728 | return failure(); | |||
1729 | ||||
1730 | // Parse the arguments of the block. | |||
1731 | if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) | |||
1732 | return failure(); | |||
1733 | ||||
1734 | // We don't parse the operations of the block here, that's done elsewhere. | |||
1735 | return success(); | |||
1736 | } | |||
1737 | ||||
1738 | LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, | |||
1739 | Block *block) { | |||
1740 | // Parse the value ID for the first argument, and the number of arguments. | |||
1741 | uint64_t numArgs; | |||
1742 | if (failed(reader.parseVarInt(numArgs))) | |||
1743 | return failure(); | |||
1744 | ||||
1745 | SmallVector<Type> argTypes; | |||
1746 | SmallVector<Location> argLocs; | |||
1747 | argTypes.reserve(numArgs); | |||
1748 | argLocs.reserve(numArgs); | |||
1749 | ||||
1750 | while (numArgs--) { | |||
1751 | Type argType; | |||
1752 | LocationAttr argLoc; | |||
1753 | if (failed(parseType(reader, argType)) || | |||
1754 | failed(parseAttribute(reader, argLoc))) | |||
1755 | return failure(); | |||
1756 | ||||
1757 | argTypes.push_back(argType); | |||
1758 | argLocs.push_back(argLoc); | |||
1759 | } | |||
1760 | block->addArguments(argTypes, argLocs); | |||
1761 | return defineValues(reader, block->getArguments()); | |||
1762 | } | |||
1763 | ||||
1764 | //===----------------------------------------------------------------------===// | |||
1765 | // Value Processing | |||
1766 | ||||
1767 | Value BytecodeReader::parseOperand(EncodingReader &reader) { | |||
1768 | std::vector<Value> &values = valueScopes.back().values; | |||
1769 | Value *value = nullptr; | |||
1770 | if (failed(parseEntry(reader, values, value, "value"))) | |||
1771 | return Value(); | |||
1772 | ||||
1773 | // Create a new forward reference if necessary. | |||
1774 | if (!*value) | |||
1775 | *value = createForwardRef(); | |||
1776 | return *value; | |||
1777 | } | |||
1778 | ||||
1779 | LogicalResult BytecodeReader::defineValues(EncodingReader &reader, | |||
1780 | ValueRange newValues) { | |||
1781 | ValueScope &valueScope = valueScopes.back(); | |||
1782 | std::vector<Value> &values = valueScope.values; | |||
1783 | ||||
1784 | unsigned &valueID = valueScope.nextValueIDs.back(); | |||
1785 | unsigned valueIDEnd = valueID + newValues.size(); | |||
1786 | if (valueIDEnd > values.size()) { | |||
1787 | return reader.emitError( | |||
1788 | "value index range was outside of the expected range for " | |||
1789 | "the parent region, got [", | |||
1790 | valueID, ", ", valueIDEnd, "), but the maximum index was ", | |||
1791 | values.size() - 1); | |||
1792 | } | |||
1793 | ||||
1794 | // Assign the values and update any forward references. | |||
1795 | for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { | |||
1796 | Value newValue = newValues[i]; | |||
1797 | ||||
1798 | // Check to see if a definition for this value already exists. | |||
1799 | if (Value oldValue = std::exchange(values[valueID], newValue)) { | |||
1800 | Operation *forwardRefOp = oldValue.getDefiningOp(); | |||
1801 | ||||
1802 | // Assert that this is a forward reference operation. Given how we compute | |||
1803 | // definition ids (incrementally as we parse), it shouldn't be possible | |||
1804 | // for the value to be defined any other way. | |||
1805 | assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1806, __extension__ __PRETTY_FUNCTION__)) | |||
1806 | "value index was already defined?")(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1806, __extension__ __PRETTY_FUNCTION__)); | |||
1807 | ||||
1808 | oldValue.replaceAllUsesWith(newValue); | |||
1809 | forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); | |||
1810 | } | |||
1811 | } | |||
1812 | return success(); | |||
1813 | } | |||
1814 | ||||
1815 | Value BytecodeReader::createForwardRef() { | |||
1816 | // Check for an avaliable existing operation to use. Otherwise, create a new | |||
1817 | // fake operation to use for the reference. | |||
1818 | if (!openForwardRefOps.empty()) { | |||
1819 | Operation *op = &openForwardRefOps.back(); | |||
1820 | op->moveBefore(&forwardRefOps, forwardRefOps.end()); | |||
1821 | } else { | |||
1822 | forwardRefOps.push_back(Operation::create(forwardRefOpState)); | |||
1823 | } | |||
1824 | return forwardRefOps.back().getResult(0); | |||
1825 | } | |||
1826 | ||||
1827 | //===----------------------------------------------------------------------===// | |||
1828 | // Entry Points | |||
1829 | //===----------------------------------------------------------------------===// | |||
1830 | ||||
1831 | bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { | |||
1832 | return buffer.getBuffer().startswith("ML\xefR"); | |||
1833 | } | |||
1834 | ||||
1835 | /// Read the bytecode from the provided memory buffer reference. | |||
1836 | /// `bufferOwnerRef` if provided is the owning source manager for the buffer, | |||
1837 | /// and may be used to extend the lifetime of the buffer. | |||
1838 | static LogicalResult | |||
1839 | readBytecodeFileImpl(llvm::MemoryBufferRef buffer, Block *block, | |||
1840 | const ParserConfig &config, | |||
1841 | const std::shared_ptr<llvm::SourceMgr> &bufferOwnerRef) { | |||
1842 | Location sourceFileLoc = | |||
1843 | FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), | |||
1844 | /*line=*/0, /*column=*/0); | |||
1845 | if (!isBytecode(buffer)) { | |||
1846 | return emitError(sourceFileLoc, | |||
1847 | "input buffer is not an MLIR bytecode file"); | |||
1848 | } | |||
1849 | ||||
1850 | BytecodeReader reader(sourceFileLoc, config, bufferOwnerRef); | |||
1851 | return reader.read(buffer, block); | |||
1852 | } | |||
1853 | ||||
1854 | LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, | |||
1855 | const ParserConfig &config) { | |||
1856 | return readBytecodeFileImpl(buffer, block, config, /*bufferOwnerRef=*/{}); | |||
1857 | } | |||
1858 | LogicalResult | |||
1859 | mlir::readBytecodeFile(const std::shared_ptr<llvm::SourceMgr> &sourceMgr, | |||
1860 | Block *block, const ParserConfig &config) { | |||
1861 | return readBytecodeFileImpl( | |||
| ||||
1862 | *sourceMgr->getMemoryBuffer(sourceMgr->getMainFileID()), block, config, | |||
1863 | sourceMgr); | |||
1864 | } |