File: | build/llvm-toolchain-snapshot-16~++20220904122748+c444af1c20b3/mlir/lib/Bytecode/Reader/BytecodeReader.cpp |
Warning: | line 183, column 19 The left operand of '>=' is a garbage value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- BytecodeReader.cpp - MLIR Bytecode Reader --------------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | ||||
9 | // TODO: Support for big-endian architectures. | |||
10 | // TODO: Properly preserve use lists of values. | |||
11 | ||||
12 | #include "mlir/Bytecode/BytecodeReader.h" | |||
13 | #include "../Encoding.h" | |||
14 | #include "mlir/AsmParser/AsmParser.h" | |||
15 | #include "mlir/Bytecode/BytecodeImplementation.h" | |||
16 | #include "mlir/IR/BuiltinDialect.h" | |||
17 | #include "mlir/IR/BuiltinOps.h" | |||
18 | #include "mlir/IR/OpImplementation.h" | |||
19 | #include "mlir/IR/Verifier.h" | |||
20 | #include "llvm/ADT/MapVector.h" | |||
21 | #include "llvm/ADT/ScopeExit.h" | |||
22 | #include "llvm/ADT/SmallString.h" | |||
23 | #include "llvm/Support/MemoryBufferRef.h" | |||
24 | #include "llvm/Support/SaveAndRestore.h" | |||
25 | ||||
26 | #define DEBUG_TYPE"mlir-bytecode-reader" "mlir-bytecode-reader" | |||
27 | ||||
28 | using namespace mlir; | |||
29 | ||||
30 | /// Stringify the given section ID. | |||
31 | static std::string toString(bytecode::Section::ID sectionID) { | |||
32 | switch (sectionID) { | |||
33 | case bytecode::Section::kString: | |||
34 | return "String (0)"; | |||
35 | case bytecode::Section::kDialect: | |||
36 | return "Dialect (1)"; | |||
37 | case bytecode::Section::kAttrType: | |||
38 | return "AttrType (2)"; | |||
39 | case bytecode::Section::kAttrTypeOffset: | |||
40 | return "AttrTypeOffset (3)"; | |||
41 | case bytecode::Section::kIR: | |||
42 | return "IR (4)"; | |||
43 | default: | |||
44 | return ("Unknown (" + Twine(static_cast<unsigned>(sectionID)) + ")").str(); | |||
45 | } | |||
46 | } | |||
47 | ||||
48 | //===----------------------------------------------------------------------===// | |||
49 | // EncodingReader | |||
50 | //===----------------------------------------------------------------------===// | |||
51 | ||||
52 | namespace { | |||
53 | class EncodingReader { | |||
54 | public: | |||
55 | explicit EncodingReader(ArrayRef<uint8_t> contents, Location fileLoc) | |||
56 | : dataIt(contents.data()), dataEnd(contents.end()), fileLoc(fileLoc) {} | |||
57 | explicit EncodingReader(StringRef contents, Location fileLoc) | |||
58 | : EncodingReader({reinterpret_cast<const uint8_t *>(contents.data()), | |||
59 | contents.size()}, | |||
60 | fileLoc) {} | |||
61 | ||||
62 | /// Returns true if the entire section has been read. | |||
63 | bool empty() const { return dataIt == dataEnd; } | |||
64 | ||||
65 | /// Returns the remaining size of the bytecode. | |||
66 | size_t size() const { return dataEnd - dataIt; } | |||
67 | ||||
68 | /// Emit an error using the given arguments. | |||
69 | template <typename... Args> | |||
70 | InFlightDiagnostic emitError(Args &&...args) const { | |||
71 | return ::emitError(fileLoc).append(std::forward<Args>(args)...); | |||
72 | } | |||
73 | ||||
74 | /// Parse a single byte from the stream. | |||
75 | template <typename T> | |||
76 | LogicalResult parseByte(T &value) { | |||
77 | if (empty()) | |||
78 | return emitError("attempting to parse a byte at the end of the bytecode"); | |||
79 | value = static_cast<T>(*dataIt++); | |||
80 | return success(); | |||
81 | } | |||
82 | /// Parse a range of bytes of 'length' into the given result. | |||
83 | LogicalResult parseBytes(size_t length, ArrayRef<uint8_t> &result) { | |||
84 | if (length > size()) { | |||
85 | return emitError("attempting to parse ", length, " bytes when only ", | |||
86 | size(), " remain"); | |||
87 | } | |||
88 | result = {dataIt, length}; | |||
89 | dataIt += length; | |||
90 | return success(); | |||
91 | } | |||
92 | /// Parse a range of bytes of 'length' into the given result, which can be | |||
93 | /// assumed to be large enough to hold `length`. | |||
94 | LogicalResult parseBytes(size_t length, uint8_t *result) { | |||
95 | if (length > size()) { | |||
96 | return emitError("attempting to parse ", length, " bytes when only ", | |||
97 | size(), " remain"); | |||
98 | } | |||
99 | memcpy(result, dataIt, length); | |||
100 | dataIt += length; | |||
101 | return success(); | |||
102 | } | |||
103 | ||||
104 | /// Parse a variable length encoded integer from the byte stream. The first | |||
105 | /// encoded byte contains a prefix in the low bits indicating the encoded | |||
106 | /// length of the value. This length prefix is a bit sequence of '0's followed | |||
107 | /// by a '1'. The number of '0' bits indicate the number of _additional_ bytes | |||
108 | /// (not including the prefix byte). All remaining bits in the first byte, | |||
109 | /// along with all of the bits in additional bytes, provide the value of the | |||
110 | /// integer encoded in little-endian order. | |||
111 | LogicalResult parseVarInt(uint64_t &result) { | |||
112 | // Parse the first byte of the encoding, which contains the length prefix. | |||
113 | if (failed(parseByte(result))) | |||
114 | return failure(); | |||
115 | ||||
116 | // Handle the overwhelmingly common case where the value is stored in a | |||
117 | // single byte. In this case, the first bit is the `1` marker bit. | |||
118 | if (LLVM_LIKELY(result & 1)__builtin_expect((bool)(result & 1), true)) { | |||
119 | result >>= 1; | |||
120 | return success(); | |||
121 | } | |||
122 | ||||
123 | // Handle the overwhelming uncommon case where the value required all 8 | |||
124 | // bytes (i.e. a really really big number). In this case, the marker byte is | |||
125 | // all zeros: `00000000`. | |||
126 | if (LLVM_UNLIKELY(result == 0)__builtin_expect((bool)(result == 0), false)) | |||
127 | return parseBytes(sizeof(result), reinterpret_cast<uint8_t *>(&result)); | |||
128 | return parseMultiByteVarInt(result); | |||
129 | } | |||
130 | ||||
131 | /// Parse a signed variable length encoded integer from the byte stream. A | |||
132 | /// signed varint is encoded as a normal varint with zigzag encoding applied, | |||
133 | /// i.e. the low bit of the value is used to indicate the sign. | |||
134 | LogicalResult parseSignedVarInt(uint64_t &result) { | |||
135 | if (failed(parseVarInt(result))) | |||
136 | return failure(); | |||
137 | // Essentially (but using unsigned): (x >> 1) ^ -(x & 1) | |||
138 | result = (result >> 1) ^ (~(result & 1) + 1); | |||
139 | return success(); | |||
140 | } | |||
141 | ||||
142 | /// Parse a variable length encoded integer whose low bit is used to encode an | |||
143 | /// unrelated flag, i.e: `(integerValue << 1) | (flag ? 1 : 0)`. | |||
144 | LogicalResult parseVarIntWithFlag(uint64_t &result, bool &flag) { | |||
145 | if (failed(parseVarInt(result))) | |||
146 | return failure(); | |||
147 | flag = result & 1; | |||
148 | result >>= 1; | |||
149 | return success(); | |||
150 | } | |||
151 | ||||
152 | /// Skip the first `length` bytes within the reader. | |||
153 | LogicalResult skipBytes(size_t length) { | |||
154 | if (length > size()) { | |||
155 | return emitError("attempting to skip ", length, " bytes when only ", | |||
156 | size(), " remain"); | |||
157 | } | |||
158 | dataIt += length; | |||
159 | return success(); | |||
160 | } | |||
161 | ||||
162 | /// Parse a null-terminated string into `result` (without including the NUL | |||
163 | /// terminator). | |||
164 | LogicalResult parseNullTerminatedString(StringRef &result) { | |||
165 | const char *startIt = (const char *)dataIt; | |||
166 | const char *nulIt = (const char *)memchr(startIt, 0, size()); | |||
167 | if (!nulIt) | |||
168 | return emitError( | |||
169 | "malformed null-terminated string, no null character found"); | |||
170 | ||||
171 | result = StringRef(startIt, nulIt - startIt); | |||
172 | dataIt = (const uint8_t *)nulIt + 1; | |||
173 | return success(); | |||
174 | } | |||
175 | ||||
176 | /// Parse a section header, placing the kind of section in `sectionID` and the | |||
177 | /// contents of the section in `sectionData`. | |||
178 | LogicalResult parseSection(bytecode::Section::ID §ionID, | |||
179 | ArrayRef<uint8_t> §ionData) { | |||
180 | uint64_t length; | |||
181 | if (failed(parseByte(sectionID)) || failed(parseVarInt(length))) | |||
182 | return failure(); | |||
183 | if (sectionID >= bytecode::Section::kNumSections) | |||
| ||||
184 | return emitError("invalid section ID: ", unsigned(sectionID)); | |||
185 | ||||
186 | // Parse the actua section data now that we have its length. | |||
187 | return parseBytes(static_cast<size_t>(length), sectionData); | |||
188 | } | |||
189 | ||||
190 | private: | |||
191 | /// Parse a variable length encoded integer from the byte stream. This method | |||
192 | /// is a fallback when the number of bytes used to encode the value is greater | |||
193 | /// than 1, but less than the max (9). The provided `result` value can be | |||
194 | /// assumed to already contain the first byte of the value. | |||
195 | /// NOTE: This method is marked noinline to avoid pessimizing the common case | |||
196 | /// of single byte encoding. | |||
197 | LLVM_ATTRIBUTE_NOINLINE__attribute__((noinline)) LogicalResult parseMultiByteVarInt(uint64_t &result) { | |||
198 | // Count the number of trailing zeros in the marker byte, this indicates the | |||
199 | // number of trailing bytes that are part of the value. We use `uint32_t` | |||
200 | // here because we only care about the first byte, and so that be actually | |||
201 | // get ctz intrinsic calls when possible (the `uint8_t` overload uses a loop | |||
202 | // implementation). | |||
203 | uint32_t numBytes = | |||
204 | llvm::countTrailingZeros<uint32_t>(result, llvm::ZB_Undefined); | |||
205 | assert(numBytes > 0 && numBytes <= 7 &&(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 206, __extension__ __PRETTY_FUNCTION__)) | |||
206 | "unexpected number of trailing zeros in varint encoding")(static_cast <bool> (numBytes > 0 && numBytes <= 7 && "unexpected number of trailing zeros in varint encoding" ) ? void (0) : __assert_fail ("numBytes > 0 && numBytes <= 7 && \"unexpected number of trailing zeros in varint encoding\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 206, __extension__ __PRETTY_FUNCTION__)); | |||
207 | ||||
208 | // Parse in the remaining bytes of the value. | |||
209 | if (failed(parseBytes(numBytes, reinterpret_cast<uint8_t *>(&result) + 1))) | |||
210 | return failure(); | |||
211 | ||||
212 | // Shift out the low-order bits that were used to mark how the value was | |||
213 | // encoded. | |||
214 | result >>= (numBytes + 1); | |||
215 | return success(); | |||
216 | } | |||
217 | ||||
218 | /// The current data iterator, and an iterator to the end of the buffer. | |||
219 | const uint8_t *dataIt, *dataEnd; | |||
220 | ||||
221 | /// A location for the bytecode used to report errors. | |||
222 | Location fileLoc; | |||
223 | }; | |||
224 | } // namespace | |||
225 | ||||
226 | /// Resolve an index into the given entry list. `entry` may either be a | |||
227 | /// reference, in which case it is assigned to the corresponding value in | |||
228 | /// `entries`, or a pointer, in which case it is assigned to the address of the | |||
229 | /// element in `entries`. | |||
230 | template <typename RangeT, typename T> | |||
231 | static LogicalResult resolveEntry(EncodingReader &reader, RangeT &entries, | |||
232 | uint64_t index, T &entry, | |||
233 | StringRef entryStr) { | |||
234 | if (index >= entries.size()) | |||
235 | return reader.emitError("invalid ", entryStr, " index: ", index); | |||
236 | ||||
237 | // If the provided entry is a pointer, resolve to the address of the entry. | |||
238 | if constexpr (std::is_convertible_v<llvm::detail::ValueOfRange<RangeT>, T>) | |||
239 | entry = entries[index]; | |||
240 | else | |||
241 | entry = &entries[index]; | |||
242 | return success(); | |||
243 | } | |||
244 | ||||
245 | /// Parse and resolve an index into the given entry list. | |||
246 | template <typename RangeT, typename T> | |||
247 | static LogicalResult parseEntry(EncodingReader &reader, RangeT &entries, | |||
248 | T &entry, StringRef entryStr) { | |||
249 | uint64_t entryIdx; | |||
250 | if (failed(reader.parseVarInt(entryIdx))) | |||
251 | return failure(); | |||
252 | return resolveEntry(reader, entries, entryIdx, entry, entryStr); | |||
253 | } | |||
254 | ||||
255 | //===----------------------------------------------------------------------===// | |||
256 | // StringSectionReader | |||
257 | //===----------------------------------------------------------------------===// | |||
258 | ||||
259 | namespace { | |||
260 | /// This class is used to read references to the string section from the | |||
261 | /// bytecode. | |||
262 | class StringSectionReader { | |||
263 | public: | |||
264 | /// Initialize the string section reader with the given section data. | |||
265 | LogicalResult initialize(Location fileLoc, ArrayRef<uint8_t> sectionData); | |||
266 | ||||
267 | /// Parse a shared string from the string section. The shared string is | |||
268 | /// encoded using an index to a corresponding string in the string section. | |||
269 | LogicalResult parseString(EncodingReader &reader, StringRef &result) { | |||
270 | return parseEntry(reader, strings, result, "string"); | |||
271 | } | |||
272 | ||||
273 | private: | |||
274 | /// The table of strings referenced within the bytecode file. | |||
275 | SmallVector<StringRef> strings; | |||
276 | }; | |||
277 | } // namespace | |||
278 | ||||
279 | LogicalResult StringSectionReader::initialize(Location fileLoc, | |||
280 | ArrayRef<uint8_t> sectionData) { | |||
281 | EncodingReader stringReader(sectionData, fileLoc); | |||
282 | ||||
283 | // Parse the number of strings in the section. | |||
284 | uint64_t numStrings; | |||
285 | if (failed(stringReader.parseVarInt(numStrings))) | |||
286 | return failure(); | |||
287 | strings.resize(numStrings); | |||
288 | ||||
289 | // Parse each of the strings. The sizes of the strings are encoded in reverse | |||
290 | // order, so that's the order we populate the table. | |||
291 | size_t stringDataEndOffset = sectionData.size(); | |||
292 | for (StringRef &string : llvm::reverse(strings)) { | |||
293 | uint64_t stringSize; | |||
294 | if (failed(stringReader.parseVarInt(stringSize))) | |||
295 | return failure(); | |||
296 | if (stringDataEndOffset < stringSize) { | |||
297 | return stringReader.emitError( | |||
298 | "string size exceeds the available data size"); | |||
299 | } | |||
300 | ||||
301 | // Extract the string from the data, dropping the null character. | |||
302 | size_t stringOffset = stringDataEndOffset - stringSize; | |||
303 | string = StringRef( | |||
304 | reinterpret_cast<const char *>(sectionData.data() + stringOffset), | |||
305 | stringSize - 1); | |||
306 | stringDataEndOffset = stringOffset; | |||
307 | } | |||
308 | ||||
309 | // Check that the only remaining data was for the strings, i.e. the reader | |||
310 | // should be at the same offset as the first string. | |||
311 | if ((sectionData.size() - stringReader.size()) != stringDataEndOffset) { | |||
312 | return stringReader.emitError("unexpected trailing data between the " | |||
313 | "offsets for strings and their data"); | |||
314 | } | |||
315 | return success(); | |||
316 | } | |||
317 | ||||
318 | //===----------------------------------------------------------------------===// | |||
319 | // BytecodeDialect | |||
320 | //===----------------------------------------------------------------------===// | |||
321 | ||||
322 | namespace { | |||
323 | /// This struct represents a dialect entry within the bytecode. | |||
324 | struct BytecodeDialect { | |||
325 | /// Load the dialect into the provided context if it hasn't been loaded yet. | |||
326 | /// Returns failure if the dialect couldn't be loaded *and* the provided | |||
327 | /// context does not allow unregistered dialects. The provided reader is used | |||
328 | /// for error emission if necessary. | |||
329 | LogicalResult load(EncodingReader &reader, MLIRContext *ctx) { | |||
330 | if (dialect) | |||
331 | return success(); | |||
332 | Dialect *loadedDialect = ctx->getOrLoadDialect(name); | |||
333 | if (!loadedDialect && !ctx->allowsUnregisteredDialects()) { | |||
334 | return reader.emitError( | |||
335 | "dialect '", name, | |||
336 | "' is unknown. If this is intended, please call " | |||
337 | "allowUnregisteredDialects() on the MLIRContext, or use " | |||
338 | "-allow-unregistered-dialect with the MLIR tool used."); | |||
339 | } | |||
340 | dialect = loadedDialect; | |||
341 | ||||
342 | // If the dialect was actually loaded, check to see if it has a bytecode | |||
343 | // interface. | |||
344 | if (loadedDialect) | |||
345 | interface = dyn_cast<BytecodeDialectInterface>(loadedDialect); | |||
346 | return success(); | |||
347 | } | |||
348 | ||||
349 | /// The loaded dialect entry. This field is None if we haven't attempted to | |||
350 | /// load, nullptr if we failed to load, otherwise the loaded dialect. | |||
351 | Optional<Dialect *> dialect; | |||
352 | ||||
353 | /// The bytecode interface of the dialect, or nullptr if the dialect does not | |||
354 | /// implement the bytecode interface. This field should only be checked if the | |||
355 | /// `dialect` field is non-None. | |||
356 | const BytecodeDialectInterface *interface = nullptr; | |||
357 | ||||
358 | /// The name of the dialect. | |||
359 | StringRef name; | |||
360 | }; | |||
361 | ||||
362 | /// This struct represents an operation name entry within the bytecode. | |||
363 | struct BytecodeOperationName { | |||
364 | BytecodeOperationName(BytecodeDialect *dialect, StringRef name) | |||
365 | : dialect(dialect), name(name) {} | |||
366 | ||||
367 | /// The loaded operation name, or None if it hasn't been processed yet. | |||
368 | Optional<OperationName> opName; | |||
369 | ||||
370 | /// The dialect that owns this operation name. | |||
371 | BytecodeDialect *dialect; | |||
372 | ||||
373 | /// The name of the operation, without the dialect prefix. | |||
374 | StringRef name; | |||
375 | }; | |||
376 | } // namespace | |||
377 | ||||
378 | /// Parse a single dialect group encoded in the byte stream. | |||
379 | static LogicalResult parseDialectGrouping( | |||
380 | EncodingReader &reader, MutableArrayRef<BytecodeDialect> dialects, | |||
381 | function_ref<LogicalResult(BytecodeDialect *)> entryCallback) { | |||
382 | // Parse the dialect and the number of entries in the group. | |||
383 | BytecodeDialect *dialect; | |||
384 | if (failed(parseEntry(reader, dialects, dialect, "dialect"))) | |||
385 | return failure(); | |||
386 | uint64_t numEntries; | |||
387 | if (failed(reader.parseVarInt(numEntries))) | |||
388 | return failure(); | |||
389 | ||||
390 | for (uint64_t i = 0; i < numEntries; ++i) | |||
391 | if (failed(entryCallback(dialect))) | |||
392 | return failure(); | |||
393 | return success(); | |||
394 | } | |||
395 | ||||
396 | //===----------------------------------------------------------------------===// | |||
397 | // Attribute/Type Reader | |||
398 | //===----------------------------------------------------------------------===// | |||
399 | ||||
400 | namespace { | |||
401 | /// This class provides support for reading attribute and type entries from the | |||
402 | /// bytecode. Attribute and Type entries are read lazily on demand, so we use | |||
403 | /// this reader to manage when to actually parse them from the bytecode. | |||
404 | class AttrTypeReader { | |||
405 | /// This class represents a single attribute or type entry. | |||
406 | template <typename T> | |||
407 | struct Entry { | |||
408 | /// The entry, or null if it hasn't been resolved yet. | |||
409 | T entry = {}; | |||
410 | /// The parent dialect of this entry. | |||
411 | BytecodeDialect *dialect = nullptr; | |||
412 | /// A flag indicating if the entry was encoded using a custom encoding, | |||
413 | /// instead of using the textual assembly format. | |||
414 | bool hasCustomEncoding = false; | |||
415 | /// The raw data of this entry in the bytecode. | |||
416 | ArrayRef<uint8_t> data; | |||
417 | }; | |||
418 | using AttrEntry = Entry<Attribute>; | |||
419 | using TypeEntry = Entry<Type>; | |||
420 | ||||
421 | public: | |||
422 | AttrTypeReader(StringSectionReader &stringReader, Location fileLoc) | |||
423 | : stringReader(stringReader), fileLoc(fileLoc) {} | |||
424 | ||||
425 | /// Initialize the attribute and type information within the reader. | |||
426 | LogicalResult initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
427 | ArrayRef<uint8_t> sectionData, | |||
428 | ArrayRef<uint8_t> offsetSectionData); | |||
429 | ||||
430 | /// Resolve the attribute or type at the given index. Returns nullptr on | |||
431 | /// failure. | |||
432 | Attribute resolveAttribute(size_t index) { | |||
433 | return resolveEntry(attributes, index, "Attribute"); | |||
434 | } | |||
435 | Type resolveType(size_t index) { return resolveEntry(types, index, "Type"); } | |||
436 | ||||
437 | /// Parse a reference to an attribute or type using the given reader. | |||
438 | LogicalResult parseAttribute(EncodingReader &reader, Attribute &result) { | |||
439 | uint64_t attrIdx; | |||
440 | if (failed(reader.parseVarInt(attrIdx))) | |||
441 | return failure(); | |||
442 | result = resolveAttribute(attrIdx); | |||
443 | return success(!!result); | |||
444 | } | |||
445 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
446 | uint64_t typeIdx; | |||
447 | if (failed(reader.parseVarInt(typeIdx))) | |||
448 | return failure(); | |||
449 | result = resolveType(typeIdx); | |||
450 | return success(!!result); | |||
451 | } | |||
452 | ||||
453 | template <typename T> | |||
454 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
455 | Attribute baseResult; | |||
456 | if (failed(parseAttribute(reader, baseResult))) | |||
457 | return failure(); | |||
458 | if ((result = baseResult.dyn_cast<T>())) | |||
459 | return success(); | |||
460 | return reader.emitError("expected attribute of type: ", | |||
461 | llvm::getTypeName<T>(), ", but got: ", baseResult); | |||
462 | } | |||
463 | ||||
464 | private: | |||
465 | /// Resolve the given entry at `index`. | |||
466 | template <typename T> | |||
467 | T resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
468 | StringRef entryType); | |||
469 | ||||
470 | /// Parse an entry using the given reader that was encoded using the textual | |||
471 | /// assembly format. | |||
472 | template <typename T> | |||
473 | LogicalResult parseAsmEntry(T &result, EncodingReader &reader, | |||
474 | StringRef entryType); | |||
475 | ||||
476 | /// Parse an entry using the given reader that was encoded using a custom | |||
477 | /// bytecode format. | |||
478 | template <typename T> | |||
479 | LogicalResult parseCustomEntry(Entry<T> &entry, EncodingReader &reader, | |||
480 | StringRef entryType); | |||
481 | ||||
482 | /// The string section reader used to resolve string references when parsing | |||
483 | /// custom encoded attribute/type entries. | |||
484 | StringSectionReader &stringReader; | |||
485 | ||||
486 | /// The set of attribute and type entries. | |||
487 | SmallVector<AttrEntry> attributes; | |||
488 | SmallVector<TypeEntry> types; | |||
489 | ||||
490 | /// A location used for error emission. | |||
491 | Location fileLoc; | |||
492 | }; | |||
493 | ||||
494 | class DialectReader : public DialectBytecodeReader { | |||
495 | public: | |||
496 | DialectReader(AttrTypeReader &attrTypeReader, | |||
497 | StringSectionReader &stringReader, EncodingReader &reader) | |||
498 | : attrTypeReader(attrTypeReader), stringReader(stringReader), | |||
499 | reader(reader) {} | |||
500 | ||||
501 | InFlightDiagnostic emitError(const Twine &msg) override { | |||
502 | return reader.emitError(msg); | |||
503 | } | |||
504 | ||||
505 | //===--------------------------------------------------------------------===// | |||
506 | // IR | |||
507 | //===--------------------------------------------------------------------===// | |||
508 | ||||
509 | LogicalResult readAttribute(Attribute &result) override { | |||
510 | return attrTypeReader.parseAttribute(reader, result); | |||
511 | } | |||
512 | ||||
513 | LogicalResult readType(Type &result) override { | |||
514 | return attrTypeReader.parseType(reader, result); | |||
515 | } | |||
516 | ||||
517 | //===--------------------------------------------------------------------===// | |||
518 | // Primitives | |||
519 | //===--------------------------------------------------------------------===// | |||
520 | ||||
521 | LogicalResult readVarInt(uint64_t &result) override { | |||
522 | return reader.parseVarInt(result); | |||
523 | } | |||
524 | ||||
525 | LogicalResult readSignedVarInt(int64_t &result) override { | |||
526 | uint64_t unsignedResult; | |||
527 | if (failed(reader.parseSignedVarInt(unsignedResult))) | |||
528 | return failure(); | |||
529 | result = static_cast<int64_t>(unsignedResult); | |||
530 | return success(); | |||
531 | } | |||
532 | ||||
533 | FailureOr<APInt> readAPIntWithKnownWidth(unsigned bitWidth) override { | |||
534 | // Small values are encoded using a single byte. | |||
535 | if (bitWidth <= 8) { | |||
536 | uint8_t value; | |||
537 | if (failed(reader.parseByte(value))) | |||
538 | return failure(); | |||
539 | return APInt(bitWidth, value); | |||
540 | } | |||
541 | ||||
542 | // Large values up to 64 bits are encoded using a single varint. | |||
543 | if (bitWidth <= 64) { | |||
544 | uint64_t value; | |||
545 | if (failed(reader.parseSignedVarInt(value))) | |||
546 | return failure(); | |||
547 | return APInt(bitWidth, value); | |||
548 | } | |||
549 | ||||
550 | // Otherwise, for really big values we encode the array of active words in | |||
551 | // the value. | |||
552 | uint64_t numActiveWords; | |||
553 | if (failed(reader.parseVarInt(numActiveWords))) | |||
554 | return failure(); | |||
555 | SmallVector<uint64_t, 4> words(numActiveWords); | |||
556 | for (uint64_t i = 0; i < numActiveWords; ++i) | |||
557 | if (failed(reader.parseSignedVarInt(words[i]))) | |||
558 | return failure(); | |||
559 | return APInt(bitWidth, words); | |||
560 | } | |||
561 | ||||
562 | FailureOr<APFloat> | |||
563 | readAPFloatWithKnownSemantics(const llvm::fltSemantics &semantics) override { | |||
564 | FailureOr<APInt> intVal = | |||
565 | readAPIntWithKnownWidth(APFloat::getSizeInBits(semantics)); | |||
566 | if (failed(intVal)) | |||
567 | return failure(); | |||
568 | return APFloat(semantics, *intVal); | |||
569 | } | |||
570 | ||||
571 | LogicalResult readString(StringRef &result) override { | |||
572 | return stringReader.parseString(reader, result); | |||
573 | } | |||
574 | ||||
575 | private: | |||
576 | AttrTypeReader &attrTypeReader; | |||
577 | StringSectionReader &stringReader; | |||
578 | EncodingReader &reader; | |||
579 | }; | |||
580 | } // namespace | |||
581 | ||||
582 | LogicalResult | |||
583 | AttrTypeReader::initialize(MutableArrayRef<BytecodeDialect> dialects, | |||
584 | ArrayRef<uint8_t> sectionData, | |||
585 | ArrayRef<uint8_t> offsetSectionData) { | |||
586 | EncodingReader offsetReader(offsetSectionData, fileLoc); | |||
587 | ||||
588 | // Parse the number of attribute and type entries. | |||
589 | uint64_t numAttributes, numTypes; | |||
590 | if (failed(offsetReader.parseVarInt(numAttributes)) || | |||
591 | failed(offsetReader.parseVarInt(numTypes))) | |||
592 | return failure(); | |||
593 | attributes.resize(numAttributes); | |||
594 | types.resize(numTypes); | |||
595 | ||||
596 | // A functor used to accumulate the offsets for the entries in the given | |||
597 | // range. | |||
598 | uint64_t currentOffset = 0; | |||
599 | auto parseEntries = [&](auto &&range) { | |||
600 | size_t currentIndex = 0, endIndex = range.size(); | |||
601 | ||||
602 | // Parse an individual entry. | |||
603 | auto parseEntryFn = [&](BytecodeDialect *dialect) -> LogicalResult { | |||
604 | auto &entry = range[currentIndex++]; | |||
605 | ||||
606 | uint64_t entrySize; | |||
607 | if (failed(offsetReader.parseVarIntWithFlag(entrySize, | |||
608 | entry.hasCustomEncoding))) | |||
609 | return failure(); | |||
610 | ||||
611 | // Verify that the offset is actually valid. | |||
612 | if (currentOffset + entrySize > sectionData.size()) { | |||
613 | return offsetReader.emitError( | |||
614 | "Attribute or Type entry offset points past the end of section"); | |||
615 | } | |||
616 | ||||
617 | entry.data = sectionData.slice(currentOffset, entrySize); | |||
618 | entry.dialect = dialect; | |||
619 | currentOffset += entrySize; | |||
620 | return success(); | |||
621 | }; | |||
622 | while (currentIndex != endIndex) | |||
623 | if (failed(parseDialectGrouping(offsetReader, dialects, parseEntryFn))) | |||
624 | return failure(); | |||
625 | return success(); | |||
626 | }; | |||
627 | ||||
628 | // Process each of the attributes, and then the types. | |||
629 | if (failed(parseEntries(attributes)) || failed(parseEntries(types))) | |||
630 | return failure(); | |||
631 | ||||
632 | // Ensure that we read everything from the section. | |||
633 | if (!offsetReader.empty()) { | |||
634 | return offsetReader.emitError( | |||
635 | "unexpected trailing data in the Attribute/Type offset section"); | |||
636 | } | |||
637 | return success(); | |||
638 | } | |||
639 | ||||
640 | template <typename T> | |||
641 | T AttrTypeReader::resolveEntry(SmallVectorImpl<Entry<T>> &entries, size_t index, | |||
642 | StringRef entryType) { | |||
643 | if (index >= entries.size()) { | |||
644 | emitError(fileLoc) << "invalid " << entryType << " index: " << index; | |||
645 | return {}; | |||
646 | } | |||
647 | ||||
648 | // If the entry has already been resolved, there is nothing left to do. | |||
649 | Entry<T> &entry = entries[index]; | |||
650 | if (entry.entry) | |||
651 | return entry.entry; | |||
652 | ||||
653 | // Parse the entry. | |||
654 | EncodingReader reader(entry.data, fileLoc); | |||
655 | ||||
656 | // Parse based on how the entry was encoded. | |||
657 | if (entry.hasCustomEncoding) { | |||
658 | if (failed(parseCustomEntry(entry, reader, entryType))) | |||
659 | return T(); | |||
660 | } else if (failed(parseAsmEntry(entry.entry, reader, entryType))) { | |||
661 | return T(); | |||
662 | } | |||
663 | ||||
664 | if (!reader.empty()) { | |||
665 | reader.emitError("unexpected trailing bytes after " + entryType + " entry"); | |||
666 | return T(); | |||
667 | } | |||
668 | return entry.entry; | |||
669 | } | |||
670 | ||||
671 | template <typename T> | |||
672 | LogicalResult AttrTypeReader::parseAsmEntry(T &result, EncodingReader &reader, | |||
673 | StringRef entryType) { | |||
674 | StringRef asmStr; | |||
675 | if (failed(reader.parseNullTerminatedString(asmStr))) | |||
676 | return failure(); | |||
677 | ||||
678 | // Invoke the MLIR assembly parser to parse the entry text. | |||
679 | size_t numRead = 0; | |||
680 | MLIRContext *context = fileLoc->getContext(); | |||
681 | if constexpr (std::is_same_v<T, Type>) | |||
682 | result = ::parseType(asmStr, context, numRead); | |||
683 | else | |||
684 | result = ::parseAttribute(asmStr, context, numRead); | |||
685 | if (!result) | |||
686 | return failure(); | |||
687 | ||||
688 | // Ensure there weren't dangling characters after the entry. | |||
689 | if (numRead != asmStr.size()) { | |||
690 | return reader.emitError("trailing characters found after ", entryType, | |||
691 | " assembly format: ", asmStr.drop_front(numRead)); | |||
692 | } | |||
693 | return success(); | |||
694 | } | |||
695 | ||||
696 | template <typename T> | |||
697 | LogicalResult AttrTypeReader::parseCustomEntry(Entry<T> &entry, | |||
698 | EncodingReader &reader, | |||
699 | StringRef entryType) { | |||
700 | if (failed(entry.dialect->load(reader, fileLoc.getContext()))) | |||
701 | return failure(); | |||
702 | ||||
703 | // Ensure that the dialect implements the bytecode interface. | |||
704 | if (!entry.dialect->interface) { | |||
705 | return reader.emitError("dialect '", entry.dialect->name, | |||
706 | "' does not implement the bytecode interface"); | |||
707 | } | |||
708 | ||||
709 | // Ask the dialect to parse the entry. | |||
710 | DialectReader dialectReader(*this, stringReader, reader); | |||
711 | if constexpr (std::is_same_v<T, Type>) | |||
712 | entry.entry = entry.dialect->interface->readType(dialectReader); | |||
713 | else | |||
714 | entry.entry = entry.dialect->interface->readAttribute(dialectReader); | |||
715 | return success(!!entry.entry); | |||
716 | } | |||
717 | ||||
718 | //===----------------------------------------------------------------------===// | |||
719 | // Bytecode Reader | |||
720 | //===----------------------------------------------------------------------===// | |||
721 | ||||
722 | namespace { | |||
723 | /// This class is used to read a bytecode buffer and translate it into MLIR. | |||
724 | class BytecodeReader { | |||
725 | public: | |||
726 | BytecodeReader(Location fileLoc, const ParserConfig &config) | |||
727 | : config(config), fileLoc(fileLoc), attrTypeReader(stringReader, fileLoc), | |||
728 | // Use the builtin unrealized conversion cast operation to represent | |||
729 | // forward references to values that aren't yet defined. | |||
730 | forwardRefOpState(UnknownLoc::get(config.getContext()), | |||
731 | "builtin.unrealized_conversion_cast", ValueRange(), | |||
732 | NoneType::get(config.getContext())) {} | |||
733 | ||||
734 | /// Read the bytecode defined within `buffer` into the given block. | |||
735 | LogicalResult read(llvm::MemoryBufferRef buffer, Block *block); | |||
736 | ||||
737 | private: | |||
738 | /// Return the context for this config. | |||
739 | MLIRContext *getContext() const { return config.getContext(); } | |||
740 | ||||
741 | /// Parse the bytecode version. | |||
742 | LogicalResult parseVersion(EncodingReader &reader); | |||
743 | ||||
744 | //===--------------------------------------------------------------------===// | |||
745 | // Dialect Section | |||
746 | ||||
747 | LogicalResult parseDialectSection(ArrayRef<uint8_t> sectionData); | |||
748 | ||||
749 | /// Parse an operation name reference using the given reader. | |||
750 | FailureOr<OperationName> parseOpName(EncodingReader &reader); | |||
751 | ||||
752 | //===--------------------------------------------------------------------===// | |||
753 | // Attribute/Type Section | |||
754 | ||||
755 | /// Parse an attribute or type using the given reader. | |||
756 | template <typename T> | |||
757 | LogicalResult parseAttribute(EncodingReader &reader, T &result) { | |||
758 | return attrTypeReader.parseAttribute(reader, result); | |||
759 | } | |||
760 | LogicalResult parseType(EncodingReader &reader, Type &result) { | |||
761 | return attrTypeReader.parseType(reader, result); | |||
762 | } | |||
763 | ||||
764 | //===--------------------------------------------------------------------===// | |||
765 | // IR Section | |||
766 | ||||
767 | /// This struct represents the current read state of a range of regions. This | |||
768 | /// struct is used to enable iterative parsing of regions. | |||
769 | struct RegionReadState { | |||
770 | RegionReadState(Operation *op, bool isIsolatedFromAbove) | |||
771 | : RegionReadState(op->getRegions(), isIsolatedFromAbove) {} | |||
772 | RegionReadState(MutableArrayRef<Region> regions, bool isIsolatedFromAbove) | |||
773 | : curRegion(regions.begin()), endRegion(regions.end()), | |||
774 | isIsolatedFromAbove(isIsolatedFromAbove) {} | |||
775 | ||||
776 | /// The current regions being read. | |||
777 | MutableArrayRef<Region>::iterator curRegion, endRegion; | |||
778 | ||||
779 | /// The number of values defined immediately within this region. | |||
780 | unsigned numValues = 0; | |||
781 | ||||
782 | /// The current blocks of the region being read. | |||
783 | SmallVector<Block *> curBlocks; | |||
784 | Region::iterator curBlock = {}; | |||
785 | ||||
786 | /// The number of operations remaining to be read from the current block | |||
787 | /// being read. | |||
788 | uint64_t numOpsRemaining = 0; | |||
789 | ||||
790 | /// A flag indicating if the regions being read are isolated from above. | |||
791 | bool isIsolatedFromAbove = false; | |||
792 | }; | |||
793 | ||||
794 | LogicalResult parseIRSection(ArrayRef<uint8_t> sectionData, Block *block); | |||
795 | LogicalResult parseRegions(EncodingReader &reader, | |||
796 | std::vector<RegionReadState> ®ionStack, | |||
797 | RegionReadState &readState); | |||
798 | FailureOr<Operation *> parseOpWithoutRegions(EncodingReader &reader, | |||
799 | RegionReadState &readState, | |||
800 | bool &isIsolatedFromAbove); | |||
801 | ||||
802 | LogicalResult parseRegion(EncodingReader &reader, RegionReadState &readState); | |||
803 | LogicalResult parseBlock(EncodingReader &reader, RegionReadState &readState); | |||
804 | LogicalResult parseBlockArguments(EncodingReader &reader, Block *block); | |||
805 | ||||
806 | //===--------------------------------------------------------------------===// | |||
807 | // Value Processing | |||
808 | ||||
809 | /// Parse an operand reference using the given reader. Returns nullptr in the | |||
810 | /// case of failure. | |||
811 | Value parseOperand(EncodingReader &reader); | |||
812 | ||||
813 | /// Sequentially define the given value range. | |||
814 | LogicalResult defineValues(EncodingReader &reader, ValueRange values); | |||
815 | ||||
816 | /// Create a value to use for a forward reference. | |||
817 | Value createForwardRef(); | |||
818 | ||||
819 | //===--------------------------------------------------------------------===// | |||
820 | // Fields | |||
821 | ||||
822 | /// This class represents a single value scope, in which a value scope is | |||
823 | /// delimited by isolated from above regions. | |||
824 | struct ValueScope { | |||
825 | /// Push a new region state onto this scope, reserving enough values for | |||
826 | /// those defined within the current region of the provided state. | |||
827 | void push(RegionReadState &readState) { | |||
828 | nextValueIDs.push_back(values.size()); | |||
829 | values.resize(values.size() + readState.numValues); | |||
830 | } | |||
831 | ||||
832 | /// Pop the values defined for the current region within the provided region | |||
833 | /// state. | |||
834 | void pop(RegionReadState &readState) { | |||
835 | values.resize(values.size() - readState.numValues); | |||
836 | nextValueIDs.pop_back(); | |||
837 | } | |||
838 | ||||
839 | /// The set of values defined in this scope. | |||
840 | std::vector<Value> values; | |||
841 | ||||
842 | /// The ID for the next defined value for each region current being | |||
843 | /// processed in this scope. | |||
844 | SmallVector<unsigned, 4> nextValueIDs; | |||
845 | }; | |||
846 | ||||
847 | /// The configuration of the parser. | |||
848 | const ParserConfig &config; | |||
849 | ||||
850 | /// A location to use when emitting errors. | |||
851 | Location fileLoc; | |||
852 | ||||
853 | /// The reader used to process attribute and types within the bytecode. | |||
854 | AttrTypeReader attrTypeReader; | |||
855 | ||||
856 | /// The version of the bytecode being read. | |||
857 | uint64_t version = 0; | |||
858 | ||||
859 | /// The producer of the bytecode being read. | |||
860 | StringRef producer; | |||
861 | ||||
862 | /// The table of IR units referenced within the bytecode file. | |||
863 | SmallVector<BytecodeDialect> dialects; | |||
864 | SmallVector<BytecodeOperationName> opNames; | |||
865 | ||||
866 | /// The table of strings referenced within the bytecode file. | |||
867 | StringSectionReader stringReader; | |||
868 | ||||
869 | /// The current set of available IR value scopes. | |||
870 | std::vector<ValueScope> valueScopes; | |||
871 | /// A block containing the set of operations defined to create forward | |||
872 | /// references. | |||
873 | Block forwardRefOps; | |||
874 | /// A block containing previously created, and no longer used, forward | |||
875 | /// reference operations. | |||
876 | Block openForwardRefOps; | |||
877 | /// An operation state used when instantiating forward references. | |||
878 | OperationState forwardRefOpState; | |||
879 | }; | |||
880 | } // namespace | |||
881 | ||||
882 | LogicalResult BytecodeReader::read(llvm::MemoryBufferRef buffer, Block *block) { | |||
883 | EncodingReader reader(buffer.getBuffer(), fileLoc); | |||
884 | ||||
885 | // Skip over the bytecode header, this should have already been checked. | |||
886 | if (failed(reader.skipBytes(StringRef("ML\xefR").size()))) | |||
887 | return failure(); | |||
888 | // Parse the bytecode version and producer. | |||
889 | if (failed(parseVersion(reader)) || | |||
890 | failed(reader.parseNullTerminatedString(producer))) | |||
891 | return failure(); | |||
892 | ||||
893 | // Add a diagnostic handler that attaches a note that includes the original | |||
894 | // producer of the bytecode. | |||
895 | ScopedDiagnosticHandler diagHandler(getContext(), [&](Diagnostic &diag) { | |||
896 | diag.attachNote() << "in bytecode version " << version | |||
897 | << " produced by: " << producer; | |||
898 | return failure(); | |||
899 | }); | |||
900 | ||||
901 | // Parse the raw data for each of the top-level sections of the bytecode. | |||
902 | Optional<ArrayRef<uint8_t>> sectionDatas[bytecode::Section::kNumSections]; | |||
903 | while (!reader.empty()) { | |||
904 | // Read the next section from the bytecode. | |||
905 | bytecode::Section::ID sectionID; | |||
906 | ArrayRef<uint8_t> sectionData; | |||
907 | if (failed(reader.parseSection(sectionID, sectionData))) | |||
908 | return failure(); | |||
909 | ||||
910 | // Check for duplicate sections, we only expect one instance of each. | |||
911 | if (sectionDatas[sectionID]) { | |||
912 | return reader.emitError("duplicate top-level section: ", | |||
913 | toString(sectionID)); | |||
914 | } | |||
915 | sectionDatas[sectionID] = sectionData; | |||
916 | } | |||
917 | // Check that all of the sections were found. | |||
918 | for (int i = 0; i < bytecode::Section::kNumSections; ++i) { | |||
919 | if (!sectionDatas[i]) { | |||
920 | return reader.emitError("missing data for top-level section: ", | |||
921 | toString(bytecode::Section::ID(i))); | |||
922 | } | |||
923 | } | |||
924 | ||||
925 | // Process the string section first. | |||
926 | if (failed(stringReader.initialize( | |||
927 | fileLoc, *sectionDatas[bytecode::Section::kString]))) | |||
928 | return failure(); | |||
929 | ||||
930 | // Process the dialect section. | |||
931 | if (failed(parseDialectSection(*sectionDatas[bytecode::Section::kDialect]))) | |||
932 | return failure(); | |||
933 | ||||
934 | // Process the attribute and type section. | |||
935 | if (failed(attrTypeReader.initialize( | |||
936 | dialects, *sectionDatas[bytecode::Section::kAttrType], | |||
937 | *sectionDatas[bytecode::Section::kAttrTypeOffset]))) | |||
938 | return failure(); | |||
939 | ||||
940 | // Finally, process the IR section. | |||
941 | return parseIRSection(*sectionDatas[bytecode::Section::kIR], block); | |||
942 | } | |||
943 | ||||
944 | LogicalResult BytecodeReader::parseVersion(EncodingReader &reader) { | |||
945 | if (failed(reader.parseVarInt(version))) | |||
946 | return failure(); | |||
947 | ||||
948 | // Validate the bytecode version. | |||
949 | uint64_t currentVersion = bytecode::kVersion; | |||
950 | if (version < currentVersion) { | |||
951 | return reader.emitError("bytecode version ", version, | |||
952 | " is older than the current version of ", | |||
953 | currentVersion, ", and upgrade is not supported"); | |||
954 | } | |||
955 | if (version > currentVersion) { | |||
956 | return reader.emitError("bytecode version ", version, | |||
957 | " is newer than the current version ", | |||
958 | currentVersion); | |||
959 | } | |||
960 | return success(); | |||
961 | } | |||
962 | ||||
963 | //===----------------------------------------------------------------------===// | |||
964 | // Dialect Section | |||
965 | ||||
966 | LogicalResult | |||
967 | BytecodeReader::parseDialectSection(ArrayRef<uint8_t> sectionData) { | |||
968 | EncodingReader sectionReader(sectionData, fileLoc); | |||
969 | ||||
970 | // Parse the number of dialects in the section. | |||
971 | uint64_t numDialects; | |||
972 | if (failed(sectionReader.parseVarInt(numDialects))) | |||
973 | return failure(); | |||
974 | dialects.resize(numDialects); | |||
975 | ||||
976 | // Parse each of the dialects. | |||
977 | for (uint64_t i = 0; i < numDialects; ++i) | |||
978 | if (failed(stringReader.parseString(sectionReader, dialects[i].name))) | |||
979 | return failure(); | |||
980 | ||||
981 | // Parse the operation names, which are grouped by dialect. | |||
982 | auto parseOpName = [&](BytecodeDialect *dialect) { | |||
983 | StringRef opName; | |||
984 | if (failed(stringReader.parseString(sectionReader, opName))) | |||
985 | return failure(); | |||
986 | opNames.emplace_back(dialect, opName); | |||
987 | return success(); | |||
988 | }; | |||
989 | while (!sectionReader.empty()) | |||
990 | if (failed(parseDialectGrouping(sectionReader, dialects, parseOpName))) | |||
991 | return failure(); | |||
992 | return success(); | |||
993 | } | |||
994 | ||||
995 | FailureOr<OperationName> BytecodeReader::parseOpName(EncodingReader &reader) { | |||
996 | BytecodeOperationName *opName = nullptr; | |||
997 | if (failed(parseEntry(reader, opNames, opName, "operation name"))) | |||
998 | return failure(); | |||
999 | ||||
1000 | // Check to see if this operation name has already been resolved. If we | |||
1001 | // haven't, load the dialect and build the operation name. | |||
1002 | if (!opName->opName) { | |||
1003 | if (failed(opName->dialect->load(reader, getContext()))) | |||
1004 | return failure(); | |||
1005 | opName->opName.emplace((opName->dialect->name + "." + opName->name).str(), | |||
1006 | getContext()); | |||
1007 | } | |||
1008 | return *opName->opName; | |||
1009 | } | |||
1010 | ||||
1011 | //===----------------------------------------------------------------------===// | |||
1012 | // IR Section | |||
1013 | ||||
1014 | LogicalResult BytecodeReader::parseIRSection(ArrayRef<uint8_t> sectionData, | |||
1015 | Block *block) { | |||
1016 | EncodingReader reader(sectionData, fileLoc); | |||
1017 | ||||
1018 | // A stack of operation regions currently being read from the bytecode. | |||
1019 | std::vector<RegionReadState> regionStack; | |||
1020 | ||||
1021 | // Parse the top-level block using a temporary module operation. | |||
1022 | OwningOpRef<ModuleOp> moduleOp = ModuleOp::create(fileLoc); | |||
1023 | regionStack.emplace_back(*moduleOp, /*isIsolatedFromAbove=*/true); | |||
1024 | regionStack.back().curBlocks.push_back(moduleOp->getBody()); | |||
1025 | regionStack.back().curBlock = regionStack.back().curRegion->begin(); | |||
1026 | if (failed(parseBlock(reader, regionStack.back()))) | |||
1027 | return failure(); | |||
1028 | valueScopes.emplace_back(); | |||
1029 | valueScopes.back().push(regionStack.back()); | |||
1030 | ||||
1031 | // Iteratively parse regions until everything has been resolved. | |||
1032 | while (!regionStack.empty()) | |||
1033 | if (failed(parseRegions(reader, regionStack, regionStack.back()))) | |||
1034 | return failure(); | |||
1035 | if (!forwardRefOps.empty()) { | |||
1036 | return reader.emitError( | |||
1037 | "not all forward unresolved forward operand references"); | |||
1038 | } | |||
1039 | ||||
1040 | // Verify that the parsed operations are valid. | |||
1041 | if (failed(verify(*moduleOp))) | |||
1042 | return failure(); | |||
1043 | ||||
1044 | // Splice the parsed operations over to the provided top-level block. | |||
1045 | auto &parsedOps = moduleOp->getBody()->getOperations(); | |||
1046 | auto &destOps = block->getOperations(); | |||
1047 | destOps.splice(destOps.empty() ? destOps.end() : std::prev(destOps.end()), | |||
1048 | parsedOps, parsedOps.begin(), parsedOps.end()); | |||
1049 | return success(); | |||
1050 | } | |||
1051 | ||||
1052 | LogicalResult | |||
1053 | BytecodeReader::parseRegions(EncodingReader &reader, | |||
1054 | std::vector<RegionReadState> ®ionStack, | |||
1055 | RegionReadState &readState) { | |||
1056 | // Read the regions of this operation. | |||
1057 | for (; readState.curRegion != readState.endRegion; ++readState.curRegion) { | |||
1058 | // If the current block hasn't been setup yet, parse the header for this | |||
1059 | // region. | |||
1060 | if (readState.curBlock == Region::iterator()) { | |||
1061 | if (failed(parseRegion(reader, readState))) | |||
1062 | return failure(); | |||
1063 | ||||
1064 | // If the region is empty, there is nothing to more to do. | |||
1065 | if (readState.curRegion->empty()) | |||
1066 | continue; | |||
1067 | } | |||
1068 | ||||
1069 | // Parse the blocks within the region. | |||
1070 | do { | |||
1071 | while (readState.numOpsRemaining--) { | |||
1072 | // Read in the next operation. We don't read its regions directly, we | |||
1073 | // handle those afterwards as necessary. | |||
1074 | bool isIsolatedFromAbove = false; | |||
1075 | FailureOr<Operation *> op = | |||
1076 | parseOpWithoutRegions(reader, readState, isIsolatedFromAbove); | |||
1077 | if (failed(op)) | |||
1078 | return failure(); | |||
1079 | ||||
1080 | // If the op has regions, add it to the stack for processing. | |||
1081 | if ((*op)->getNumRegions()) { | |||
1082 | regionStack.emplace_back(*op, isIsolatedFromAbove); | |||
1083 | ||||
1084 | // If the op is isolated from above, push a new value scope. | |||
1085 | if (isIsolatedFromAbove) | |||
1086 | valueScopes.emplace_back(); | |||
1087 | return success(); | |||
1088 | } | |||
1089 | } | |||
1090 | ||||
1091 | // Move to the next block of the region. | |||
1092 | if (++readState.curBlock == readState.curRegion->end()) | |||
1093 | break; | |||
1094 | if (failed(parseBlock(reader, readState))) | |||
1095 | return failure(); | |||
1096 | } while (true); | |||
1097 | ||||
1098 | // Reset the current block and any values reserved for this region. | |||
1099 | readState.curBlock = {}; | |||
1100 | valueScopes.back().pop(readState); | |||
1101 | } | |||
1102 | ||||
1103 | // When the regions have been fully parsed, pop them off of the read stack. If | |||
1104 | // the regions were isolated from above, we also pop the last value scope. | |||
1105 | if (readState.isIsolatedFromAbove) | |||
1106 | valueScopes.pop_back(); | |||
1107 | regionStack.pop_back(); | |||
1108 | return success(); | |||
1109 | } | |||
1110 | ||||
1111 | FailureOr<Operation *> | |||
1112 | BytecodeReader::parseOpWithoutRegions(EncodingReader &reader, | |||
1113 | RegionReadState &readState, | |||
1114 | bool &isIsolatedFromAbove) { | |||
1115 | // Parse the name of the operation. | |||
1116 | FailureOr<OperationName> opName = parseOpName(reader); | |||
1117 | if (failed(opName)) | |||
1118 | return failure(); | |||
1119 | ||||
1120 | // Parse the operation mask, which indicates which components of the operation | |||
1121 | // are present. | |||
1122 | uint8_t opMask; | |||
1123 | if (failed(reader.parseByte(opMask))) | |||
1124 | return failure(); | |||
1125 | ||||
1126 | /// Parse the location. | |||
1127 | LocationAttr opLoc; | |||
1128 | if (failed(parseAttribute(reader, opLoc))) | |||
1129 | return failure(); | |||
1130 | ||||
1131 | // With the location and name resolved, we can start building the operation | |||
1132 | // state. | |||
1133 | OperationState opState(opLoc, *opName); | |||
1134 | ||||
1135 | // Parse the attributes of the operation. | |||
1136 | if (opMask & bytecode::OpEncodingMask::kHasAttrs) { | |||
1137 | DictionaryAttr dictAttr; | |||
1138 | if (failed(parseAttribute(reader, dictAttr))) | |||
1139 | return failure(); | |||
1140 | opState.attributes = dictAttr; | |||
1141 | } | |||
1142 | ||||
1143 | /// Parse the results of the operation. | |||
1144 | if (opMask & bytecode::OpEncodingMask::kHasResults) { | |||
1145 | uint64_t numResults; | |||
1146 | if (failed(reader.parseVarInt(numResults))) | |||
1147 | return failure(); | |||
1148 | opState.types.resize(numResults); | |||
1149 | for (int i = 0, e = numResults; i < e; ++i) | |||
1150 | if (failed(parseType(reader, opState.types[i]))) | |||
1151 | return failure(); | |||
1152 | } | |||
1153 | ||||
1154 | /// Parse the operands of the operation. | |||
1155 | if (opMask & bytecode::OpEncodingMask::kHasOperands) { | |||
1156 | uint64_t numOperands; | |||
1157 | if (failed(reader.parseVarInt(numOperands))) | |||
1158 | return failure(); | |||
1159 | opState.operands.resize(numOperands); | |||
1160 | for (int i = 0, e = numOperands; i < e; ++i) | |||
1161 | if (!(opState.operands[i] = parseOperand(reader))) | |||
1162 | return failure(); | |||
1163 | } | |||
1164 | ||||
1165 | /// Parse the successors of the operation. | |||
1166 | if (opMask & bytecode::OpEncodingMask::kHasSuccessors) { | |||
1167 | uint64_t numSuccs; | |||
1168 | if (failed(reader.parseVarInt(numSuccs))) | |||
1169 | return failure(); | |||
1170 | opState.successors.resize(numSuccs); | |||
1171 | for (int i = 0, e = numSuccs; i < e; ++i) { | |||
1172 | if (failed(parseEntry(reader, readState.curBlocks, opState.successors[i], | |||
1173 | "successor"))) | |||
1174 | return failure(); | |||
1175 | } | |||
1176 | } | |||
1177 | ||||
1178 | /// Parse the regions of the operation. | |||
1179 | if (opMask & bytecode::OpEncodingMask::kHasInlineRegions) { | |||
1180 | uint64_t numRegions; | |||
1181 | if (failed(reader.parseVarIntWithFlag(numRegions, isIsolatedFromAbove))) | |||
1182 | return failure(); | |||
1183 | ||||
1184 | opState.regions.reserve(numRegions); | |||
1185 | for (int i = 0, e = numRegions; i < e; ++i) | |||
1186 | opState.regions.push_back(std::make_unique<Region>()); | |||
1187 | } | |||
1188 | ||||
1189 | // Create the operation at the back of the current block. | |||
1190 | Operation *op = Operation::create(opState); | |||
1191 | readState.curBlock->push_back(op); | |||
1192 | ||||
1193 | // If the operation had results, update the value references. | |||
1194 | if (op->getNumResults() && failed(defineValues(reader, op->getResults()))) | |||
1195 | return failure(); | |||
1196 | ||||
1197 | return op; | |||
1198 | } | |||
1199 | ||||
1200 | LogicalResult BytecodeReader::parseRegion(EncodingReader &reader, | |||
1201 | RegionReadState &readState) { | |||
1202 | // Parse the number of blocks in the region. | |||
1203 | uint64_t numBlocks; | |||
1204 | if (failed(reader.parseVarInt(numBlocks))) | |||
1205 | return failure(); | |||
1206 | ||||
1207 | // If the region is empty, there is nothing else to do. | |||
1208 | if (numBlocks == 0) | |||
1209 | return success(); | |||
1210 | ||||
1211 | // Parse the number of values defined in this region. | |||
1212 | uint64_t numValues; | |||
1213 | if (failed(reader.parseVarInt(numValues))) | |||
1214 | return failure(); | |||
1215 | readState.numValues = numValues; | |||
1216 | ||||
1217 | // Create the blocks within this region. We do this before processing so that | |||
1218 | // we can rely on the blocks existing when creating operations. | |||
1219 | readState.curBlocks.clear(); | |||
1220 | readState.curBlocks.reserve(numBlocks); | |||
1221 | for (uint64_t i = 0; i < numBlocks; ++i) { | |||
1222 | readState.curBlocks.push_back(new Block()); | |||
1223 | readState.curRegion->push_back(readState.curBlocks.back()); | |||
1224 | } | |||
1225 | ||||
1226 | // Prepare the current value scope for this region. | |||
1227 | valueScopes.back().push(readState); | |||
1228 | ||||
1229 | // Parse the entry block of the region. | |||
1230 | readState.curBlock = readState.curRegion->begin(); | |||
1231 | return parseBlock(reader, readState); | |||
1232 | } | |||
1233 | ||||
1234 | LogicalResult BytecodeReader::parseBlock(EncodingReader &reader, | |||
1235 | RegionReadState &readState) { | |||
1236 | bool hasArgs; | |||
1237 | if (failed(reader.parseVarIntWithFlag(readState.numOpsRemaining, hasArgs))) | |||
1238 | return failure(); | |||
1239 | ||||
1240 | // Parse the arguments of the block. | |||
1241 | if (hasArgs && failed(parseBlockArguments(reader, &*readState.curBlock))) | |||
1242 | return failure(); | |||
1243 | ||||
1244 | // We don't parse the operations of the block here, that's done elsewhere. | |||
1245 | return success(); | |||
1246 | } | |||
1247 | ||||
1248 | LogicalResult BytecodeReader::parseBlockArguments(EncodingReader &reader, | |||
1249 | Block *block) { | |||
1250 | // Parse the value ID for the first argument, and the number of arguments. | |||
1251 | uint64_t numArgs; | |||
1252 | if (failed(reader.parseVarInt(numArgs))) | |||
1253 | return failure(); | |||
1254 | ||||
1255 | SmallVector<Type> argTypes; | |||
1256 | SmallVector<Location> argLocs; | |||
1257 | argTypes.reserve(numArgs); | |||
1258 | argLocs.reserve(numArgs); | |||
1259 | ||||
1260 | while (numArgs--) { | |||
1261 | Type argType; | |||
1262 | LocationAttr argLoc; | |||
1263 | if (failed(parseType(reader, argType)) || | |||
1264 | failed(parseAttribute(reader, argLoc))) | |||
1265 | return failure(); | |||
1266 | ||||
1267 | argTypes.push_back(argType); | |||
1268 | argLocs.push_back(argLoc); | |||
1269 | } | |||
1270 | block->addArguments(argTypes, argLocs); | |||
1271 | return defineValues(reader, block->getArguments()); | |||
1272 | } | |||
1273 | ||||
1274 | //===----------------------------------------------------------------------===// | |||
1275 | // Value Processing | |||
1276 | ||||
1277 | Value BytecodeReader::parseOperand(EncodingReader &reader) { | |||
1278 | std::vector<Value> &values = valueScopes.back().values; | |||
1279 | Value *value = nullptr; | |||
1280 | if (failed(parseEntry(reader, values, value, "value"))) | |||
1281 | return Value(); | |||
1282 | ||||
1283 | // Create a new forward reference if necessary. | |||
1284 | if (!*value) | |||
1285 | *value = createForwardRef(); | |||
1286 | return *value; | |||
1287 | } | |||
1288 | ||||
1289 | LogicalResult BytecodeReader::defineValues(EncodingReader &reader, | |||
1290 | ValueRange newValues) { | |||
1291 | ValueScope &valueScope = valueScopes.back(); | |||
1292 | std::vector<Value> &values = valueScope.values; | |||
1293 | ||||
1294 | unsigned &valueID = valueScope.nextValueIDs.back(); | |||
1295 | unsigned valueIDEnd = valueID + newValues.size(); | |||
1296 | if (valueIDEnd > values.size()) { | |||
1297 | return reader.emitError( | |||
1298 | "value index range was outside of the expected range for " | |||
1299 | "the parent region, got [", | |||
1300 | valueID, ", ", valueIDEnd, "), but the maximum index was ", | |||
1301 | values.size() - 1); | |||
1302 | } | |||
1303 | ||||
1304 | // Assign the values and update any forward references. | |||
1305 | for (unsigned i = 0, e = newValues.size(); i != e; ++i, ++valueID) { | |||
1306 | Value newValue = newValues[i]; | |||
1307 | ||||
1308 | // Check to see if a definition for this value already exists. | |||
1309 | if (Value oldValue = std::exchange(values[valueID], newValue)) { | |||
1310 | Operation *forwardRefOp = oldValue.getDefiningOp(); | |||
1311 | ||||
1312 | // Assert that this is a forward reference operation. Given how we compute | |||
1313 | // definition ids (incrementally as we parse), it shouldn't be possible | |||
1314 | // for the value to be defined any other way. | |||
1315 | assert(forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps &&(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1316, __extension__ __PRETTY_FUNCTION__)) | |||
1316 | "value index was already defined?")(static_cast <bool> (forwardRefOp && forwardRefOp ->getBlock() == &forwardRefOps && "value index was already defined?" ) ? void (0) : __assert_fail ("forwardRefOp && forwardRefOp->getBlock() == &forwardRefOps && \"value index was already defined?\"" , "mlir/lib/Bytecode/Reader/BytecodeReader.cpp", 1316, __extension__ __PRETTY_FUNCTION__)); | |||
1317 | ||||
1318 | oldValue.replaceAllUsesWith(newValue); | |||
1319 | forwardRefOp->moveBefore(&openForwardRefOps, openForwardRefOps.end()); | |||
1320 | } | |||
1321 | } | |||
1322 | return success(); | |||
1323 | } | |||
1324 | ||||
1325 | Value BytecodeReader::createForwardRef() { | |||
1326 | // Check for an avaliable existing operation to use. Otherwise, create a new | |||
1327 | // fake operation to use for the reference. | |||
1328 | if (!openForwardRefOps.empty()) { | |||
1329 | Operation *op = &openForwardRefOps.back(); | |||
1330 | op->moveBefore(&forwardRefOps, forwardRefOps.end()); | |||
1331 | } else { | |||
1332 | forwardRefOps.push_back(Operation::create(forwardRefOpState)); | |||
1333 | } | |||
1334 | return forwardRefOps.back().getResult(0); | |||
1335 | } | |||
1336 | ||||
1337 | //===----------------------------------------------------------------------===// | |||
1338 | // Entry Points | |||
1339 | //===----------------------------------------------------------------------===// | |||
1340 | ||||
1341 | bool mlir::isBytecode(llvm::MemoryBufferRef buffer) { | |||
1342 | return buffer.getBuffer().startswith("ML\xefR"); | |||
1343 | } | |||
1344 | ||||
1345 | LogicalResult mlir::readBytecodeFile(llvm::MemoryBufferRef buffer, Block *block, | |||
1346 | const ParserConfig &config) { | |||
1347 | Location sourceFileLoc = | |||
1348 | FileLineColLoc::get(config.getContext(), buffer.getBufferIdentifier(), | |||
1349 | /*line=*/0, /*column=*/0); | |||
1350 | if (!isBytecode(buffer)) { | |||
| ||||
1351 | return emitError(sourceFileLoc, | |||
1352 | "input buffer is not an MLIR bytecode file"); | |||
1353 | } | |||
1354 | ||||
1355 | BytecodeReader reader(sourceFileLoc, config); | |||
1356 | return reader.read(buffer, block); | |||
1357 | } |