File: | build/llvm-toolchain-snapshot-16~++20220828101037+f00f2b3e8d40/mlir/lib/AsmParser/TypeParser.cpp |
Warning: | line 235, column 23 2nd function call argument is an uninitialized value |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements the parser for the MLIR Types. | |||
10 | // | |||
11 | //===----------------------------------------------------------------------===// | |||
12 | ||||
13 | #include "Parser.h" | |||
14 | #include "mlir/IR/AffineMap.h" | |||
15 | #include "mlir/IR/BuiltinTypes.h" | |||
16 | #include "mlir/IR/OpDefinition.h" | |||
17 | #include "mlir/IR/TensorEncoding.h" | |||
18 | ||||
19 | using namespace mlir; | |||
20 | using namespace mlir::detail; | |||
21 | ||||
22 | /// Optionally parse a type. | |||
23 | OptionalParseResult Parser::parseOptionalType(Type &type) { | |||
24 | // There are many different starting tokens for a type, check them here. | |||
25 | switch (getToken().getKind()) { | |||
26 | case Token::l_paren: | |||
27 | case Token::kw_memref: | |||
28 | case Token::kw_tensor: | |||
29 | case Token::kw_complex: | |||
30 | case Token::kw_tuple: | |||
31 | case Token::kw_vector: | |||
32 | case Token::inttype: | |||
33 | case Token::kw_bf16: | |||
34 | case Token::kw_f16: | |||
35 | case Token::kw_f32: | |||
36 | case Token::kw_f64: | |||
37 | case Token::kw_f80: | |||
38 | case Token::kw_f128: | |||
39 | case Token::kw_index: | |||
40 | case Token::kw_none: | |||
41 | case Token::exclamation_identifier: | |||
42 | return failure(!(type = parseType())); | |||
43 | ||||
44 | default: | |||
45 | return llvm::None; | |||
46 | } | |||
47 | } | |||
48 | ||||
49 | /// Parse an arbitrary type. | |||
50 | /// | |||
51 | /// type ::= function-type | |||
52 | /// | non-function-type | |||
53 | /// | |||
54 | Type Parser::parseType() { | |||
55 | if (getToken().is(Token::l_paren)) | |||
56 | return parseFunctionType(); | |||
57 | return parseNonFunctionType(); | |||
58 | } | |||
59 | ||||
60 | /// Parse a function result type. | |||
61 | /// | |||
62 | /// function-result-type ::= type-list-parens | |||
63 | /// | non-function-type | |||
64 | /// | |||
65 | ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) { | |||
66 | if (getToken().is(Token::l_paren)) | |||
67 | return parseTypeListParens(elements); | |||
68 | ||||
69 | Type t = parseNonFunctionType(); | |||
70 | if (!t) | |||
71 | return failure(); | |||
72 | elements.push_back(t); | |||
73 | return success(); | |||
74 | } | |||
75 | ||||
76 | /// Parse a list of types without an enclosing parenthesis. The list must have | |||
77 | /// at least one member. | |||
78 | /// | |||
79 | /// type-list-no-parens ::= type (`,` type)* | |||
80 | /// | |||
81 | ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) { | |||
82 | auto parseElt = [&]() -> ParseResult { | |||
83 | auto elt = parseType(); | |||
84 | elements.push_back(elt); | |||
85 | return elt ? success() : failure(); | |||
86 | }; | |||
87 | ||||
88 | return parseCommaSeparatedList(parseElt); | |||
89 | } | |||
90 | ||||
91 | /// Parse a parenthesized list of types. | |||
92 | /// | |||
93 | /// type-list-parens ::= `(` `)` | |||
94 | /// | `(` type-list-no-parens `)` | |||
95 | /// | |||
96 | ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) { | |||
97 | if (parseToken(Token::l_paren, "expected '('")) | |||
98 | return failure(); | |||
99 | ||||
100 | // Handle empty lists. | |||
101 | if (getToken().is(Token::r_paren)) | |||
102 | return consumeToken(), success(); | |||
103 | ||||
104 | if (parseTypeListNoParens(elements) || | |||
105 | parseToken(Token::r_paren, "expected ')'")) | |||
106 | return failure(); | |||
107 | return success(); | |||
108 | } | |||
109 | ||||
110 | /// Parse a complex type. | |||
111 | /// | |||
112 | /// complex-type ::= `complex` `<` type `>` | |||
113 | /// | |||
114 | Type Parser::parseComplexType() { | |||
115 | consumeToken(Token::kw_complex); | |||
116 | ||||
117 | // Parse the '<'. | |||
118 | if (parseToken(Token::less, "expected '<' in complex type")) | |||
119 | return nullptr; | |||
120 | ||||
121 | SMLoc elementTypeLoc = getToken().getLoc(); | |||
122 | auto elementType = parseType(); | |||
123 | if (!elementType || | |||
124 | parseToken(Token::greater, "expected '>' in complex type")) | |||
125 | return nullptr; | |||
126 | if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>()) | |||
127 | return emitError(elementTypeLoc, "invalid element type for complex"), | |||
128 | nullptr; | |||
129 | ||||
130 | return ComplexType::get(elementType); | |||
131 | } | |||
132 | ||||
133 | /// Parse a function type. | |||
134 | /// | |||
135 | /// function-type ::= type-list-parens `->` function-result-type | |||
136 | /// | |||
137 | Type Parser::parseFunctionType() { | |||
138 | assert(getToken().is(Token::l_paren))(static_cast <bool> (getToken().is(Token::l_paren)) ? void (0) : __assert_fail ("getToken().is(Token::l_paren)", "mlir/lib/AsmParser/TypeParser.cpp" , 138, __extension__ __PRETTY_FUNCTION__)); | |||
139 | ||||
140 | SmallVector<Type, 4> arguments, results; | |||
141 | if (parseTypeListParens(arguments) || | |||
142 | parseToken(Token::arrow, "expected '->' in function type") || | |||
143 | parseFunctionResultTypes(results)) | |||
144 | return nullptr; | |||
145 | ||||
146 | return builder.getFunctionType(arguments, results); | |||
147 | } | |||
148 | ||||
149 | /// Parse the offset and strides from a strided layout specification. | |||
150 | /// | |||
151 | /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list | |||
152 | /// | |||
153 | ParseResult Parser::parseStridedLayout(int64_t &offset, | |||
154 | SmallVectorImpl<int64_t> &strides) { | |||
155 | // Parse offset. | |||
156 | consumeToken(Token::kw_offset); | |||
157 | if (parseToken(Token::colon, "expected colon after `offset` keyword")) | |||
158 | return failure(); | |||
159 | ||||
160 | auto maybeOffset = getToken().getUnsignedIntegerValue(); | |||
161 | bool question = getToken().is(Token::question); | |||
162 | if (!maybeOffset && !question
| |||
163 | return emitWrongTokenError("invalid offset"); | |||
164 | offset = maybeOffset ? static_cast<int64_t>(*maybeOffset) | |||
165 | : MemRefType::getDynamicStrideOrOffset(); | |||
166 | consumeToken(); | |||
167 | ||||
168 | // Parse stride list. | |||
169 | if (parseToken(Token::comma, "expected comma after offset value") || | |||
170 | parseToken(Token::kw_strides, | |||
171 | "expected `strides` keyword after offset specification") || | |||
172 | parseToken(Token::colon, "expected colon after `strides` keyword") || | |||
173 | parseStrideList(strides)) | |||
174 | return failure(); | |||
175 | return success(); | |||
176 | } | |||
177 | ||||
178 | /// Parse a memref type. | |||
179 | /// | |||
180 | /// memref-type ::= ranked-memref-type | unranked-memref-type | |||
181 | /// | |||
182 | /// ranked-memref-type ::= `memref` `<` dimension-list-ranked type | |||
183 | /// (`,` layout-specification)? (`,` memory-space)? `>` | |||
184 | /// | |||
185 | /// unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>` | |||
186 | /// | |||
187 | /// stride-list ::= `[` (dimension (`,` dimension)*)? `]` | |||
188 | /// strided-layout ::= `offset:` dimension `,` `strides: ` stride-list | |||
189 | /// layout-specification ::= semi-affine-map | strided-layout | attribute | |||
190 | /// memory-space ::= integer-literal | attribute | |||
191 | /// | |||
192 | Type Parser::parseMemRefType() { | |||
193 | SMLoc loc = getToken().getLoc(); | |||
194 | consumeToken(Token::kw_memref); | |||
195 | ||||
196 | if (parseToken(Token::less, "expected '<' in memref type")) | |||
197 | return nullptr; | |||
198 | ||||
199 | bool isUnranked; | |||
200 | SmallVector<int64_t, 4> dimensions; | |||
201 | ||||
202 | if (consumeIf(Token::star)) { | |||
203 | // This is an unranked memref type. | |||
204 | isUnranked = true; | |||
205 | if (parseXInDimensionList()) | |||
206 | return nullptr; | |||
207 | ||||
208 | } else { | |||
209 | isUnranked = false; | |||
210 | if (parseDimensionListRanked(dimensions)) | |||
211 | return nullptr; | |||
212 | } | |||
213 | ||||
214 | // Parse the element type. | |||
215 | auto typeLoc = getToken().getLoc(); | |||
216 | auto elementType = parseType(); | |||
217 | if (!elementType) | |||
218 | return nullptr; | |||
219 | ||||
220 | // Check that memref is formed from allowed types. | |||
221 | if (!BaseMemRefType::isValidElementType(elementType)) | |||
222 | return emitError(typeLoc, "invalid memref element type"), nullptr; | |||
223 | ||||
224 | MemRefLayoutAttrInterface layout; | |||
225 | Attribute memorySpace; | |||
226 | ||||
227 | auto parseElt = [&]() -> ParseResult { | |||
228 | // Check for AffineMap as offset/strides. | |||
229 | if (getToken().is(Token::kw_offset)) { | |||
| ||||
230 | int64_t offset; | |||
231 | SmallVector<int64_t, 4> strides; | |||
232 | if (failed(parseStridedLayout(offset, strides))) | |||
233 | return failure(); | |||
234 | // Construct strided affine map. | |||
235 | AffineMap map = makeStridedLinearLayoutMap(strides, offset, getContext()); | |||
| ||||
236 | layout = AffineMapAttr::get(map); | |||
237 | } else { | |||
238 | // Either it is MemRefLayoutAttrInterface or memory space attribute. | |||
239 | Attribute attr = parseAttribute(); | |||
240 | if (!attr) | |||
241 | return failure(); | |||
242 | ||||
243 | if (attr.isa<MemRefLayoutAttrInterface>()) { | |||
244 | layout = attr.cast<MemRefLayoutAttrInterface>(); | |||
245 | } else if (memorySpace) { | |||
246 | return emitError("multiple memory spaces specified in memref type"); | |||
247 | } else { | |||
248 | memorySpace = attr; | |||
249 | return success(); | |||
250 | } | |||
251 | } | |||
252 | ||||
253 | if (isUnranked) | |||
254 | return emitError("cannot have affine map for unranked memref type"); | |||
255 | if (memorySpace) | |||
256 | return emitError("expected memory space to be last in memref type"); | |||
257 | ||||
258 | return success(); | |||
259 | }; | |||
260 | ||||
261 | // Parse a list of mappings and address space if present. | |||
262 | if (!consumeIf(Token::greater)) { | |||
263 | // Parse comma separated list of affine maps, followed by memory space. | |||
264 | if (parseToken(Token::comma, "expected ',' or '>' in memref type") || | |||
265 | parseCommaSeparatedListUntil(Token::greater, parseElt, | |||
266 | /*allowEmptyList=*/false)) { | |||
267 | return nullptr; | |||
268 | } | |||
269 | } | |||
270 | ||||
271 | if (isUnranked) | |||
272 | return getChecked<UnrankedMemRefType>(loc, elementType, memorySpace); | |||
273 | ||||
274 | return getChecked<MemRefType>(loc, dimensions, elementType, layout, | |||
275 | memorySpace); | |||
276 | } | |||
277 | ||||
278 | /// Parse any type except the function type. | |||
279 | /// | |||
280 | /// non-function-type ::= integer-type | |||
281 | /// | index-type | |||
282 | /// | float-type | |||
283 | /// | extended-type | |||
284 | /// | vector-type | |||
285 | /// | tensor-type | |||
286 | /// | memref-type | |||
287 | /// | complex-type | |||
288 | /// | tuple-type | |||
289 | /// | none-type | |||
290 | /// | |||
291 | /// index-type ::= `index` | |||
292 | /// float-type ::= `f16` | `bf16` | `f32` | `f64` | `f80` | `f128` | |||
293 | /// none-type ::= `none` | |||
294 | /// | |||
295 | Type Parser::parseNonFunctionType() { | |||
296 | switch (getToken().getKind()) { | |||
297 | default: | |||
298 | return (emitWrongTokenError("expected non-function type"), nullptr); | |||
299 | case Token::kw_memref: | |||
300 | return parseMemRefType(); | |||
301 | case Token::kw_tensor: | |||
302 | return parseTensorType(); | |||
303 | case Token::kw_complex: | |||
304 | return parseComplexType(); | |||
305 | case Token::kw_tuple: | |||
306 | return parseTupleType(); | |||
307 | case Token::kw_vector: | |||
308 | return parseVectorType(); | |||
309 | // integer-type | |||
310 | case Token::inttype: { | |||
311 | auto width = getToken().getIntTypeBitwidth(); | |||
312 | if (!width.has_value()) | |||
313 | return (emitError("invalid integer width"), nullptr); | |||
314 | if (width.value() > IntegerType::kMaxWidth) { | |||
315 | emitError(getToken().getLoc(), "integer bitwidth is limited to ") | |||
316 | << IntegerType::kMaxWidth << " bits"; | |||
317 | return nullptr; | |||
318 | } | |||
319 | ||||
320 | IntegerType::SignednessSemantics signSemantics = IntegerType::Signless; | |||
321 | if (Optional<bool> signedness = getToken().getIntTypeSignedness()) | |||
322 | signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned; | |||
323 | ||||
324 | consumeToken(Token::inttype); | |||
325 | return IntegerType::get(getContext(), *width, signSemantics); | |||
326 | } | |||
327 | ||||
328 | // float-type | |||
329 | case Token::kw_bf16: | |||
330 | consumeToken(Token::kw_bf16); | |||
331 | return builder.getBF16Type(); | |||
332 | case Token::kw_f16: | |||
333 | consumeToken(Token::kw_f16); | |||
334 | return builder.getF16Type(); | |||
335 | case Token::kw_f32: | |||
336 | consumeToken(Token::kw_f32); | |||
337 | return builder.getF32Type(); | |||
338 | case Token::kw_f64: | |||
339 | consumeToken(Token::kw_f64); | |||
340 | return builder.getF64Type(); | |||
341 | case Token::kw_f80: | |||
342 | consumeToken(Token::kw_f80); | |||
343 | return builder.getF80Type(); | |||
344 | case Token::kw_f128: | |||
345 | consumeToken(Token::kw_f128); | |||
346 | return builder.getF128Type(); | |||
347 | ||||
348 | // index-type | |||
349 | case Token::kw_index: | |||
350 | consumeToken(Token::kw_index); | |||
351 | return builder.getIndexType(); | |||
352 | ||||
353 | // none-type | |||
354 | case Token::kw_none: | |||
355 | consumeToken(Token::kw_none); | |||
356 | return builder.getNoneType(); | |||
357 | ||||
358 | // extended type | |||
359 | case Token::exclamation_identifier: | |||
360 | return parseExtendedType(); | |||
361 | ||||
362 | // Handle completion of a dialect type. | |||
363 | case Token::code_complete: | |||
364 | if (getToken().isCodeCompletionFor(Token::exclamation_identifier)) | |||
365 | return parseExtendedType(); | |||
366 | return codeCompleteType(); | |||
367 | } | |||
368 | } | |||
369 | ||||
370 | /// Parse a tensor type. | |||
371 | /// | |||
372 | /// tensor-type ::= `tensor` `<` dimension-list type `>` | |||
373 | /// dimension-list ::= dimension-list-ranked | `*x` | |||
374 | /// | |||
375 | Type Parser::parseTensorType() { | |||
376 | consumeToken(Token::kw_tensor); | |||
377 | ||||
378 | if (parseToken(Token::less, "expected '<' in tensor type")) | |||
379 | return nullptr; | |||
380 | ||||
381 | bool isUnranked; | |||
382 | SmallVector<int64_t, 4> dimensions; | |||
383 | ||||
384 | if (consumeIf(Token::star)) { | |||
385 | // This is an unranked tensor type. | |||
386 | isUnranked = true; | |||
387 | ||||
388 | if (parseXInDimensionList()) | |||
389 | return nullptr; | |||
390 | ||||
391 | } else { | |||
392 | isUnranked = false; | |||
393 | if (parseDimensionListRanked(dimensions)) | |||
394 | return nullptr; | |||
395 | } | |||
396 | ||||
397 | // Parse the element type. | |||
398 | auto elementTypeLoc = getToken().getLoc(); | |||
399 | auto elementType = parseType(); | |||
400 | ||||
401 | // Parse an optional encoding attribute. | |||
402 | Attribute encoding; | |||
403 | if (consumeIf(Token::comma)) { | |||
404 | encoding = parseAttribute(); | |||
405 | if (auto v = encoding.dyn_cast_or_null<VerifiableTensorEncoding>()) { | |||
406 | if (failed(v.verifyEncoding(dimensions, elementType, | |||
407 | [&] { return emitError(); }))) | |||
408 | return nullptr; | |||
409 | } | |||
410 | } | |||
411 | ||||
412 | if (!elementType || parseToken(Token::greater, "expected '>' in tensor type")) | |||
413 | return nullptr; | |||
414 | if (!TensorType::isValidElementType(elementType)) | |||
415 | return emitError(elementTypeLoc, "invalid tensor element type"), nullptr; | |||
416 | ||||
417 | if (isUnranked) { | |||
418 | if (encoding) | |||
419 | return emitError("cannot apply encoding to unranked tensor"), nullptr; | |||
420 | return UnrankedTensorType::get(elementType); | |||
421 | } | |||
422 | return RankedTensorType::get(dimensions, elementType, encoding); | |||
423 | } | |||
424 | ||||
425 | /// Parse a tuple type. | |||
426 | /// | |||
427 | /// tuple-type ::= `tuple` `<` (type (`,` type)*)? `>` | |||
428 | /// | |||
429 | Type Parser::parseTupleType() { | |||
430 | consumeToken(Token::kw_tuple); | |||
431 | ||||
432 | // Parse the '<'. | |||
433 | if (parseToken(Token::less, "expected '<' in tuple type")) | |||
434 | return nullptr; | |||
435 | ||||
436 | // Check for an empty tuple by directly parsing '>'. | |||
437 | if (consumeIf(Token::greater)) | |||
438 | return TupleType::get(getContext()); | |||
439 | ||||
440 | // Parse the element types and the '>'. | |||
441 | SmallVector<Type, 4> types; | |||
442 | if (parseTypeListNoParens(types) || | |||
443 | parseToken(Token::greater, "expected '>' in tuple type")) | |||
444 | return nullptr; | |||
445 | ||||
446 | return TupleType::get(getContext(), types); | |||
447 | } | |||
448 | ||||
449 | /// Parse a vector type. | |||
450 | /// | |||
451 | /// vector-type ::= `vector` `<` vector-dim-list vector-element-type `>` | |||
452 | /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? | |||
453 | /// static-dim-list ::= decimal-literal (`x` decimal-literal)* | |||
454 | /// | |||
455 | VectorType Parser::parseVectorType() { | |||
456 | consumeToken(Token::kw_vector); | |||
457 | ||||
458 | if (parseToken(Token::less, "expected '<' in vector type")) | |||
459 | return nullptr; | |||
460 | ||||
461 | SmallVector<int64_t, 4> dimensions; | |||
462 | unsigned numScalableDims; | |||
463 | if (parseVectorDimensionList(dimensions, numScalableDims)) | |||
464 | return nullptr; | |||
465 | if (any_of(dimensions, [](int64_t i) { return i <= 0; })) | |||
466 | return emitError(getToken().getLoc(), | |||
467 | "vector types must have positive constant sizes"), | |||
468 | nullptr; | |||
469 | ||||
470 | // Parse the element type. | |||
471 | auto typeLoc = getToken().getLoc(); | |||
472 | auto elementType = parseType(); | |||
473 | if (!elementType || parseToken(Token::greater, "expected '>' in vector type")) | |||
474 | return nullptr; | |||
475 | ||||
476 | if (!VectorType::isValidElementType(elementType)) | |||
477 | return emitError(typeLoc, "vector elements must be int/index/float type"), | |||
478 | nullptr; | |||
479 | ||||
480 | return VectorType::get(dimensions, elementType, numScalableDims); | |||
481 | } | |||
482 | ||||
483 | /// Parse a dimension list in a vector type. This populates the dimension list, | |||
484 | /// and returns the number of scalable dimensions in `numScalableDims`. | |||
485 | /// | |||
486 | /// vector-dim-list := (static-dim-list `x`)? (`[` static-dim-list `]` `x`)? | |||
487 | /// static-dim-list ::= decimal-literal (`x` decimal-literal)* | |||
488 | /// | |||
489 | ParseResult | |||
490 | Parser::parseVectorDimensionList(SmallVectorImpl<int64_t> &dimensions, | |||
491 | unsigned &numScalableDims) { | |||
492 | numScalableDims = 0; | |||
493 | // If there is a set of fixed-length dimensions, consume it | |||
494 | while (getToken().is(Token::integer)) { | |||
495 | int64_t value; | |||
496 | if (parseIntegerInDimensionList(value)) | |||
497 | return failure(); | |||
498 | dimensions.push_back(value); | |||
499 | // Make sure we have an 'x' or something like 'xbf32'. | |||
500 | if (parseXInDimensionList()) | |||
501 | return failure(); | |||
502 | } | |||
503 | // If there is a set of scalable dimensions, consume it | |||
504 | if (consumeIf(Token::l_square)) { | |||
505 | while (getToken().is(Token::integer)) { | |||
506 | int64_t value; | |||
507 | if (parseIntegerInDimensionList(value)) | |||
508 | return failure(); | |||
509 | dimensions.push_back(value); | |||
510 | numScalableDims++; | |||
511 | // Check if we have reached the end of the scalable dimension list | |||
512 | if (consumeIf(Token::r_square)) { | |||
513 | // Make sure we have something like 'xbf32'. | |||
514 | return parseXInDimensionList(); | |||
515 | } | |||
516 | // Make sure we have an 'x' | |||
517 | if (parseXInDimensionList()) | |||
518 | return failure(); | |||
519 | } | |||
520 | // If we make it here, we've finished parsing the dimension list | |||
521 | // without finding ']' closing the set of scalable dimensions | |||
522 | return emitWrongTokenError( | |||
523 | "missing ']' closing set of scalable dimensions"); | |||
524 | } | |||
525 | ||||
526 | return success(); | |||
527 | } | |||
528 | ||||
529 | /// Parse a dimension list of a tensor or memref type. This populates the | |||
530 | /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and | |||
531 | /// errors out on `?` otherwise. Parsing the trailing `x` is configurable. | |||
532 | /// | |||
533 | /// dimension-list ::= eps | dimension (`x` dimension)* | |||
534 | /// dimension-list-with-trailing-x ::= (dimension `x`)* | |||
535 | /// dimension ::= `?` | decimal-literal | |||
536 | /// | |||
537 | /// When `allowDynamic` is not set, this is used to parse: | |||
538 | /// | |||
539 | /// static-dimension-list ::= eps | decimal-literal (`x` decimal-literal)* | |||
540 | /// static-dimension-list-with-trailing-x ::= (dimension `x`)* | |||
541 | ParseResult | |||
542 | Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions, | |||
543 | bool allowDynamic, bool withTrailingX) { | |||
544 | auto parseDim = [&]() -> LogicalResult { | |||
545 | auto loc = getToken().getLoc(); | |||
546 | if (consumeIf(Token::question)) { | |||
547 | if (!allowDynamic) | |||
548 | return emitError(loc, "expected static shape"); | |||
549 | dimensions.push_back(-1); | |||
550 | } else { | |||
551 | int64_t value; | |||
552 | if (failed(parseIntegerInDimensionList(value))) | |||
553 | return failure(); | |||
554 | dimensions.push_back(value); | |||
555 | } | |||
556 | return success(); | |||
557 | }; | |||
558 | ||||
559 | if (withTrailingX) { | |||
560 | while (getToken().isAny(Token::integer, Token::question)) { | |||
561 | if (failed(parseDim()) || failed(parseXInDimensionList())) | |||
562 | return failure(); | |||
563 | } | |||
564 | return success(); | |||
565 | } | |||
566 | ||||
567 | if (getToken().isAny(Token::integer, Token::question)) { | |||
568 | if (failed(parseDim())) | |||
569 | return failure(); | |||
570 | while (getToken().is(Token::bare_identifier) && | |||
571 | getTokenSpelling()[0] == 'x') { | |||
572 | if (failed(parseXInDimensionList()) || failed(parseDim())) | |||
573 | return failure(); | |||
574 | } | |||
575 | } | |||
576 | return success(); | |||
577 | } | |||
578 | ||||
579 | ParseResult Parser::parseIntegerInDimensionList(int64_t &value) { | |||
580 | // Hexadecimal integer literals (starting with `0x`) are not allowed in | |||
581 | // aggregate type declarations. Therefore, `0xf32` should be processed as | |||
582 | // a sequence of separate elements `0`, `x`, `f32`. | |||
583 | if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') { | |||
584 | // We can get here only if the token is an integer literal. Hexadecimal | |||
585 | // integer literals can only start with `0x` (`1x` wouldn't lex as a | |||
586 | // literal, just `1` would, at which point we don't get into this | |||
587 | // branch). | |||
588 | assert(getTokenSpelling()[0] == '0' && "invalid integer literal")(static_cast <bool> (getTokenSpelling()[0] == '0' && "invalid integer literal") ? void (0) : __assert_fail ("getTokenSpelling()[0] == '0' && \"invalid integer literal\"" , "mlir/lib/AsmParser/TypeParser.cpp", 588, __extension__ __PRETTY_FUNCTION__ )); | |||
589 | value = 0; | |||
590 | state.lex.resetPointer(getTokenSpelling().data() + 1); | |||
591 | consumeToken(); | |||
592 | } else { | |||
593 | // Make sure this integer value is in bound and valid. | |||
594 | Optional<uint64_t> dimension = getToken().getUInt64IntegerValue(); | |||
595 | if (!dimension || | |||
596 | *dimension > (uint64_t)std::numeric_limits<int64_t>::max()) | |||
597 | return emitError("invalid dimension"); | |||
598 | value = (int64_t)*dimension; | |||
599 | consumeToken(Token::integer); | |||
600 | } | |||
601 | return success(); | |||
602 | } | |||
603 | ||||
604 | /// Parse an 'x' token in a dimension list, handling the case where the x is | |||
605 | /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next | |||
606 | /// token. | |||
607 | ParseResult Parser::parseXInDimensionList() { | |||
608 | if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x') | |||
609 | return emitWrongTokenError("expected 'x' in dimension list"); | |||
610 | ||||
611 | // If we had a prefix of 'x', lex the next token immediately after the 'x'. | |||
612 | if (getTokenSpelling().size() != 1) | |||
613 | state.lex.resetPointer(getTokenSpelling().data() + 1); | |||
614 | ||||
615 | // Consume the 'x'. | |||
616 | consumeToken(Token::bare_identifier); | |||
617 | ||||
618 | return success(); | |||
619 | } | |||
620 | ||||
621 | // Parse a comma-separated list of dimensions, possibly empty: | |||
622 | // stride-list ::= `[` (dimension (`,` dimension)*)? `]` | |||
623 | ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) { | |||
624 | return parseCommaSeparatedList( | |||
625 | Delimiter::Square, | |||
626 | [&]() -> ParseResult { | |||
627 | if (consumeIf(Token::question)) { | |||
628 | dimensions.push_back(MemRefType::getDynamicStrideOrOffset()); | |||
629 | } else { | |||
630 | // This must be an integer value. | |||
631 | int64_t val; | |||
632 | if (getToken().getSpelling().getAsInteger(10, val)) | |||
633 | return emitError("invalid integer value: ") | |||
634 | << getToken().getSpelling(); | |||
635 | // Make sure it is not the one value for `?`. | |||
636 | if (ShapedType::isDynamic(val)) | |||
637 | return emitError("invalid integer value: ") | |||
638 | << getToken().getSpelling() | |||
639 | << ", use `?` to specify a dynamic dimension"; | |||
640 | ||||
641 | if (val == 0) | |||
642 | return emitError("invalid memref stride"); | |||
643 | ||||
644 | dimensions.push_back(val); | |||
645 | consumeToken(Token::integer); | |||
646 | } | |||
647 | return success(); | |||
648 | }, | |||
649 | " in stride list"); | |||
650 | } |