File: | build/source/mlir/include/mlir/IR/OpDefinition.h |
Warning: | line 114, column 5 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- SuperVectorize.cpp - Vectorize Pass Impl ---------------------------===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | // | ||||
9 | // This file implements vectorization of loops, operations and data types to | ||||
10 | // a target-independent, n-D super-vector abstraction. | ||||
11 | // | ||||
12 | //===----------------------------------------------------------------------===// | ||||
13 | |||||
14 | #include "mlir/Dialect/Affine/Passes.h" | ||||
15 | |||||
16 | #include "mlir/Analysis/SliceAnalysis.h" | ||||
17 | #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" | ||||
18 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" | ||||
19 | #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" | ||||
20 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||
21 | #include "mlir/Dialect/Affine/Utils.h" | ||||
22 | #include "mlir/Dialect/Arith/IR/Arith.h" | ||||
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" | ||||
24 | #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||||
25 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||||
26 | #include "mlir/IR/IRMapping.h" | ||||
27 | #include "mlir/Pass/Pass.h" | ||||
28 | #include "mlir/Support/LLVM.h" | ||||
29 | #include "llvm/ADT/STLExtras.h" | ||||
30 | #include "llvm/Support/Debug.h" | ||||
31 | #include <optional> | ||||
32 | |||||
33 | namespace mlir { | ||||
34 | namespace affine { | ||||
35 | #define GEN_PASS_DEF_AFFINEVECTORIZE | ||||
36 | #include "mlir/Dialect/Affine/Passes.h.inc" | ||||
37 | } // namespace affine | ||||
38 | } // namespace mlir | ||||
39 | |||||
40 | using namespace mlir; | ||||
41 | using namespace affine; | ||||
42 | using namespace vector; | ||||
43 | |||||
44 | /// | ||||
45 | /// Implements a high-level vectorization strategy on a Function. | ||||
46 | /// The abstraction used is that of super-vectors, which provide a single, | ||||
47 | /// compact, representation in the vector types, information that is expected | ||||
48 | /// to reduce the impact of the phase ordering problem | ||||
49 | /// | ||||
50 | /// Vector granularity: | ||||
51 | /// =================== | ||||
52 | /// This pass is designed to perform vectorization at a super-vector | ||||
53 | /// granularity. A super-vector is loosely defined as a vector type that is a | ||||
54 | /// multiple of a "good" vector size so the HW can efficiently implement a set | ||||
55 | /// of high-level primitives. Multiple is understood along any dimension; e.g. | ||||
56 | /// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a | ||||
57 | /// vector<8xf32> HW vector. Note that a "good vector size so the HW can | ||||
58 | /// efficiently implement a set of high-level primitives" is not necessarily an | ||||
59 | /// integer multiple of actual hardware registers. We leave details of this | ||||
60 | /// distinction unspecified for now. | ||||
61 | /// | ||||
62 | /// Some may prefer the terminology a "tile of HW vectors". In this case, one | ||||
63 | /// should note that super-vectors implement an "always full tile" abstraction. | ||||
64 | /// They guarantee no partial-tile separation is necessary by relying on a | ||||
65 | /// high-level copy-reshape abstraction that we call vector.transfer. This | ||||
66 | /// copy-reshape operations is also responsible for performing layout | ||||
67 | /// transposition if necessary. In the general case this will require a scoped | ||||
68 | /// allocation in some notional local memory. | ||||
69 | /// | ||||
70 | /// Whatever the mental model one prefers to use for this abstraction, the key | ||||
71 | /// point is that we burn into a single, compact, representation in the vector | ||||
72 | /// types, information that is expected to reduce the impact of the phase | ||||
73 | /// ordering problem. Indeed, a vector type conveys information that: | ||||
74 | /// 1. the associated loops have dependency semantics that do not prevent | ||||
75 | /// vectorization; | ||||
76 | /// 2. the associate loops have been sliced in chunks of static sizes that are | ||||
77 | /// compatible with vector sizes (i.e. similar to unroll-and-jam); | ||||
78 | /// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by | ||||
79 | /// the | ||||
80 | /// vector type and no vectorization hampering transformations can be | ||||
81 | /// applied to them anymore; | ||||
82 | /// 4. the underlying memrefs are accessed in some notional contiguous way | ||||
83 | /// that allows loading into vectors with some amount of spatial locality; | ||||
84 | /// In other words, super-vectorization provides a level of separation of | ||||
85 | /// concern by way of opacity to subsequent passes. This has the effect of | ||||
86 | /// encapsulating and propagating vectorization constraints down the list of | ||||
87 | /// passes until we are ready to lower further. | ||||
88 | /// | ||||
89 | /// For a particular target, a notion of minimal n-d vector size will be | ||||
90 | /// specified and vectorization targets a multiple of those. In the following | ||||
91 | /// paragraph, let "k ." represent "a multiple of", to be understood as a | ||||
92 | /// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes | ||||
93 | /// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc). | ||||
94 | /// | ||||
95 | /// Some non-exhaustive notable super-vector sizes of interest include: | ||||
96 | /// - CPU: vector<k . HW_vector_size>, | ||||
97 | /// vector<k' . core_count x k . HW_vector_size>, | ||||
98 | /// vector<socket_count x k' . core_count x k . HW_vector_size>; | ||||
99 | /// - GPU: vector<k . warp_size>, | ||||
100 | /// vector<k . warp_size x float2>, | ||||
101 | /// vector<k . warp_size x float4>, | ||||
102 | /// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes). | ||||
103 | /// | ||||
104 | /// Loops and operations are emitted that operate on those super-vector shapes. | ||||
105 | /// Subsequent lowering passes will materialize to actual HW vector sizes. These | ||||
106 | /// passes are expected to be (gradually) more target-specific. | ||||
107 | /// | ||||
108 | /// At a high level, a vectorized load in a loop will resemble: | ||||
109 | /// ```mlir | ||||
110 | /// affine.for %i = ? to ? step ? { | ||||
111 | /// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32> | ||||
112 | /// } | ||||
113 | /// ``` | ||||
114 | /// It is the responsibility of the implementation of vector.transfer_read to | ||||
115 | /// materialize vector registers from the original scalar memrefs. A later (more | ||||
116 | /// target-dependent) lowering pass will materialize to actual HW vector sizes. | ||||
117 | /// This lowering may be occur at different times: | ||||
118 | /// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp + | ||||
119 | /// DmaWaitOp + vectorized operations for data transformations and shuffle; | ||||
120 | /// thus opening opportunities for unrolling and pipelining. This is an | ||||
121 | /// instance of library call "whiteboxing"; or | ||||
122 | /// 2. later in the a target-specific lowering pass or hand-written library | ||||
123 | /// call; achieving full separation of concerns. This is an instance of | ||||
124 | /// library call; or | ||||
125 | /// 3. a mix of both, e.g. based on a model. | ||||
126 | /// In the future, these operations will expose a contract to constrain the | ||||
127 | /// search on vectorization patterns and sizes. | ||||
128 | /// | ||||
129 | /// Occurrence of super-vectorization in the compiler flow: | ||||
130 | /// ======================================================= | ||||
131 | /// This is an active area of investigation. We start with 2 remarks to position | ||||
132 | /// super-vectorization in the context of existing ongoing work: LLVM VPLAN | ||||
133 | /// and LLVM SLP Vectorizer. | ||||
134 | /// | ||||
135 | /// LLVM VPLAN: | ||||
136 | /// ----------- | ||||
137 | /// The astute reader may have noticed that in the limit, super-vectorization | ||||
138 | /// can be applied at a similar time and with similar objectives than VPLAN. | ||||
139 | /// For instance, in the case of a traditional, polyhedral compilation-flow (for | ||||
140 | /// instance, the PPCG project uses ISL to provide dependence analysis, | ||||
141 | /// multi-level(scheduling + tiling), lifting footprint to fast memory, | ||||
142 | /// communication synthesis, mapping, register optimizations) and before | ||||
143 | /// unrolling. When vectorization is applied at this *late* level in a typical | ||||
144 | /// polyhedral flow, and is instantiated with actual hardware vector sizes, | ||||
145 | /// super-vectorization is expected to match (or subsume) the type of patterns | ||||
146 | /// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR | ||||
147 | /// is higher level and our implementation should be significantly simpler. Also | ||||
148 | /// note that in this mode, recursive patterns are probably a bit of an overkill | ||||
149 | /// although it is reasonable to expect that mixing a bit of outer loop and | ||||
150 | /// inner loop vectorization + unrolling will provide interesting choices to | ||||
151 | /// MLIR. | ||||
152 | /// | ||||
153 | /// LLVM SLP Vectorizer: | ||||
154 | /// -------------------- | ||||
155 | /// Super-vectorization however is not meant to be usable in a similar fashion | ||||
156 | /// to the SLP vectorizer. The main difference lies in the information that | ||||
157 | /// both vectorizers use: super-vectorization examines contiguity of memory | ||||
158 | /// references along fastest varying dimensions and loops with recursive nested | ||||
159 | /// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on | ||||
160 | /// the other hand, performs flat pattern matching inside a single unrolled loop | ||||
161 | /// body and stitches together pieces of load and store operations into full | ||||
162 | /// 1-D vectors. We envision that the SLP vectorizer is a good way to capture | ||||
163 | /// innermost loop, control-flow dependent patterns that super-vectorization may | ||||
164 | /// not be able to capture easily. In other words, super-vectorization does not | ||||
165 | /// aim at replacing the SLP vectorizer and the two solutions are complementary. | ||||
166 | /// | ||||
167 | /// Ongoing investigations: | ||||
168 | /// ----------------------- | ||||
169 | /// We discuss the following *early* places where super-vectorization is | ||||
170 | /// applicable and touch on the expected benefits and risks . We list the | ||||
171 | /// opportunities in the context of the traditional polyhedral compiler flow | ||||
172 | /// described in PPCG. There are essentially 6 places in the MLIR pass pipeline | ||||
173 | /// we expect to experiment with super-vectorization: | ||||
174 | /// 1. Right after language lowering to MLIR: this is the earliest time where | ||||
175 | /// super-vectorization is expected to be applied. At this level, all the | ||||
176 | /// language/user/library-level annotations are available and can be fully | ||||
177 | /// exploited. Examples include loop-type annotations (such as parallel, | ||||
178 | /// reduction, scan, dependence distance vector, vectorizable) as well as | ||||
179 | /// memory access annotations (such as non-aliasing writes guaranteed, | ||||
180 | /// indirect accesses that are permutations by construction) accesses or | ||||
181 | /// that a particular operation is prescribed atomic by the user. At this | ||||
182 | /// level, anything that enriches what dependence analysis can do should be | ||||
183 | /// aggressively exploited. At this level we are close to having explicit | ||||
184 | /// vector types in the language, except we do not impose that burden on the | ||||
185 | /// programmer/library: we derive information from scalar code + annotations. | ||||
186 | /// 2. After dependence analysis and before polyhedral scheduling: the | ||||
187 | /// information that supports vectorization does not need to be supplied by a | ||||
188 | /// higher level of abstraction. Traditional dependence analysis is available | ||||
189 | /// in MLIR and will be used to drive vectorization and cost models. | ||||
190 | /// | ||||
191 | /// Let's pause here and remark that applying super-vectorization as described | ||||
192 | /// in 1. and 2. presents clear opportunities and risks: | ||||
193 | /// - the opportunity is that vectorization is burned in the type system and | ||||
194 | /// is protected from the adverse effect of loop scheduling, tiling, loop | ||||
195 | /// interchange and all passes downstream. Provided that subsequent passes are | ||||
196 | /// able to operate on vector types; the vector shapes, associated loop | ||||
197 | /// iterator properties, alignment, and contiguity of fastest varying | ||||
198 | /// dimensions are preserved until we lower the super-vector types. We expect | ||||
199 | /// this to significantly rein in on the adverse effects of phase ordering. | ||||
200 | /// - the risks are that a. all passes after super-vectorization have to work | ||||
201 | /// on elemental vector types (not that this is always true, wherever | ||||
202 | /// vectorization is applied) and b. that imposing vectorization constraints | ||||
203 | /// too early may be overall detrimental to loop fusion, tiling and other | ||||
204 | /// transformations because the dependence distances are coarsened when | ||||
205 | /// operating on elemental vector types. For this reason, the pattern | ||||
206 | /// profitability analysis should include a component that also captures the | ||||
207 | /// maximal amount of fusion available under a particular pattern. This is | ||||
208 | /// still at the stage of rough ideas but in this context, search is our | ||||
209 | /// friend as the Tensor Comprehensions and auto-TVM contributions | ||||
210 | /// demonstrated previously. | ||||
211 | /// Bottom-line is we do not yet have good answers for the above but aim at | ||||
212 | /// making it easy to answer such questions. | ||||
213 | /// | ||||
214 | /// Back to our listing, the last places where early super-vectorization makes | ||||
215 | /// sense are: | ||||
216 | /// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known | ||||
217 | /// to improve locality, parallelism and be configurable (e.g. max-fuse, | ||||
218 | /// smart-fuse etc). They can also have adverse effects on contiguity | ||||
219 | /// properties that are required for vectorization but the vector.transfer | ||||
220 | /// copy-reshape-pad-transpose abstraction is expected to help recapture | ||||
221 | /// these properties. | ||||
222 | /// 4. right after polyhedral-style scheduling+tiling; | ||||
223 | /// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent | ||||
224 | /// probably the most promising places because applying tiling achieves a | ||||
225 | /// separation of concerns that allows rescheduling to worry less about | ||||
226 | /// locality and more about parallelism and distribution (e.g. min-fuse). | ||||
227 | /// | ||||
228 | /// At these levels the risk-reward looks different: on one hand we probably | ||||
229 | /// lost a good deal of language/user/library-level annotation; on the other | ||||
230 | /// hand we gained parallelism and locality through scheduling and tiling. | ||||
231 | /// However we probably want to ensure tiling is compatible with the | ||||
232 | /// full-tile-only abstraction used in super-vectorization or suffer the | ||||
233 | /// consequences. It is too early to place bets on what will win but we expect | ||||
234 | /// super-vectorization to be the right abstraction to allow exploring at all | ||||
235 | /// these levels. And again, search is our friend. | ||||
236 | /// | ||||
237 | /// Lastly, we mention it again here: | ||||
238 | /// 6. as a MLIR-based alternative to VPLAN. | ||||
239 | /// | ||||
240 | /// Lowering, unrolling, pipelining: | ||||
241 | /// ================================ | ||||
242 | /// TODO: point to the proper places. | ||||
243 | /// | ||||
244 | /// Algorithm: | ||||
245 | /// ========== | ||||
246 | /// The algorithm proceeds in a few steps: | ||||
247 | /// 1. defining super-vectorization patterns and matching them on the tree of | ||||
248 | /// AffineForOp. A super-vectorization pattern is defined as a recursive | ||||
249 | /// data structures that matches and captures nested, imperfectly-nested | ||||
250 | /// loops that have a. conformable loop annotations attached (e.g. parallel, | ||||
251 | /// reduction, vectorizable, ...) as well as b. all contiguous load/store | ||||
252 | /// operations along a specified minor dimension (not necessarily the | ||||
253 | /// fastest varying) ; | ||||
254 | /// 2. analyzing those patterns for profitability (TODO: and | ||||
255 | /// interference); | ||||
256 | /// 3. then, for each pattern in order: | ||||
257 | /// a. applying iterative rewriting of the loops and all their nested | ||||
258 | /// operations in topological order. Rewriting is implemented by | ||||
259 | /// coarsening the loops and converting operations and operands to their | ||||
260 | /// vector forms. Processing operations in topological order is relatively | ||||
261 | /// simple due to the structured nature of the control-flow | ||||
262 | /// representation. This order ensures that all the operands of a given | ||||
263 | /// operation have been vectorized before the operation itself in a single | ||||
264 | /// traversal, except for operands defined outside of the loop nest. The | ||||
265 | /// algorithm can convert the following operations to their vector form: | ||||
266 | /// * Affine load and store operations are converted to opaque vector | ||||
267 | /// transfer read and write operations. | ||||
268 | /// * Scalar constant operations/operands are converted to vector | ||||
269 | /// constant operations (splat). | ||||
270 | /// * Uniform operands (only induction variables of loops not mapped to | ||||
271 | /// a vector dimension, or operands defined outside of the loop nest | ||||
272 | /// for now) are broadcasted to a vector. | ||||
273 | /// TODO: Support more uniform cases. | ||||
274 | /// * Affine for operations with 'iter_args' are vectorized by | ||||
275 | /// vectorizing their 'iter_args' operands and results. | ||||
276 | /// TODO: Support more complex loops with divergent lbs and/or ubs. | ||||
277 | /// * The remaining operations in the loop nest are vectorized by | ||||
278 | /// widening their scalar types to vector types. | ||||
279 | /// b. if everything under the root AffineForOp in the current pattern | ||||
280 | /// is vectorized properly, we commit that loop to the IR and remove the | ||||
281 | /// scalar loop. Otherwise, we discard the vectorized loop and keep the | ||||
282 | /// original scalar loop. | ||||
283 | /// c. vectorization is applied on the next pattern in the list. Because | ||||
284 | /// pattern interference avoidance is not yet implemented and that we do | ||||
285 | /// not support further vectorizing an already vector load we need to | ||||
286 | /// re-verify that the pattern is still vectorizable. This is expected to | ||||
287 | /// make cost models more difficult to write and is subject to improvement | ||||
288 | /// in the future. | ||||
289 | /// | ||||
290 | /// Choice of loop transformation to support the algorithm: | ||||
291 | /// ======================================================= | ||||
292 | /// The choice of loop transformation to apply for coarsening vectorized loops | ||||
293 | /// is still subject to exploratory tradeoffs. In particular, say we want to | ||||
294 | /// vectorize by a factor 128, we want to transform the following input: | ||||
295 | /// ```mlir | ||||
296 | /// affine.for %i = %M to %N { | ||||
297 | /// %a = affine.load %A[%i] : memref<?xf32> | ||||
298 | /// } | ||||
299 | /// ``` | ||||
300 | /// | ||||
301 | /// Traditionally, one would vectorize late (after scheduling, tiling, | ||||
302 | /// memory promotion etc) say after stripmining (and potentially unrolling in | ||||
303 | /// the case of LLVM's SLP vectorizer): | ||||
304 | /// ```mlir | ||||
305 | /// affine.for %i = floor(%M, 128) to ceil(%N, 128) { | ||||
306 | /// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { | ||||
307 | /// %a = affine.load %A[%ii] : memref<?xf32> | ||||
308 | /// } | ||||
309 | /// } | ||||
310 | /// ``` | ||||
311 | /// | ||||
312 | /// Instead, we seek to vectorize early and freeze vector types before | ||||
313 | /// scheduling, so we want to generate a pattern that resembles: | ||||
314 | /// ```mlir | ||||
315 | /// affine.for %i = ? to ? step ? { | ||||
316 | /// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> | ||||
317 | /// } | ||||
318 | /// ``` | ||||
319 | /// | ||||
320 | /// i. simply dividing the lower / upper bounds by 128 creates issues | ||||
321 | /// when representing expressions such as ii + 1 because now we only | ||||
322 | /// have access to original values that have been divided. Additional | ||||
323 | /// information is needed to specify accesses at below-128 granularity; | ||||
324 | /// ii. another alternative is to coarsen the loop step but this may have | ||||
325 | /// consequences on dependence analysis and fusability of loops: fusable | ||||
326 | /// loops probably need to have the same step (because we don't want to | ||||
327 | /// stripmine/unroll to enable fusion). | ||||
328 | /// As a consequence, we choose to represent the coarsening using the loop | ||||
329 | /// step for now and reevaluate in the future. Note that we can renormalize | ||||
330 | /// loop steps later if/when we have evidence that they are problematic. | ||||
331 | /// | ||||
332 | /// For the simple strawman example above, vectorizing for a 1-D vector | ||||
333 | /// abstraction of size 128 returns code similar to: | ||||
334 | /// ```mlir | ||||
335 | /// affine.for %i = %M to %N step 128 { | ||||
336 | /// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> | ||||
337 | /// } | ||||
338 | /// ``` | ||||
339 | /// | ||||
340 | /// Unsupported cases, extensions, and work in progress (help welcome :-) ): | ||||
341 | /// ======================================================================== | ||||
342 | /// 1. lowering to concrete vector types for various HW; | ||||
343 | /// 2. reduction support for n-D vectorization and non-unit steps; | ||||
344 | /// 3. non-effecting padding during vector.transfer_read and filter during | ||||
345 | /// vector.transfer_write; | ||||
346 | /// 4. misalignment support vector.transfer_read / vector.transfer_write | ||||
347 | /// (hopefully without read-modify-writes); | ||||
348 | /// 5. control-flow support; | ||||
349 | /// 6. cost-models, heuristics and search; | ||||
350 | /// 7. Op implementation, extensions and implication on memref views; | ||||
351 | /// 8. many TODOs left around. | ||||
352 | /// | ||||
353 | /// Examples: | ||||
354 | /// ========= | ||||
355 | /// Consider the following Function: | ||||
356 | /// ```mlir | ||||
357 | /// func @vector_add_2d(%M : index, %N : index) -> f32 { | ||||
358 | /// %A = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
359 | /// %B = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
360 | /// %C = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
361 | /// %f1 = arith.constant 1.0 : f32 | ||||
362 | /// %f2 = arith.constant 2.0 : f32 | ||||
363 | /// affine.for %i0 = 0 to %M { | ||||
364 | /// affine.for %i1 = 0 to %N { | ||||
365 | /// // non-scoped %f1 | ||||
366 | /// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0> | ||||
367 | /// } | ||||
368 | /// } | ||||
369 | /// affine.for %i2 = 0 to %M { | ||||
370 | /// affine.for %i3 = 0 to %N { | ||||
371 | /// // non-scoped %f2 | ||||
372 | /// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0> | ||||
373 | /// } | ||||
374 | /// } | ||||
375 | /// affine.for %i4 = 0 to %M { | ||||
376 | /// affine.for %i5 = 0 to %N { | ||||
377 | /// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0> | ||||
378 | /// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0> | ||||
379 | /// %s5 = arith.addf %a5, %b5 : f32 | ||||
380 | /// // non-scoped %f1 | ||||
381 | /// %s6 = arith.addf %s5, %f1 : f32 | ||||
382 | /// // non-scoped %f2 | ||||
383 | /// %s7 = arith.addf %s5, %f2 : f32 | ||||
384 | /// // diamond dependency. | ||||
385 | /// %s8 = arith.addf %s7, %s6 : f32 | ||||
386 | /// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0> | ||||
387 | /// } | ||||
388 | /// } | ||||
389 | /// %c7 = arith.constant 7 : index | ||||
390 | /// %c42 = arith.constant 42 : index | ||||
391 | /// %res = load %C[%c7, %c42] : memref<?x?xf32, 0> | ||||
392 | /// return %res : f32 | ||||
393 | /// } | ||||
394 | /// ``` | ||||
395 | /// | ||||
396 | /// The -affine-super-vectorize pass with the following arguments: | ||||
397 | /// ``` | ||||
398 | /// -affine-super-vectorize="virtual-vector-size=256 test-fastest-varying=0" | ||||
399 | /// ``` | ||||
400 | /// | ||||
401 | /// produces this standard innermost-loop vectorized code: | ||||
402 | /// ```mlir | ||||
403 | /// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { | ||||
404 | /// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
405 | /// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
406 | /// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
407 | /// %cst = arith.constant 1.0 : f32 | ||||
408 | /// %cst_0 = arith.constant 2.0 : f32 | ||||
409 | /// affine.for %i0 = 0 to %arg0 { | ||||
410 | /// affine.for %i1 = 0 to %arg1 step 256 { | ||||
411 | /// %cst_1 = arith.constant dense<vector<256xf32>, 1.0> : | ||||
412 | /// vector<256xf32> | ||||
413 | /// vector.transfer_write %cst_1, %0[%i0, %i1] : | ||||
414 | /// vector<256xf32>, memref<?x?xf32> | ||||
415 | /// } | ||||
416 | /// } | ||||
417 | /// affine.for %i2 = 0 to %arg0 { | ||||
418 | /// affine.for %i3 = 0 to %arg1 step 256 { | ||||
419 | /// %cst_2 = arith.constant dense<vector<256xf32>, 2.0> : | ||||
420 | /// vector<256xf32> | ||||
421 | /// vector.transfer_write %cst_2, %1[%i2, %i3] : | ||||
422 | /// vector<256xf32>, memref<?x?xf32> | ||||
423 | /// } | ||||
424 | /// } | ||||
425 | /// affine.for %i4 = 0 to %arg0 { | ||||
426 | /// affine.for %i5 = 0 to %arg1 step 256 { | ||||
427 | /// %3 = vector.transfer_read %0[%i4, %i5] : | ||||
428 | /// memref<?x?xf32>, vector<256xf32> | ||||
429 | /// %4 = vector.transfer_read %1[%i4, %i5] : | ||||
430 | /// memref<?x?xf32>, vector<256xf32> | ||||
431 | /// %5 = arith.addf %3, %4 : vector<256xf32> | ||||
432 | /// %cst_3 = arith.constant dense<vector<256xf32>, 1.0> : | ||||
433 | /// vector<256xf32> | ||||
434 | /// %6 = arith.addf %5, %cst_3 : vector<256xf32> | ||||
435 | /// %cst_4 = arith.constant dense<vector<256xf32>, 2.0> : | ||||
436 | /// vector<256xf32> | ||||
437 | /// %7 = arith.addf %5, %cst_4 : vector<256xf32> | ||||
438 | /// %8 = arith.addf %7, %6 : vector<256xf32> | ||||
439 | /// vector.transfer_write %8, %2[%i4, %i5] : | ||||
440 | /// vector<256xf32>, memref<?x?xf32> | ||||
441 | /// } | ||||
442 | /// } | ||||
443 | /// %c7 = arith.constant 7 : index | ||||
444 | /// %c42 = arith.constant 42 : index | ||||
445 | /// %9 = load %2[%c7, %c42] : memref<?x?xf32> | ||||
446 | /// return %9 : f32 | ||||
447 | /// } | ||||
448 | /// ``` | ||||
449 | /// | ||||
450 | /// The -affine-super-vectorize pass with the following arguments: | ||||
451 | /// ``` | ||||
452 | /// -affine-super-vectorize="virtual-vector-size=32,256 \ | ||||
453 | /// test-fastest-varying=1,0" | ||||
454 | /// ``` | ||||
455 | /// | ||||
456 | /// produces this more interesting mixed outer-innermost-loop vectorized code: | ||||
457 | /// ```mlir | ||||
458 | /// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { | ||||
459 | /// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
460 | /// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
461 | /// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
462 | /// %cst = arith.constant 1.0 : f32 | ||||
463 | /// %cst_0 = arith.constant 2.0 : f32 | ||||
464 | /// affine.for %i0 = 0 to %arg0 step 32 { | ||||
465 | /// affine.for %i1 = 0 to %arg1 step 256 { | ||||
466 | /// %cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> : | ||||
467 | /// vector<32x256xf32> | ||||
468 | /// vector.transfer_write %cst_1, %0[%i0, %i1] : | ||||
469 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
470 | /// } | ||||
471 | /// } | ||||
472 | /// affine.for %i2 = 0 to %arg0 step 32 { | ||||
473 | /// affine.for %i3 = 0 to %arg1 step 256 { | ||||
474 | /// %cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> : | ||||
475 | /// vector<32x256xf32> | ||||
476 | /// vector.transfer_write %cst_2, %1[%i2, %i3] : | ||||
477 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
478 | /// } | ||||
479 | /// } | ||||
480 | /// affine.for %i4 = 0 to %arg0 step 32 { | ||||
481 | /// affine.for %i5 = 0 to %arg1 step 256 { | ||||
482 | /// %3 = vector.transfer_read %0[%i4, %i5] : | ||||
483 | /// memref<?x?xf32> vector<32x256xf32> | ||||
484 | /// %4 = vector.transfer_read %1[%i4, %i5] : | ||||
485 | /// memref<?x?xf32>, vector<32x256xf32> | ||||
486 | /// %5 = arith.addf %3, %4 : vector<32x256xf32> | ||||
487 | /// %cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> : | ||||
488 | /// vector<32x256xf32> | ||||
489 | /// %6 = arith.addf %5, %cst_3 : vector<32x256xf32> | ||||
490 | /// %cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> : | ||||
491 | /// vector<32x256xf32> | ||||
492 | /// %7 = arith.addf %5, %cst_4 : vector<32x256xf32> | ||||
493 | /// %8 = arith.addf %7, %6 : vector<32x256xf32> | ||||
494 | /// vector.transfer_write %8, %2[%i4, %i5] : | ||||
495 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
496 | /// } | ||||
497 | /// } | ||||
498 | /// %c7 = arith.constant 7 : index | ||||
499 | /// %c42 = arith.constant 42 : index | ||||
500 | /// %9 = load %2[%c7, %c42] : memref<?x?xf32> | ||||
501 | /// return %9 : f32 | ||||
502 | /// } | ||||
503 | /// ``` | ||||
504 | /// | ||||
505 | /// Of course, much more intricate n-D imperfectly-nested patterns can be | ||||
506 | /// vectorized too and specified in a fully declarative fashion. | ||||
507 | /// | ||||
508 | /// Reduction: | ||||
509 | /// ========== | ||||
510 | /// Vectorizing reduction loops along the reduction dimension is supported if: | ||||
511 | /// - the reduction kind is supported, | ||||
512 | /// - the vectorization is 1-D, and | ||||
513 | /// - the step size of the loop equals to one. | ||||
514 | /// | ||||
515 | /// Comparing to the non-vector-dimension case, two additional things are done | ||||
516 | /// during vectorization of such loops: | ||||
517 | /// - The resulting vector returned from the loop is reduced to a scalar using | ||||
518 | /// `vector.reduce`. | ||||
519 | /// - In some cases a mask is applied to the vector yielded at the end of the | ||||
520 | /// loop to prevent garbage values from being written to the accumulator. | ||||
521 | /// | ||||
522 | /// Reduction vectorization is switched off by default, it can be enabled by | ||||
523 | /// passing a map from loops to reductions to utility functions, or by passing | ||||
524 | /// `vectorize-reductions=true` to the vectorization pass. | ||||
525 | /// | ||||
526 | /// Consider the following example: | ||||
527 | /// ```mlir | ||||
528 | /// func @vecred(%in: memref<512xf32>) -> f32 { | ||||
529 | /// %cst = arith.constant 0.000000e+00 : f32 | ||||
530 | /// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { | ||||
531 | /// %ld = affine.load %in[%i] : memref<512xf32> | ||||
532 | /// %cos = math.cos %ld : f32 | ||||
533 | /// %add = arith.addf %part_sum, %cos : f32 | ||||
534 | /// affine.yield %add : f32 | ||||
535 | /// } | ||||
536 | /// return %sum : f32 | ||||
537 | /// } | ||||
538 | /// ``` | ||||
539 | /// | ||||
540 | /// The -affine-super-vectorize pass with the following arguments: | ||||
541 | /// ``` | ||||
542 | /// -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 \ | ||||
543 | /// vectorize-reductions=true" | ||||
544 | /// ``` | ||||
545 | /// produces the following output: | ||||
546 | /// ```mlir | ||||
547 | /// #map = affine_map<(d0) -> (-d0 + 500)> | ||||
548 | /// func @vecred(%arg0: memref<512xf32>) -> f32 { | ||||
549 | /// %cst = arith.constant 0.000000e+00 : f32 | ||||
550 | /// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32> | ||||
551 | /// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) | ||||
552 | /// -> (vector<128xf32>) { | ||||
553 | /// // %2 is the number of iterations left in the original loop. | ||||
554 | /// %2 = affine.apply #map(%arg1) | ||||
555 | /// %3 = vector.create_mask %2 : vector<128xi1> | ||||
556 | /// %cst_1 = arith.constant 0.000000e+00 : f32 | ||||
557 | /// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : | ||||
558 | /// memref<512xf32>, vector<128xf32> | ||||
559 | /// %5 = math.cos %4 : vector<128xf32> | ||||
560 | /// %6 = arith.addf %arg2, %5 : vector<128xf32> | ||||
561 | /// // We filter out the effect of last 12 elements using the mask. | ||||
562 | /// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> | ||||
563 | /// affine.yield %7 : vector<128xf32> | ||||
564 | /// } | ||||
565 | /// %1 = vector.reduction <add>, %0 : vector<128xf32> into f32 | ||||
566 | /// return %1 : f32 | ||||
567 | /// } | ||||
568 | /// ``` | ||||
569 | /// | ||||
570 | /// Note that because of loop misalignment we needed to apply a mask to prevent | ||||
571 | /// last 12 elements from affecting the final result. The mask is full of ones | ||||
572 | /// in every iteration except for the last one, in which it has the form | ||||
573 | /// `11...100...0` with 116 ones and 12 zeros. | ||||
574 | |||||
575 | #define DEBUG_TYPE"early-vect" "early-vect" | ||||
576 | |||||
577 | using llvm::dbgs; | ||||
578 | |||||
579 | /// Forward declaration. | ||||
580 | static FilterFunctionType | ||||
581 | isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, | ||||
582 | int fastestVaryingMemRefDimension); | ||||
583 | |||||
584 | /// Creates a vectorization pattern from the command line arguments. | ||||
585 | /// Up to 3-D patterns are supported. | ||||
586 | /// If the command line argument requests a pattern of higher order, returns an | ||||
587 | /// empty pattern list which will conservatively result in no vectorization. | ||||
588 | static std::optional<NestedPattern> | ||||
589 | makePattern(const DenseSet<Operation *> ¶llelLoops, int vectorRank, | ||||
590 | ArrayRef<int64_t> fastestVaryingPattern) { | ||||
591 | using affine::matcher::For; | ||||
592 | int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; | ||||
593 | int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1]; | ||||
594 | int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2]; | ||||
595 | switch (vectorRank) { | ||||
596 | case 1: | ||||
597 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0)); | ||||
598 | case 2: | ||||
599 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0), | ||||
600 | For(isVectorizableLoopPtrFactory(parallelLoops, d1))); | ||||
601 | case 3: | ||||
602 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0), | ||||
603 | For(isVectorizableLoopPtrFactory(parallelLoops, d1), | ||||
604 | For(isVectorizableLoopPtrFactory(parallelLoops, d2)))); | ||||
605 | default: { | ||||
606 | return std::nullopt; | ||||
607 | } | ||||
608 | } | ||||
609 | } | ||||
610 | |||||
611 | static NestedPattern &vectorTransferPattern() { | ||||
612 | static auto pattern = affine::matcher::Op([](Operation &op) { | ||||
613 | return isa<vector::TransferReadOp, vector::TransferWriteOp>(op); | ||||
614 | }); | ||||
615 | return pattern; | ||||
616 | } | ||||
617 | |||||
618 | namespace { | ||||
619 | |||||
620 | /// Base state for the vectorize pass. | ||||
621 | /// Command line arguments are preempted by non-empty pass arguments. | ||||
622 | struct Vectorize : public affine::impl::AffineVectorizeBase<Vectorize> { | ||||
623 | using Base::Base; | ||||
624 | |||||
625 | void runOnOperation() override; | ||||
626 | }; | ||||
627 | |||||
628 | } // namespace | ||||
629 | |||||
630 | static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern, | ||||
631 | unsigned patternDepth, | ||||
632 | VectorizationStrategy *strategy) { | ||||
633 | assert(patternDepth > depthInPattern &&(static_cast <bool> (patternDepth > depthInPattern && "patternDepth is greater than depthInPattern") ? void (0) : __assert_fail ("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 634 , __extension__ __PRETTY_FUNCTION__)) | ||||
634 | "patternDepth is greater than depthInPattern")(static_cast <bool> (patternDepth > depthInPattern && "patternDepth is greater than depthInPattern") ? void (0) : __assert_fail ("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 634 , __extension__ __PRETTY_FUNCTION__)); | ||||
635 | if (patternDepth - depthInPattern > strategy->vectorSizes.size()) { | ||||
636 | // Don't vectorize this loop | ||||
637 | return; | ||||
638 | } | ||||
639 | strategy->loopToVectorDim[loop] = | ||||
640 | strategy->vectorSizes.size() - (patternDepth - depthInPattern); | ||||
641 | } | ||||
642 | |||||
643 | /// Implements a simple strawman strategy for vectorization. | ||||
644 | /// Given a matched pattern `matches` of depth `patternDepth`, this strategy | ||||
645 | /// greedily assigns the fastest varying dimension ** of the vector ** to the | ||||
646 | /// innermost loop in the pattern. | ||||
647 | /// When coupled with a pattern that looks for the fastest varying dimension in | ||||
648 | /// load/store MemRefs, this creates a generic vectorization strategy that works | ||||
649 | /// for any loop in a hierarchy (outermost, innermost or intermediate). | ||||
650 | /// | ||||
651 | /// TODO: In the future we should additionally increase the power of the | ||||
652 | /// profitability analysis along 3 directions: | ||||
653 | /// 1. account for loop extents (both static and parametric + annotations); | ||||
654 | /// 2. account for data layout permutations; | ||||
655 | /// 3. account for impact of vectorization on maximal loop fusion. | ||||
656 | /// Then we can quantify the above to build a cost model and search over | ||||
657 | /// strategies. | ||||
658 | static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches, | ||||
659 | unsigned depthInPattern, | ||||
660 | unsigned patternDepth, | ||||
661 | VectorizationStrategy *strategy) { | ||||
662 | for (auto m : matches) { | ||||
663 | if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, | ||||
664 | patternDepth, strategy))) { | ||||
665 | return failure(); | ||||
666 | } | ||||
667 | vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern, | ||||
668 | patternDepth, strategy); | ||||
669 | } | ||||
670 | return success(); | ||||
671 | } | ||||
672 | |||||
673 | ///// end TODO: Hoist to a VectorizationStrategy.cpp when appropriate ///// | ||||
674 | |||||
675 | namespace { | ||||
676 | |||||
677 | struct VectorizationState { | ||||
678 | |||||
679 | VectorizationState(MLIRContext *context) : builder(context) {} | ||||
680 | |||||
681 | /// Registers the vector replacement of a scalar operation and its result | ||||
682 | /// values. Both operations must have the same number of results. | ||||
683 | /// | ||||
684 | /// This utility is used to register the replacement for the vast majority of | ||||
685 | /// the vectorized operations. | ||||
686 | /// | ||||
687 | /// Example: | ||||
688 | /// * 'replaced': %0 = arith.addf %1, %2 : f32 | ||||
689 | /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> | ||||
690 | void registerOpVectorReplacement(Operation *replaced, Operation *replacement); | ||||
691 | |||||
692 | /// Registers the vector replacement of a scalar value. The replacement | ||||
693 | /// operation should have a single result, which replaces the scalar value. | ||||
694 | /// | ||||
695 | /// This utility is used to register the vector replacement of block arguments | ||||
696 | /// and operation results which are not directly vectorized (i.e., their | ||||
697 | /// scalar version still exists after vectorization), like uniforms. | ||||
698 | /// | ||||
699 | /// Example: | ||||
700 | /// * 'replaced': block argument or operation outside of the vectorized | ||||
701 | /// loop. | ||||
702 | /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> | ||||
703 | void registerValueVectorReplacement(Value replaced, Operation *replacement); | ||||
704 | |||||
705 | /// Registers the vector replacement of a block argument (e.g., iter_args). | ||||
706 | /// | ||||
707 | /// Example: | ||||
708 | /// * 'replaced': 'iter_arg' block argument. | ||||
709 | /// * 'replacement': vectorized 'iter_arg' block argument. | ||||
710 | void registerBlockArgVectorReplacement(BlockArgument replaced, | ||||
711 | BlockArgument replacement); | ||||
712 | |||||
713 | /// Registers the scalar replacement of a scalar value. 'replacement' must be | ||||
714 | /// scalar. Both values must be block arguments. Operation results should be | ||||
715 | /// replaced using the 'registerOp*' utilitites. | ||||
716 | /// | ||||
717 | /// This utility is used to register the replacement of block arguments | ||||
718 | /// that are within the loop to be vectorized and will continue being scalar | ||||
719 | /// within the vector loop. | ||||
720 | /// | ||||
721 | /// Example: | ||||
722 | /// * 'replaced': induction variable of a loop to be vectorized. | ||||
723 | /// * 'replacement': new induction variable in the new vector loop. | ||||
724 | void registerValueScalarReplacement(BlockArgument replaced, | ||||
725 | BlockArgument replacement); | ||||
726 | |||||
727 | /// Registers the scalar replacement of a scalar result returned from a | ||||
728 | /// reduction loop. 'replacement' must be scalar. | ||||
729 | /// | ||||
730 | /// This utility is used to register the replacement for scalar results of | ||||
731 | /// vectorized reduction loops with iter_args. | ||||
732 | /// | ||||
733 | /// Example 2: | ||||
734 | /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) | ||||
735 | /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into | ||||
736 | /// f32 | ||||
737 | void registerLoopResultScalarReplacement(Value replaced, Value replacement); | ||||
738 | |||||
739 | /// Returns in 'replacedVals' the scalar replacement for values in | ||||
740 | /// 'inputVals'. | ||||
741 | void getScalarValueReplacementsFor(ValueRange inputVals, | ||||
742 | SmallVectorImpl<Value> &replacedVals); | ||||
743 | |||||
744 | /// Erases the scalar loop nest after its successful vectorization. | ||||
745 | void finishVectorizationPattern(AffineForOp rootLoop); | ||||
746 | |||||
747 | // Used to build and insert all the new operations created. The insertion | ||||
748 | // point is preserved and updated along the vectorization process. | ||||
749 | OpBuilder builder; | ||||
750 | |||||
751 | // Maps input scalar operations to their vector counterparts. | ||||
752 | DenseMap<Operation *, Operation *> opVectorReplacement; | ||||
753 | // Maps input scalar values to their vector counterparts. | ||||
754 | IRMapping valueVectorReplacement; | ||||
755 | // Maps input scalar values to their new scalar counterparts in the vector | ||||
756 | // loop nest. | ||||
757 | IRMapping valueScalarReplacement; | ||||
758 | // Maps results of reduction loops to their new scalar counterparts. | ||||
759 | DenseMap<Value, Value> loopResultScalarReplacement; | ||||
760 | |||||
761 | // Maps the newly created vector loops to their vector dimension. | ||||
762 | DenseMap<Operation *, unsigned> vecLoopToVecDim; | ||||
763 | |||||
764 | // Maps the new vectorized loops to the corresponding vector masks if it is | ||||
765 | // required. | ||||
766 | DenseMap<Operation *, Value> vecLoopToMask; | ||||
767 | |||||
768 | // The strategy drives which loop to vectorize by which amount. | ||||
769 | const VectorizationStrategy *strategy = nullptr; | ||||
770 | |||||
771 | private: | ||||
772 | /// Internal implementation to map input scalar values to new vector or scalar | ||||
773 | /// values. | ||||
774 | void registerValueVectorReplacementImpl(Value replaced, Value replacement); | ||||
775 | void registerValueScalarReplacementImpl(Value replaced, Value replacement); | ||||
776 | }; | ||||
777 | |||||
778 | } // namespace | ||||
779 | |||||
780 | /// Registers the vector replacement of a scalar operation and its result | ||||
781 | /// values. Both operations must have the same number of results. | ||||
782 | /// | ||||
783 | /// This utility is used to register the replacement for the vast majority of | ||||
784 | /// the vectorized operations. | ||||
785 | /// | ||||
786 | /// Example: | ||||
787 | /// * 'replaced': %0 = arith.addf %1, %2 : f32 | ||||
788 | /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> | ||||
789 | void VectorizationState::registerOpVectorReplacement(Operation *replaced, | ||||
790 | Operation *replacement) { | ||||
791 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ commit vectorized op:\n" ; } } while (false); | ||||
792 | LLVM_DEBUG(dbgs() << *replaced << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *replaced << "\n"; } } while (false); | ||||
793 | LLVM_DEBUG(dbgs() << "into\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "into\n"; } } while (false); | ||||
794 | LLVM_DEBUG(dbgs() << *replacement << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *replacement << "\n"; } } while (false); | ||||
795 | |||||
796 | assert(replaced->getNumResults() == replacement->getNumResults() &&(static_cast <bool> (replaced->getNumResults() == replacement ->getNumResults() && "Unexpected replaced and replacement results" ) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 797 , __extension__ __PRETTY_FUNCTION__)) | ||||
797 | "Unexpected replaced and replacement results")(static_cast <bool> (replaced->getNumResults() == replacement ->getNumResults() && "Unexpected replaced and replacement results" ) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 797 , __extension__ __PRETTY_FUNCTION__)); | ||||
798 | assert(opVectorReplacement.count(replaced) == 0 && "already registered")(static_cast <bool> (opVectorReplacement.count(replaced ) == 0 && "already registered") ? void (0) : __assert_fail ("opVectorReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 798 , __extension__ __PRETTY_FUNCTION__)); | ||||
799 | opVectorReplacement[replaced] = replacement; | ||||
800 | |||||
801 | for (auto resultTuple : | ||||
802 | llvm::zip(replaced->getResults(), replacement->getResults())) | ||||
803 | registerValueVectorReplacementImpl(std::get<0>(resultTuple), | ||||
804 | std::get<1>(resultTuple)); | ||||
805 | } | ||||
806 | |||||
807 | /// Registers the vector replacement of a scalar value. The replacement | ||||
808 | /// operation should have a single result, which replaces the scalar value. | ||||
809 | /// | ||||
810 | /// This utility is used to register the vector replacement of block arguments | ||||
811 | /// and operation results which are not directly vectorized (i.e., their | ||||
812 | /// scalar version still exists after vectorization), like uniforms. | ||||
813 | /// | ||||
814 | /// Example: | ||||
815 | /// * 'replaced': block argument or operation outside of the vectorized loop. | ||||
816 | /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> | ||||
817 | void VectorizationState::registerValueVectorReplacement( | ||||
818 | Value replaced, Operation *replacement) { | ||||
819 | assert(replacement->getNumResults() == 1 &&(static_cast <bool> (replacement->getNumResults() == 1 && "Expected single-result replacement") ? void (0 ) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 820 , __extension__ __PRETTY_FUNCTION__)) | ||||
820 | "Expected single-result replacement")(static_cast <bool> (replacement->getNumResults() == 1 && "Expected single-result replacement") ? void (0 ) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 820 , __extension__ __PRETTY_FUNCTION__)); | ||||
821 | if (Operation *defOp = replaced.getDefiningOp()) | ||||
822 | registerOpVectorReplacement(defOp, replacement); | ||||
823 | else | ||||
824 | registerValueVectorReplacementImpl(replaced, replacement->getResult(0)); | ||||
825 | } | ||||
826 | |||||
827 | /// Registers the vector replacement of a block argument (e.g., iter_args). | ||||
828 | /// | ||||
829 | /// Example: | ||||
830 | /// * 'replaced': 'iter_arg' block argument. | ||||
831 | /// * 'replacement': vectorized 'iter_arg' block argument. | ||||
832 | void VectorizationState::registerBlockArgVectorReplacement( | ||||
833 | BlockArgument replaced, BlockArgument replacement) { | ||||
834 | registerValueVectorReplacementImpl(replaced, replacement); | ||||
835 | } | ||||
836 | |||||
837 | void VectorizationState::registerValueVectorReplacementImpl(Value replaced, | ||||
838 | Value replacement) { | ||||
839 | assert(!valueVectorReplacement.contains(replaced) &&(static_cast <bool> (!valueVectorReplacement.contains(replaced ) && "Vector replacement already registered") ? void ( 0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 840 , __extension__ __PRETTY_FUNCTION__)) | ||||
840 | "Vector replacement already registered")(static_cast <bool> (!valueVectorReplacement.contains(replaced ) && "Vector replacement already registered") ? void ( 0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 840 , __extension__ __PRETTY_FUNCTION__)); | ||||
841 | assert(replacement.getType().isa<VectorType>() &&(static_cast <bool> (replacement.getType().isa<VectorType >() && "Expected vector type in vector replacement" ) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 842 , __extension__ __PRETTY_FUNCTION__)) | ||||
842 | "Expected vector type in vector replacement")(static_cast <bool> (replacement.getType().isa<VectorType >() && "Expected vector type in vector replacement" ) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 842 , __extension__ __PRETTY_FUNCTION__)); | ||||
843 | valueVectorReplacement.map(replaced, replacement); | ||||
844 | } | ||||
845 | |||||
846 | /// Registers the scalar replacement of a scalar value. 'replacement' must be | ||||
847 | /// scalar. Both values must be block arguments. Operation results should be | ||||
848 | /// replaced using the 'registerOp*' utilitites. | ||||
849 | /// | ||||
850 | /// This utility is used to register the replacement of block arguments | ||||
851 | /// that are within the loop to be vectorized and will continue being scalar | ||||
852 | /// within the vector loop. | ||||
853 | /// | ||||
854 | /// Example: | ||||
855 | /// * 'replaced': induction variable of a loop to be vectorized. | ||||
856 | /// * 'replacement': new induction variable in the new vector loop. | ||||
857 | void VectorizationState::registerValueScalarReplacement( | ||||
858 | BlockArgument replaced, BlockArgument replacement) { | ||||
859 | registerValueScalarReplacementImpl(replaced, replacement); | ||||
860 | } | ||||
861 | |||||
862 | /// Registers the scalar replacement of a scalar result returned from a | ||||
863 | /// reduction loop. 'replacement' must be scalar. | ||||
864 | /// | ||||
865 | /// This utility is used to register the replacement for scalar results of | ||||
866 | /// vectorized reduction loops with iter_args. | ||||
867 | /// | ||||
868 | /// Example 2: | ||||
869 | /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) | ||||
870 | /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into f32 | ||||
871 | void VectorizationState::registerLoopResultScalarReplacement( | ||||
872 | Value replaced, Value replacement) { | ||||
873 | assert(isa<AffineForOp>(replaced.getDefiningOp()))(static_cast <bool> (isa<AffineForOp>(replaced.getDefiningOp ())) ? void (0) : __assert_fail ("isa<AffineForOp>(replaced.getDefiningOp())" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 873 , __extension__ __PRETTY_FUNCTION__)); | ||||
874 | assert(loopResultScalarReplacement.count(replaced) == 0 &&(static_cast <bool> (loopResultScalarReplacement.count( replaced) == 0 && "already registered") ? void (0) : __assert_fail ("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 875 , __extension__ __PRETTY_FUNCTION__)) | ||||
875 | "already registered")(static_cast <bool> (loopResultScalarReplacement.count( replaced) == 0 && "already registered") ? void (0) : __assert_fail ("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 875 , __extension__ __PRETTY_FUNCTION__)); | ||||
876 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false) | ||||
877 | "with scalar: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false) | ||||
878 | << replacement)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false); | ||||
879 | loopResultScalarReplacement[replaced] = replacement; | ||||
880 | } | ||||
881 | |||||
882 | void VectorizationState::registerValueScalarReplacementImpl(Value replaced, | ||||
883 | Value replacement) { | ||||
884 | assert(!valueScalarReplacement.contains(replaced) &&(static_cast <bool> (!valueScalarReplacement.contains(replaced ) && "Scalar value replacement already registered") ? void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 885 , __extension__ __PRETTY_FUNCTION__)) | ||||
885 | "Scalar value replacement already registered")(static_cast <bool> (!valueScalarReplacement.contains(replaced ) && "Scalar value replacement already registered") ? void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 885 , __extension__ __PRETTY_FUNCTION__)); | ||||
886 | assert(!replacement.getType().isa<VectorType>() &&(static_cast <bool> (!replacement.getType().isa<VectorType >() && "Expected scalar type in scalar replacement" ) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 887 , __extension__ __PRETTY_FUNCTION__)) | ||||
887 | "Expected scalar type in scalar replacement")(static_cast <bool> (!replacement.getType().isa<VectorType >() && "Expected scalar type in scalar replacement" ) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 887 , __extension__ __PRETTY_FUNCTION__)); | ||||
888 | valueScalarReplacement.map(replaced, replacement); | ||||
889 | } | ||||
890 | |||||
891 | /// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'. | ||||
892 | void VectorizationState::getScalarValueReplacementsFor( | ||||
893 | ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) { | ||||
894 | for (Value inputVal : inputVals) | ||||
895 | replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal)); | ||||
896 | } | ||||
897 | |||||
898 | /// Erases a loop nest, including all its nested operations. | ||||
899 | static void eraseLoopNest(AffineForOp forOp) { | ||||
900 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n"; } } while (false); | ||||
901 | forOp.erase(); | ||||
902 | } | ||||
903 | |||||
904 | /// Erases the scalar loop nest after its successful vectorization. | ||||
905 | void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) { | ||||
906 | LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] Finalizing vectorization\n" ; } } while (false); | ||||
907 | eraseLoopNest(rootLoop); | ||||
908 | } | ||||
909 | |||||
910 | // Apply 'map' with 'mapOperands' returning resulting values in 'results'. | ||||
911 | static void computeMemoryOpIndices(Operation *op, AffineMap map, | ||||
912 | ValueRange mapOperands, | ||||
913 | VectorizationState &state, | ||||
914 | SmallVectorImpl<Value> &results) { | ||||
915 | for (auto resultExpr : map.getResults()) { | ||||
916 | auto singleResMap = | ||||
917 | AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); | ||||
918 | auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap, | ||||
919 | mapOperands); | ||||
920 | results.push_back(afOp); | ||||
921 | } | ||||
922 | } | ||||
923 | |||||
924 | /// Returns a FilterFunctionType that can be used in NestedPattern to match a | ||||
925 | /// loop whose underlying load/store accesses are either invariant or all | ||||
926 | // varying along the `fastestVaryingMemRefDimension`. | ||||
927 | static FilterFunctionType | ||||
928 | isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, | ||||
929 | int fastestVaryingMemRefDimension) { | ||||
930 | return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { | ||||
931 | auto loop = cast<AffineForOp>(forOp); | ||||
932 | auto parallelIt = parallelLoops.find(loop); | ||||
933 | if (parallelIt == parallelLoops.end()) | ||||
934 | return false; | ||||
935 | int memRefDim = -1; | ||||
936 | auto vectorizableBody = | ||||
937 | isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern()); | ||||
938 | if (!vectorizableBody) | ||||
939 | return false; | ||||
940 | return memRefDim == -1 || fastestVaryingMemRefDimension == -1 || | ||||
941 | memRefDim == fastestVaryingMemRefDimension; | ||||
942 | }; | ||||
943 | } | ||||
944 | |||||
945 | /// Returns the vector type resulting from applying the provided vectorization | ||||
946 | /// strategy on the scalar type. | ||||
947 | static VectorType getVectorType(Type scalarTy, | ||||
948 | const VectorizationStrategy *strategy) { | ||||
949 | assert(!scalarTy.isa<VectorType>() && "Expected scalar type")(static_cast <bool> (!scalarTy.isa<VectorType>() && "Expected scalar type") ? void (0) : __assert_fail ("!scalarTy.isa<VectorType>() && \"Expected scalar type\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 949 , __extension__ __PRETTY_FUNCTION__)); | ||||
950 | return VectorType::get(strategy->vectorSizes, scalarTy); | ||||
951 | } | ||||
952 | |||||
953 | /// Tries to transform a scalar constant into a vector constant. Returns the | ||||
954 | /// vector constant if the scalar type is valid vector element type. Returns | ||||
955 | /// nullptr, otherwise. | ||||
956 | static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, | ||||
957 | VectorizationState &state) { | ||||
958 | Type scalarTy = constOp.getType(); | ||||
959 | if (!VectorType::isValidElementType(scalarTy)) | ||||
960 | return nullptr; | ||||
| |||||
| |||||
961 | |||||
962 | auto vecTy = getVectorType(scalarTy, state.strategy); | ||||
963 | auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); | ||||
964 | |||||
965 | OpBuilder::InsertionGuard guard(state.builder); | ||||
966 | Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); | ||||
967 | // Find the innermost vectorized ancestor loop to insert the vector constant. | ||||
968 | while (parentOp && !state.vecLoopToVecDim.count(parentOp)) | ||||
969 | parentOp = parentOp->getParentOp(); | ||||
970 | assert(parentOp && state.vecLoopToVecDim.count(parentOp) &&(static_cast <bool> (parentOp && state.vecLoopToVecDim .count(parentOp) && isa<AffineForOp>(parentOp) && "Expected a vectorized for op") ? void (0) : __assert_fail ( "parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 971 , __extension__ __PRETTY_FUNCTION__)) | ||||
971 | isa<AffineForOp>(parentOp) && "Expected a vectorized for op")(static_cast <bool> (parentOp && state.vecLoopToVecDim .count(parentOp) && isa<AffineForOp>(parentOp) && "Expected a vectorized for op") ? void (0) : __assert_fail ( "parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 971 , __extension__ __PRETTY_FUNCTION__)); | ||||
972 | auto vecForOp = cast<AffineForOp>(parentOp); | ||||
973 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
974 | auto newConstOp = | ||||
975 | state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr); | ||||
976 | |||||
977 | // Register vector replacement for future uses in the scope. | ||||
978 | state.registerOpVectorReplacement(constOp, newConstOp); | ||||
979 | return newConstOp; | ||||
980 | } | ||||
981 | |||||
982 | /// Creates a constant vector filled with the neutral elements of the given | ||||
983 | /// reduction. The scalar type of vector elements will be taken from | ||||
984 | /// `oldOperand`. | ||||
985 | static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, | ||||
986 | Value oldOperand, | ||||
987 | VectorizationState &state) { | ||||
988 | Type scalarTy = oldOperand.getType(); | ||||
989 | if (!VectorType::isValidElementType(scalarTy)) | ||||
990 | return nullptr; | ||||
991 | |||||
992 | Attribute valueAttr = getIdentityValueAttr( | ||||
993 | reductionKind, scalarTy, state.builder, oldOperand.getLoc()); | ||||
994 | auto vecTy = getVectorType(scalarTy, state.strategy); | ||||
995 | auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); | ||||
996 | auto newConstOp = | ||||
997 | state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr); | ||||
998 | |||||
999 | return newConstOp; | ||||
1000 | } | ||||
1001 | |||||
1002 | /// Creates a mask used to filter out garbage elements in the last iteration | ||||
1003 | /// of unaligned loops. If a mask is not required then `nullptr` is returned. | ||||
1004 | /// The mask will be a vector of booleans representing meaningful vector | ||||
1005 | /// elements in the current iteration. It is filled with ones for each iteration | ||||
1006 | /// except for the last one, where it has the form `11...100...0` with the | ||||
1007 | /// number of ones equal to the number of meaningful elements (i.e. the number | ||||
1008 | /// of iterations that would be left in the original loop). | ||||
1009 | static Value createMask(AffineForOp vecForOp, VectorizationState &state) { | ||||
1010 | assert(state.strategy->vectorSizes.size() == 1 &&(static_cast <bool> (state.strategy->vectorSizes.size () == 1 && "Creating a mask non-1-D vectors is not supported." ) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1011 , __extension__ __PRETTY_FUNCTION__)) | ||||
1011 | "Creating a mask non-1-D vectors is not supported.")(static_cast <bool> (state.strategy->vectorSizes.size () == 1 && "Creating a mask non-1-D vectors is not supported." ) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1011 , __extension__ __PRETTY_FUNCTION__)); | ||||
1012 | assert(vecForOp.getStep() == state.strategy->vectorSizes[0] &&(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1014 , __extension__ __PRETTY_FUNCTION__)) | ||||
1013 | "Creating a mask for loops with non-unit original step size is not "(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1014 , __extension__ __PRETTY_FUNCTION__)) | ||||
1014 | "supported.")(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1014 , __extension__ __PRETTY_FUNCTION__)); | ||||
1015 | |||||
1016 | // Check if we have already created the mask. | ||||
1017 | if (Value mask = state.vecLoopToMask.lookup(vecForOp)) | ||||
1018 | return mask; | ||||
1019 | |||||
1020 | // If the loop has constant bounds and the original number of iterations is | ||||
1021 | // divisable by the vector size then we don't need a mask. | ||||
1022 | if (vecForOp.hasConstantBounds()) { | ||||
1023 | int64_t originalTripCount = | ||||
1024 | vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound(); | ||||
1025 | if (originalTripCount % vecForOp.getStep() == 0) | ||||
1026 | return nullptr; | ||||
1027 | } | ||||
1028 | |||||
1029 | OpBuilder::InsertionGuard guard(state.builder); | ||||
1030 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
1031 | |||||
1032 | // We generate the mask using the `vector.create_mask` operation which accepts | ||||
1033 | // the number of meaningful elements (i.e. the length of the prefix of 1s). | ||||
1034 | // To compute the number of meaningful elements we subtract the current value | ||||
1035 | // of the iteration variable from the upper bound of the loop. Example: | ||||
1036 | // | ||||
1037 | // // 500 is the upper bound of the loop | ||||
1038 | // #map = affine_map<(d0) -> (500 - d0)> | ||||
1039 | // %elems_left = affine.apply #map(%iv) | ||||
1040 | // %mask = vector.create_mask %elems_left : vector<128xi1> | ||||
1041 | |||||
1042 | Location loc = vecForOp.getLoc(); | ||||
1043 | |||||
1044 | // First we get the upper bound of the loop using `affine.apply` or | ||||
1045 | // `affine.min`. | ||||
1046 | AffineMap ubMap = vecForOp.getUpperBoundMap(); | ||||
1047 | Value ub; | ||||
1048 | if (ubMap.getNumResults() == 1) | ||||
1049 | ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(), | ||||
1050 | vecForOp.getUpperBoundOperands()); | ||||
1051 | else | ||||
1052 | ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(), | ||||
1053 | vecForOp.getUpperBoundOperands()); | ||||
1054 | // Then we compute the number of (original) iterations left in the loop. | ||||
1055 | AffineExpr subExpr = | ||||
1056 | state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); | ||||
1057 | Value itersLeft = | ||||
1058 | makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr), | ||||
1059 | {ub, vecForOp.getInductionVar()}); | ||||
1060 | // If the affine maps were successfully composed then `ub` is unneeded. | ||||
1061 | if (ub.use_empty()) | ||||
1062 | ub.getDefiningOp()->erase(); | ||||
1063 | // Finally we create the mask. | ||||
1064 | Type maskTy = VectorType::get(state.strategy->vectorSizes, | ||||
1065 | state.builder.getIntegerType(1)); | ||||
1066 | Value mask = | ||||
1067 | state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft); | ||||
1068 | |||||
1069 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false) | ||||
1070 | << itersLeft << "\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false) | ||||
1071 | << mask << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false); | ||||
1072 | |||||
1073 | state.vecLoopToMask[vecForOp] = mask; | ||||
1074 | return mask; | ||||
1075 | } | ||||
1076 | |||||
1077 | /// Returns true if the provided value is vector uniform given the vectorization | ||||
1078 | /// strategy. | ||||
1079 | // TODO: For now, only values that are induction variables of loops not in | ||||
1080 | // `loopToVectorDim` or invariants to all the loops in the vectorization | ||||
1081 | // strategy are considered vector uniforms. | ||||
1082 | static bool isUniformDefinition(Value value, | ||||
1083 | const VectorizationStrategy *strategy) { | ||||
1084 | AffineForOp forOp = getForInductionVarOwner(value); | ||||
1085 | if (forOp && strategy->loopToVectorDim.count(forOp) == 0) | ||||
1086 | return true; | ||||
1087 | |||||
1088 | for (auto loopToDim : strategy->loopToVectorDim) { | ||||
1089 | auto loop = cast<AffineForOp>(loopToDim.first); | ||||
1090 | if (!loop.isDefinedOutsideOfLoop(value)) | ||||
1091 | return false; | ||||
1092 | } | ||||
1093 | return true; | ||||
1094 | } | ||||
1095 | |||||
1096 | /// Generates a broadcast op for the provided uniform value using the | ||||
1097 | /// vectorization strategy in 'state'. | ||||
1098 | static Operation *vectorizeUniform(Value uniformVal, | ||||
1099 | VectorizationState &state) { | ||||
1100 | OpBuilder::InsertionGuard guard(state.builder); | ||||
1101 | Value uniformScalarRepl = | ||||
1102 | state.valueScalarReplacement.lookupOrDefault(uniformVal); | ||||
1103 | state.builder.setInsertionPointAfterValue(uniformScalarRepl); | ||||
1104 | |||||
1105 | auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); | ||||
1106 | auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(), | ||||
1107 | vectorTy, uniformScalarRepl); | ||||
1108 | state.registerValueVectorReplacement(uniformVal, bcastOp); | ||||
1109 | return bcastOp; | ||||
1110 | } | ||||
1111 | |||||
1112 | /// Tries to vectorize a given `operand` by applying the following logic: | ||||
1113 | /// 1. if the defining operation has been already vectorized, `operand` is | ||||
1114 | /// already in the proper vector form; | ||||
1115 | /// 2. if the `operand` is a constant, returns the vectorized form of the | ||||
1116 | /// constant; | ||||
1117 | /// 3. if the `operand` is uniform, returns a vector broadcast of the `op`; | ||||
1118 | /// 4. otherwise, the vectorization of `operand` is not supported. | ||||
1119 | /// Newly created vector operations are registered in `state` as replacement | ||||
1120 | /// for their scalar counterparts. | ||||
1121 | /// In particular this logic captures some of the use cases where definitions | ||||
1122 | /// that are not scoped under the current pattern are needed to vectorize. | ||||
1123 | /// One such example is top level function constants that need to be splatted. | ||||
1124 | /// | ||||
1125 | /// Returns an operand that has been vectorized to match `state`'s strategy if | ||||
1126 | /// vectorization is possible with the above logic. Returns nullptr otherwise. | ||||
1127 | /// | ||||
1128 | /// TODO: handle more complex cases. | ||||
1129 | static Value vectorizeOperand(Value operand, VectorizationState &state) { | ||||
1130 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand; } } while (false); | ||||
1131 | // If this value is already vectorized, we are done. | ||||
1132 | if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) { | ||||
1133 | LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << " -> already vectorized: " << vecRepl; } } while (false); | ||||
1134 | return vecRepl; | ||||
1135 | } | ||||
1136 | |||||
1137 | // An vector operand that is not in the replacement map should never reach | ||||
1138 | // this point. Reaching this point could mean that the code was already | ||||
1139 | // vectorized and we shouldn't try to vectorize already vectorized code. | ||||
1140 | assert(!operand.getType().isa<VectorType>() &&(static_cast <bool> (!operand.getType().isa<VectorType >() && "Vector op not found in replacement map") ? void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1141 , __extension__ __PRETTY_FUNCTION__)) | ||||
1141 | "Vector op not found in replacement map")(static_cast <bool> (!operand.getType().isa<VectorType >() && "Vector op not found in replacement map") ? void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1141 , __extension__ __PRETTY_FUNCTION__)); | ||||
1142 | |||||
1143 | // Vectorize constant. | ||||
1144 | if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) { | ||||
1145 | auto vecConstant = vectorizeConstant(constOp, state); | ||||
1146 | LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> constant: " << vecConstant; } } while (false); | ||||
1147 | return vecConstant.getResult(); | ||||
1148 | } | ||||
1149 | |||||
1150 | // Vectorize uniform values. | ||||
1151 | if (isUniformDefinition(operand, state.strategy)) { | ||||
1152 | Operation *vecUniform = vectorizeUniform(operand, state); | ||||
1153 | LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> uniform: " << * vecUniform; } } while (false); | ||||
1154 | return vecUniform->getResult(0); | ||||
1155 | } | ||||
1156 | |||||
1157 | // Check for unsupported block argument scenarios. A supported block argument | ||||
1158 | // should have been vectorized already. | ||||
1159 | if (!operand.getDefiningOp()) | ||||
1160 | LLVM_DEBUG(dbgs() << "-> unsupported block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> unsupported block argument\n" ; } } while (false); | ||||
1161 | else | ||||
1162 | // Generic unsupported case. | ||||
1163 | LLVM_DEBUG(dbgs() << "-> non-vectorizable\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> non-vectorizable\n"; } } while (false); | ||||
1164 | |||||
1165 | return nullptr; | ||||
1166 | } | ||||
1167 | |||||
1168 | /// Vectorizes an affine load with the vectorization strategy in 'state' by | ||||
1169 | /// generating a 'vector.transfer_read' op with the proper permutation map | ||||
1170 | /// inferred from the indices of the load. The new 'vector.transfer_read' is | ||||
1171 | /// registered as replacement of the scalar load. Returns the newly created | ||||
1172 | /// 'vector.transfer_read' if vectorization was successful. Returns nullptr, | ||||
1173 | /// otherwise. | ||||
1174 | static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, | ||||
1175 | VectorizationState &state) { | ||||
1176 | MemRefType memRefType = loadOp.getMemRefType(); | ||||
1177 | Type elementType = memRefType.getElementType(); | ||||
1178 | auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType); | ||||
1179 | |||||
1180 | // Replace map operands with operands from the vector loop nest. | ||||
1181 | SmallVector<Value, 8> mapOperands; | ||||
1182 | state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands); | ||||
1183 | |||||
1184 | // Compute indices for the transfer op. AffineApplyOp's may be generated. | ||||
1185 | SmallVector<Value, 8> indices; | ||||
1186 | indices.reserve(memRefType.getRank()); | ||||
1187 | if (loadOp.getAffineMap() != | ||||
1188 | state.builder.getMultiDimIdentityMap(memRefType.getRank())) | ||||
1189 | computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state, | ||||
1190 | indices); | ||||
1191 | else | ||||
1192 | indices.append(mapOperands.begin(), mapOperands.end()); | ||||
1193 | |||||
1194 | // Compute permutation map using the information of new vector loops. | ||||
1195 | auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), | ||||
1196 | indices, state.vecLoopToVecDim); | ||||
1197 | if (!permutationMap) { | ||||
1198 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n" ; } } while (false); | ||||
1199 | return nullptr; | ||||
1200 | } | ||||
1201 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: " ; } } while (false); | ||||
1202 | LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { permutationMap.print(dbgs()); } } while (false ); | ||||
1203 | |||||
1204 | auto transfer = state.builder.create<vector::TransferReadOp>( | ||||
1205 | loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap); | ||||
1206 | |||||
1207 | // Register replacement for future uses in the scope. | ||||
1208 | state.registerOpVectorReplacement(loadOp, transfer); | ||||
1209 | return transfer; | ||||
1210 | } | ||||
1211 | |||||
1212 | /// Vectorizes an affine store with the vectorization strategy in 'state' by | ||||
1213 | /// generating a 'vector.transfer_write' op with the proper permutation map | ||||
1214 | /// inferred from the indices of the store. The new 'vector.transfer_store' is | ||||
1215 | /// registered as replacement of the scalar load. Returns the newly created | ||||
1216 | /// 'vector.transfer_write' if vectorization was successful. Returns nullptr, | ||||
1217 | /// otherwise. | ||||
1218 | static Operation *vectorizeAffineStore(AffineStoreOp storeOp, | ||||
1219 | VectorizationState &state) { | ||||
1220 | MemRefType memRefType = storeOp.getMemRefType(); | ||||
1221 | Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state); | ||||
1222 | if (!vectorValue) | ||||
1223 | return nullptr; | ||||
1224 | |||||
1225 | // Replace map operands with operands from the vector loop nest. | ||||
1226 | SmallVector<Value, 8> mapOperands; | ||||
1227 | state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands); | ||||
1228 | |||||
1229 | // Compute indices for the transfer op. AffineApplyOp's may be generated. | ||||
1230 | SmallVector<Value, 8> indices; | ||||
1231 | indices.reserve(memRefType.getRank()); | ||||
1232 | if (storeOp.getAffineMap() != | ||||
1233 | state.builder.getMultiDimIdentityMap(memRefType.getRank())) | ||||
1234 | computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state, | ||||
1235 | indices); | ||||
1236 | else | ||||
1237 | indices.append(mapOperands.begin(), mapOperands.end()); | ||||
1238 | |||||
1239 | // Compute permutation map using the information of new vector loops. | ||||
1240 | auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), | ||||
1241 | indices, state.vecLoopToVecDim); | ||||
1242 | if (!permutationMap) | ||||
1243 | return nullptr; | ||||
1244 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: " ; } } while (false); | ||||
1245 | LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { permutationMap.print(dbgs()); } } while (false ); | ||||
1246 | |||||
1247 | auto transfer = state.builder.create<vector::TransferWriteOp>( | ||||
1248 | storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, | ||||
1249 | permutationMap); | ||||
1250 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer; } } while (false); | ||||
1251 | |||||
1252 | // Register replacement for future uses in the scope. | ||||
1253 | state.registerOpVectorReplacement(storeOp, transfer); | ||||
1254 | return transfer; | ||||
1255 | } | ||||
1256 | |||||
1257 | /// Returns true if `value` is a constant equal to the neutral element of the | ||||
1258 | /// given vectorizable reduction. | ||||
1259 | static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind, | ||||
1260 | Value value, VectorizationState &state) { | ||||
1261 | Type scalarTy = value.getType(); | ||||
1262 | if (!VectorType::isValidElementType(scalarTy)) | ||||
1263 | return false; | ||||
1264 | Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, | ||||
1265 | state.builder, value.getLoc()); | ||||
1266 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp())) | ||||
1267 | return constOp.getValue() == valueAttr; | ||||
1268 | return false; | ||||
1269 | } | ||||
1270 | |||||
1271 | /// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is | ||||
1272 | /// created and registered as replacement for the scalar loop. The builder's | ||||
1273 | /// insertion point is set to the new loop's body so that subsequent vectorized | ||||
1274 | /// operations are inserted into the new loop. If the loop is a vector | ||||
1275 | /// dimension, the step of the newly created loop will reflect the vectorization | ||||
1276 | /// factor used to vectorized that dimension. | ||||
1277 | static Operation *vectorizeAffineForOp(AffineForOp forOp, | ||||
1278 | VectorizationState &state) { | ||||
1279 | const VectorizationStrategy &strategy = *state.strategy; | ||||
1280 | auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); | ||||
1281 | bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); | ||||
1282 | |||||
1283 | // TODO: Vectorization of reduction loops is not supported for non-unit steps. | ||||
1284 | if (isLoopVecDim && forOp.getNumIterOperands() > 0 && forOp.getStep() != 1) { | ||||
1285 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1286 | dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1287 | << "\n[early-vect]+++++ unsupported step size for reduction loop: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1288 | << forOp.getStep() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false); | ||||
1289 | return nullptr; | ||||
1290 | } | ||||
1291 | |||||
1292 | // If we are vectorizing a vector dimension, compute a new step for the new | ||||
1293 | // vectorized loop using the vectorization factor for the vector dimension. | ||||
1294 | // Otherwise, propagate the step of the scalar loop. | ||||
1295 | unsigned newStep; | ||||
1296 | if (isLoopVecDim) { | ||||
1297 | unsigned vectorDim = loopToVecDimIt->second; | ||||
1298 | assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow")(static_cast <bool> (vectorDim < strategy.vectorSizes .size() && "vector dim overflow") ? void (0) : __assert_fail ("vectorDim < strategy.vectorSizes.size() && \"vector dim overflow\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1298 , __extension__ __PRETTY_FUNCTION__)); | ||||
1299 | int64_t forOpVecFactor = strategy.vectorSizes[vectorDim]; | ||||
1300 | newStep = forOp.getStep() * forOpVecFactor; | ||||
1301 | } else { | ||||
1302 | newStep = forOp.getStep(); | ||||
1303 | } | ||||
1304 | |||||
1305 | // Get information about reduction kinds. | ||||
1306 | ArrayRef<LoopReduction> reductions; | ||||
1307 | if (isLoopVecDim && forOp.getNumIterOperands() > 0) { | ||||
1308 | auto it = strategy.reductionLoops.find(forOp); | ||||
1309 | assert(it != strategy.reductionLoops.end() &&(static_cast <bool> (it != strategy.reductionLoops.end( ) && "Reduction descriptors not found when vectorizing a reduction loop" ) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1310 , __extension__ __PRETTY_FUNCTION__)) | ||||
1310 | "Reduction descriptors not found when vectorizing a reduction loop")(static_cast <bool> (it != strategy.reductionLoops.end( ) && "Reduction descriptors not found when vectorizing a reduction loop" ) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1310 , __extension__ __PRETTY_FUNCTION__)); | ||||
1311 | reductions = it->second; | ||||
1312 | assert(reductions.size() == forOp.getNumIterOperands() &&(static_cast <bool> (reductions.size() == forOp.getNumIterOperands () && "The size of reductions array must match the number of iter_args" ) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1313 , __extension__ __PRETTY_FUNCTION__)) | ||||
1313 | "The size of reductions array must match the number of iter_args")(static_cast <bool> (reductions.size() == forOp.getNumIterOperands () && "The size of reductions array must match the number of iter_args" ) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1313 , __extension__ __PRETTY_FUNCTION__)); | ||||
1314 | } | ||||
1315 | |||||
1316 | // Vectorize 'iter_args'. | ||||
1317 | SmallVector<Value, 8> vecIterOperands; | ||||
1318 | if (!isLoopVecDim) { | ||||
1319 | for (auto operand : forOp.getIterOperands()) | ||||
1320 | vecIterOperands.push_back(vectorizeOperand(operand, state)); | ||||
1321 | } else { | ||||
1322 | // For reduction loops we need to pass a vector of neutral elements as an | ||||
1323 | // initial value of the accumulator. We will add the original initial value | ||||
1324 | // later. | ||||
1325 | for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) { | ||||
1326 | vecIterOperands.push_back(createInitialVector( | ||||
1327 | std::get<0>(redAndOperand).kind, std::get<1>(redAndOperand), state)); | ||||
1328 | } | ||||
1329 | } | ||||
1330 | |||||
1331 | auto vecForOp = state.builder.create<AffineForOp>( | ||||
1332 | forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), | ||||
1333 | forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, | ||||
1334 | vecIterOperands, | ||||
1335 | /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { | ||||
1336 | // Make sure we don't create a default terminator in the loop body as | ||||
1337 | // the proper terminator will be added during vectorization. | ||||
1338 | }); | ||||
1339 | |||||
1340 | // Register loop-related replacements: | ||||
1341 | // 1) The new vectorized loop is registered as vector replacement of the | ||||
1342 | // scalar loop. | ||||
1343 | // 2) The new iv of the vectorized loop is registered as scalar replacement | ||||
1344 | // since a scalar copy of the iv will prevail in the vectorized loop. | ||||
1345 | // TODO: A vector replacement will also be added in the future when | ||||
1346 | // vectorization of linear ops is supported. | ||||
1347 | // 3) The new 'iter_args' region arguments are registered as vector | ||||
1348 | // replacements since they have been vectorized. | ||||
1349 | // 4) If the loop performs a reduction along the vector dimension, a | ||||
1350 | // `vector.reduction` or similar op is inserted for each resulting value | ||||
1351 | // of the loop and its scalar value replaces the corresponding scalar | ||||
1352 | // result of the loop. | ||||
1353 | state.registerOpVectorReplacement(forOp, vecForOp); | ||||
1354 | state.registerValueScalarReplacement(forOp.getInductionVar(), | ||||
1355 | vecForOp.getInductionVar()); | ||||
1356 | for (auto iterTuple : | ||||
1357 | llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs())) | ||||
1358 | state.registerBlockArgVectorReplacement(std::get<0>(iterTuple), | ||||
1359 | std::get<1>(iterTuple)); | ||||
1360 | |||||
1361 | if (isLoopVecDim) { | ||||
1362 | for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) { | ||||
1363 | // First, we reduce the vector returned from the loop into a scalar. | ||||
1364 | Value reducedRes = | ||||
1365 | getVectorReductionOp(reductions[i].kind, state.builder, | ||||
1366 | vecForOp.getLoc(), vecForOp.getResult(i)); | ||||
1367 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: " << reducedRes; } } while (false) | ||||
1368 | << reducedRes)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: " << reducedRes; } } while (false); | ||||
1369 | // Then we combine it with the original (scalar) initial value unless it | ||||
1370 | // is equal to the neutral element of the reduction. | ||||
1371 | Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i); | ||||
1372 | Value finalRes = reducedRes; | ||||
1373 | if (!isNeutralElementConst(reductions[i].kind, origInit, state)) | ||||
1374 | finalRes = | ||||
1375 | arith::getReductionOp(reductions[i].kind, state.builder, | ||||
1376 | reducedRes.getLoc(), reducedRes, origInit); | ||||
1377 | state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes); | ||||
1378 | } | ||||
1379 | } | ||||
1380 | |||||
1381 | if (isLoopVecDim) | ||||
1382 | state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second; | ||||
1383 | |||||
1384 | // Change insertion point so that upcoming vectorized instructions are | ||||
1385 | // inserted into the vectorized loop's body. | ||||
1386 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
1387 | |||||
1388 | // If this is a reduction loop then we may need to create a mask to filter out | ||||
1389 | // garbage in the last iteration. | ||||
1390 | if (isLoopVecDim && forOp.getNumIterOperands() > 0) | ||||
1391 | createMask(vecForOp, state); | ||||
1392 | |||||
1393 | return vecForOp; | ||||
1394 | } | ||||
1395 | |||||
1396 | /// Vectorizes arbitrary operation by plain widening. We apply generic type | ||||
1397 | /// widening of all its results and retrieve the vector counterparts for all its | ||||
1398 | /// operands. | ||||
1399 | static Operation *widenOp(Operation *op, VectorizationState &state) { | ||||
1400 | SmallVector<Type, 8> vectorTypes; | ||||
1401 | for (Value result : op->getResults()) | ||||
1402 | vectorTypes.push_back( | ||||
1403 | VectorType::get(state.strategy->vectorSizes, result.getType())); | ||||
1404 | |||||
1405 | SmallVector<Value, 8> vectorOperands; | ||||
1406 | for (Value operand : op->getOperands()) { | ||||
1407 | Value vecOperand = vectorizeOperand(operand, state); | ||||
1408 | if (!vecOperand) { | ||||
1409 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n" ; } } while (false); | ||||
1410 | return nullptr; | ||||
1411 | } | ||||
1412 | vectorOperands.push_back(vecOperand); | ||||
1413 | } | ||||
1414 | |||||
1415 | // Create a clone of the op with the proper operands and return types. | ||||
1416 | // TODO: The following assumes there is always an op with a fixed | ||||
1417 | // name that works both in scalar mode and vector mode. | ||||
1418 | // TODO: Is it worth considering an Operation.clone operation which | ||||
1419 | // changes the type so we can promote an Operation with less boilerplate? | ||||
1420 | Operation *vecOp = | ||||
1421 | state.builder.create(op->getLoc(), op->getName().getIdentifier(), | ||||
1422 | vectorOperands, vectorTypes, op->getAttrs()); | ||||
1423 | state.registerOpVectorReplacement(op, vecOp); | ||||
1424 | return vecOp; | ||||
1425 | } | ||||
1426 | |||||
1427 | /// Vectorizes a yield operation by widening its types. The builder's insertion | ||||
1428 | /// point is set after the vectorized parent op to continue vectorizing the | ||||
1429 | /// operations after the parent op. When vectorizing a reduction loop a mask may | ||||
1430 | /// be used to prevent adding garbage values to the accumulator. | ||||
1431 | static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, | ||||
1432 | VectorizationState &state) { | ||||
1433 | Operation *newYieldOp = widenOp(yieldOp, state); | ||||
| |||||
1434 | Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp(); | ||||
1435 | |||||
1436 | // If there is a mask for this loop then we must prevent garbage values from | ||||
1437 | // being added to the accumulator by inserting `select` operations, for | ||||
1438 | // example: | ||||
1439 | // | ||||
1440 | // %val_masked = select %mask, %val, %neutralCst : vector<128xi1>, | ||||
1441 | // vector<128xf32> | ||||
1442 | // %res = arith.addf %acc, %val_masked : vector<128xf32> | ||||
1443 | // affine.yield %res : vector<128xf32> | ||||
1444 | // | ||||
1445 | if (Value mask = state.vecLoopToMask.lookup(newParentOp)) { | ||||
1446 | state.builder.setInsertionPoint(newYieldOp); | ||||
1447 | for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) { | ||||
1448 | SmallVector<Operation *> combinerOps; | ||||
1449 | Value reducedVal = matchReduction( | ||||
1450 | cast<AffineForOp>(newParentOp).getRegionIterArgs(), i, combinerOps); | ||||
1451 | assert(reducedVal && "expect non-null value for parallel reduction loop")(static_cast <bool> (reducedVal && "expect non-null value for parallel reduction loop" ) ? void (0) : __assert_fail ("reducedVal && \"expect non-null value for parallel reduction loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1451 , __extension__ __PRETTY_FUNCTION__)); | ||||
1452 | assert(combinerOps.size() == 1 && "expect only one combiner op")(static_cast <bool> (combinerOps.size() == 1 && "expect only one combiner op") ? void (0) : __assert_fail ("combinerOps.size() == 1 && \"expect only one combiner op\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1452 , __extension__ __PRETTY_FUNCTION__)); | ||||
1453 | // IterOperands are neutral element vectors. | ||||
1454 | Value neutralVal = cast<AffineForOp>(newParentOp).getIterOperands()[i]; | ||||
1455 | state.builder.setInsertionPoint(combinerOps.back()); | ||||
1456 | Value maskedReducedVal = state.builder.create<arith::SelectOp>( | ||||
1457 | reducedVal.getLoc(), mask, reducedVal, neutralVal); | ||||
1458 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " << maskedReducedVal; } } while (false) | ||||
1459 | dbgs() << "\n[early-vect]+++++ masking an input to a binary op that"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " << maskedReducedVal; } } while (false) | ||||
1460 | "produces value for a yield Op: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " << maskedReducedVal; } } while (false) | ||||
1461 | << maskedReducedVal)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking an input to a binary op that" "produces value for a yield Op: " << maskedReducedVal; } } while (false); | ||||
1462 | combinerOps.back()->replaceUsesOfWith(reducedVal, maskedReducedVal); | ||||
1463 | } | ||||
1464 | } | ||||
1465 | |||||
1466 | state.builder.setInsertionPointAfter(newParentOp); | ||||
1467 | return newYieldOp; | ||||
1468 | } | ||||
1469 | |||||
1470 | /// Encodes Operation-specific behavior for vectorization. In general we | ||||
1471 | /// assume that all operands of an op must be vectorized but this is not | ||||
1472 | /// always true. In the future, it would be nice to have a trait that | ||||
1473 | /// describes how a particular operation vectorizes. For now we implement the | ||||
1474 | /// case distinction here. Returns a vectorized form of an operation or | ||||
1475 | /// nullptr if vectorization fails. | ||||
1476 | // TODO: consider adding a trait to Op to describe how it gets vectorized. | ||||
1477 | // Maybe some Ops are not vectorizable or require some tricky logic, we cannot | ||||
1478 | // do one-off logic here; ideally it would be TableGen'd. | ||||
1479 | static Operation *vectorizeOneOperation(Operation *op, | ||||
1480 | VectorizationState &state) { | ||||
1481 | // Sanity checks. | ||||
1482 | assert(!isa<vector::TransferReadOp>(op) &&(static_cast <bool> (!isa<vector::TransferReadOp> (op) && "vector.transfer_read cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1483 , __extension__ __PRETTY_FUNCTION__)) | ||||
1483 | "vector.transfer_read cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferReadOp> (op) && "vector.transfer_read cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1483 , __extension__ __PRETTY_FUNCTION__)); | ||||
1484 | assert(!isa<vector::TransferWriteOp>(op) &&(static_cast <bool> (!isa<vector::TransferWriteOp> (op) && "vector.transfer_write cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1485 , __extension__ __PRETTY_FUNCTION__)) | ||||
1485 | "vector.transfer_write cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferWriteOp> (op) && "vector.transfer_write cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1485 , __extension__ __PRETTY_FUNCTION__)); | ||||
1486 | |||||
1487 | if (auto loadOp = dyn_cast<AffineLoadOp>(op)) | ||||
1488 | return vectorizeAffineLoad(loadOp, state); | ||||
1489 | if (auto storeOp = dyn_cast<AffineStoreOp>(op)) | ||||
1490 | return vectorizeAffineStore(storeOp, state); | ||||
1491 | if (auto forOp = dyn_cast<AffineForOp>(op)) | ||||
1492 | return vectorizeAffineForOp(forOp, state); | ||||
1493 | if (auto yieldOp = dyn_cast<AffineYieldOp>(op)) | ||||
1494 | return vectorizeAffineYieldOp(yieldOp, state); | ||||
1495 | if (auto constant = dyn_cast<arith::ConstantOp>(op)) | ||||
1496 | return vectorizeConstant(constant, state); | ||||
1497 | |||||
1498 | // Other ops with regions are not supported. | ||||
1499 | if (op->getNumRegions() != 0) | ||||
1500 | return nullptr; | ||||
1501 | |||||
1502 | return widenOp(op, state); | ||||
1503 | } | ||||
1504 | |||||
1505 | /// Recursive implementation to convert all the nested loops in 'match' to a 2D | ||||
1506 | /// vector container that preserves the relative nesting level of each loop with | ||||
1507 | /// respect to the others in 'match'. 'currentLevel' is the nesting level that | ||||
1508 | /// will be assigned to the loop in the current 'match'. | ||||
1509 | static void | ||||
1510 | getMatchedAffineLoopsRec(NestedMatch match, unsigned currentLevel, | ||||
1511 | std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1512 | // Add a new empty level to the output if it doesn't exist already. | ||||
1513 | assert(currentLevel <= loops.size() && "Unexpected currentLevel")(static_cast <bool> (currentLevel <= loops.size() && "Unexpected currentLevel") ? void (0) : __assert_fail ("currentLevel <= loops.size() && \"Unexpected currentLevel\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1513 , __extension__ __PRETTY_FUNCTION__)); | ||||
1514 | if (currentLevel == loops.size()) | ||||
1515 | loops.emplace_back(); | ||||
1516 | |||||
1517 | // Add current match and recursively visit its children. | ||||
1518 | loops[currentLevel].push_back(cast<AffineForOp>(match.getMatchedOperation())); | ||||
1519 | for (auto childMatch : match.getMatchedChildren()) { | ||||
1520 | getMatchedAffineLoopsRec(childMatch, currentLevel + 1, loops); | ||||
1521 | } | ||||
1522 | } | ||||
1523 | |||||
1524 | /// Converts all the nested loops in 'match' to a 2D vector container that | ||||
1525 | /// preserves the relative nesting level of each loop with respect to the others | ||||
1526 | /// in 'match'. This means that every loop in 'loops[i]' will have a parent loop | ||||
1527 | /// in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in | ||||
1528 | /// 'loops[i+1]'. | ||||
1529 | static void | ||||
1530 | getMatchedAffineLoops(NestedMatch match, | ||||
1531 | std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1532 | getMatchedAffineLoopsRec(match, /*currLoopDepth=*/0, loops); | ||||
1533 | } | ||||
1534 | |||||
1535 | /// Internal implementation to vectorize affine loops from a single loop nest | ||||
1536 | /// using an n-D vectorization strategy. | ||||
1537 | static LogicalResult | ||||
1538 | vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops, | ||||
1539 | const VectorizationStrategy &strategy) { | ||||
1540 | assert(loops[0].size() == 1 && "Expected single root loop")(static_cast <bool> (loops[0].size() == 1 && "Expected single root loop" ) ? void (0) : __assert_fail ("loops[0].size() == 1 && \"Expected single root loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1540 , __extension__ __PRETTY_FUNCTION__)); | ||||
1541 | AffineForOp rootLoop = loops[0][0]; | ||||
1542 | VectorizationState state(rootLoop.getContext()); | ||||
1543 | state.builder.setInsertionPointAfter(rootLoop); | ||||
1544 | state.strategy = &strategy; | ||||
1545 | |||||
1546 | // Since patterns are recursive, they can very well intersect. | ||||
1547 | // Since we do not want a fully greedy strategy in general, we decouple | ||||
1548 | // pattern matching, from profitability analysis, from application. | ||||
1549 | // As a consequence we must check that each root pattern is still | ||||
1550 | // vectorizable. If a pattern is not vectorizable anymore, we just skip it. | ||||
1551 | // TODO: implement a non-greedy profitability analysis that keeps only | ||||
1552 | // non-intersecting patterns. | ||||
1553 | if (!isVectorizableLoopBody(rootLoop, vectorTransferPattern())) { | ||||
1554 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ loop is not vectorizable" ; } } while (false); | ||||
1555 | return failure(); | ||||
1556 | } | ||||
1557 | |||||
1558 | ////////////////////////////////////////////////////////////////////////////// | ||||
1559 | // Vectorize the scalar loop nest following a topological order. A new vector | ||||
1560 | // loop nest with the vectorized operations is created along the process. If | ||||
1561 | // vectorization succeeds, the scalar loop nest is erased. If vectorization | ||||
1562 | // fails, the vector loop nest is erased and the scalar loop nest is not | ||||
1563 | // modified. | ||||
1564 | ////////////////////////////////////////////////////////////////////////////// | ||||
1565 | |||||
1566 | auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) { | ||||
1567 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ Vectorizing: " << *op; } } while (false); | ||||
1568 | Operation *vectorOp = vectorizeOneOperation(op, state); | ||||
1569 | if (!vectorOp) { | ||||
1570 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false) | ||||
1571 | dbgs() << "[early-vect]+++++ failed vectorizing the operation: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false) | ||||
1572 | << *op << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false); | ||||
1573 | return WalkResult::interrupt(); | ||||
1574 | } | ||||
1575 | |||||
1576 | return WalkResult::advance(); | ||||
1577 | }); | ||||
1578 | |||||
1579 | if (opVecResult.wasInterrupted()) { | ||||
1580 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: " << rootLoop << "\n"; } } while (false) | ||||
1581 | << rootLoop << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: " << rootLoop << "\n"; } } while (false); | ||||
1582 | // Erase vector loop nest if it was created. | ||||
1583 | auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop); | ||||
1584 | if (vecRootLoopIt != state.opVectorReplacement.end()) | ||||
1585 | eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second)); | ||||
1586 | |||||
1587 | return failure(); | ||||
1588 | } | ||||
1589 | |||||
1590 | // Replace results of reduction loops with the scalar values computed using | ||||
1591 | // `vector.reduce` or similar ops. | ||||
1592 | for (auto resPair : state.loopResultScalarReplacement) | ||||
1593 | resPair.first.replaceAllUsesWith(resPair.second); | ||||
1594 | |||||
1595 | assert(state.opVectorReplacement.count(rootLoop) == 1 &&(static_cast <bool> (state.opVectorReplacement.count(rootLoop ) == 1 && "Expected vector replacement for loop nest" ) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1596 , __extension__ __PRETTY_FUNCTION__)) | ||||
1596 | "Expected vector replacement for loop nest")(static_cast <bool> (state.opVectorReplacement.count(rootLoop ) == 1 && "Expected vector replacement for loop nest" ) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1596 , __extension__ __PRETTY_FUNCTION__)); | ||||
1597 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ success vectorizing pattern" ; } } while (false); | ||||
1598 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n" << *state.opVectorReplacement[rootLoop]; } } while (false ) | ||||
1599 | << *state.opVectorReplacement[rootLoop])do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n" << *state.opVectorReplacement[rootLoop]; } } while (false ); | ||||
1600 | |||||
1601 | // Finish this vectorization pattern. | ||||
1602 | state.finishVectorizationPattern(rootLoop); | ||||
1603 | return success(); | ||||
1604 | } | ||||
1605 | |||||
1606 | /// Extracts the matched loops and vectorizes them following a topological | ||||
1607 | /// order. A new vector loop nest will be created if vectorization succeeds. The | ||||
1608 | /// original loop nest won't be modified in any case. | ||||
1609 | static LogicalResult vectorizeRootMatch(NestedMatch m, | ||||
1610 | const VectorizationStrategy &strategy) { | ||||
1611 | std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize; | ||||
1612 | getMatchedAffineLoops(m, loopsToVectorize); | ||||
1613 | return vectorizeLoopNest(loopsToVectorize, strategy); | ||||
1614 | } | ||||
1615 | |||||
1616 | /// Traverses all the loop matches and classifies them into intersection | ||||
1617 | /// buckets. Two matches intersect if any of them encloses the other one. A | ||||
1618 | /// match intersects with a bucket if the match intersects with the root | ||||
1619 | /// (outermost) loop in that bucket. | ||||
1620 | static void computeIntersectionBuckets( | ||||
1621 | ArrayRef<NestedMatch> matches, | ||||
1622 | std::vector<SmallVector<NestedMatch, 8>> &intersectionBuckets) { | ||||
1623 | assert(intersectionBuckets.empty() && "Expected empty output")(static_cast <bool> (intersectionBuckets.empty() && "Expected empty output") ? void (0) : __assert_fail ("intersectionBuckets.empty() && \"Expected empty output\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1623 , __extension__ __PRETTY_FUNCTION__)); | ||||
1624 | // Keeps track of the root (outermost) loop of each bucket. | ||||
1625 | SmallVector<AffineForOp, 8> bucketRoots; | ||||
1626 | |||||
1627 | for (const NestedMatch &match : matches) { | ||||
1628 | AffineForOp matchRoot = cast<AffineForOp>(match.getMatchedOperation()); | ||||
1629 | bool intersects = false; | ||||
1630 | for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) { | ||||
1631 | AffineForOp bucketRoot = bucketRoots[i]; | ||||
1632 | // Add match to the bucket if the bucket root encloses the match root. | ||||
1633 | if (bucketRoot->isAncestor(matchRoot)) { | ||||
1634 | intersectionBuckets[i].push_back(match); | ||||
1635 | intersects = true; | ||||
1636 | break; | ||||
1637 | } | ||||
1638 | // Add match to the bucket if the match root encloses the bucket root. The | ||||
1639 | // match root becomes the new bucket root. | ||||
1640 | if (matchRoot->isAncestor(bucketRoot)) { | ||||
1641 | bucketRoots[i] = matchRoot; | ||||
1642 | intersectionBuckets[i].push_back(match); | ||||
1643 | intersects = true; | ||||
1644 | break; | ||||
1645 | } | ||||
1646 | } | ||||
1647 | |||||
1648 | // Match doesn't intersect with any existing bucket. Create a new bucket for | ||||
1649 | // it. | ||||
1650 | if (!intersects) { | ||||
1651 | bucketRoots.push_back(matchRoot); | ||||
1652 | intersectionBuckets.emplace_back(); | ||||
1653 | intersectionBuckets.back().push_back(match); | ||||
1654 | } | ||||
1655 | } | ||||
1656 | } | ||||
1657 | |||||
1658 | /// Internal implementation to vectorize affine loops in 'loops' using the n-D | ||||
1659 | /// vectorization factors in 'vectorSizes'. By default, each vectorization | ||||
1660 | /// factor is applied inner-to-outer to the loops of each loop nest. | ||||
1661 | /// 'fastestVaryingPattern' can be optionally used to provide a different loop | ||||
1662 | /// vectorization order. `reductionLoops` can be provided to specify loops which | ||||
1663 | /// can be vectorized along the reduction dimension. | ||||
1664 | static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops, | ||||
1665 | ArrayRef<int64_t> vectorSizes, | ||||
1666 | ArrayRef<int64_t> fastestVaryingPattern, | ||||
1667 | const ReductionLoopMap &reductionLoops) { | ||||
1668 | assert((reductionLoops.empty() || vectorSizes.size() == 1) &&(static_cast <bool> ((reductionLoops.empty() || vectorSizes .size() == 1) && "Vectorizing reductions is supported only for 1-D vectors" ) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1669 , __extension__ __PRETTY_FUNCTION__)) | ||||
1669 | "Vectorizing reductions is supported only for 1-D vectors")(static_cast <bool> ((reductionLoops.empty() || vectorSizes .size() == 1) && "Vectorizing reductions is supported only for 1-D vectors" ) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1669 , __extension__ __PRETTY_FUNCTION__)); | ||||
1670 | |||||
1671 | // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops. | ||||
1672 | std::optional<NestedPattern> pattern = | ||||
1673 | makePattern(loops, vectorSizes.size(), fastestVaryingPattern); | ||||
1674 | if (!pattern) { | ||||
1675 | LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] pattern couldn't be computed\n" ; } } while (false); | ||||
1676 | return; | ||||
1677 | } | ||||
1678 | |||||
1679 | LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n******************************************" ; } } while (false); | ||||
1680 | LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n******************************************" ; } } while (false); | ||||
1681 | LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] new pattern on parent op\n" ; } } while (false); | ||||
1682 | LLVM_DEBUG(dbgs() << *parentOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *parentOp << "\n"; } } while (false); | ||||
1683 | |||||
1684 | unsigned patternDepth = pattern->getDepth(); | ||||
1685 | |||||
1686 | // Compute all the pattern matches and classify them into buckets of | ||||
1687 | // intersecting matches. | ||||
1688 | SmallVector<NestedMatch, 32> allMatches; | ||||
1689 | pattern->match(parentOp, &allMatches); | ||||
1690 | std::vector<SmallVector<NestedMatch, 8>> intersectionBuckets; | ||||
1691 | computeIntersectionBuckets(allMatches, intersectionBuckets); | ||||
1692 | |||||
1693 | // Iterate over all buckets and vectorize the matches eagerly. We can only | ||||
1694 | // vectorize one match from each bucket since all the matches within a bucket | ||||
1695 | // intersect. | ||||
1696 | for (auto &intersectingMatches : intersectionBuckets) { | ||||
1697 | for (NestedMatch &match : intersectingMatches) { | ||||
1698 | VectorizationStrategy strategy; | ||||
1699 | // TODO: depending on profitability, elect to reduce the vector size. | ||||
1700 | strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end()); | ||||
1701 | strategy.reductionLoops = reductionLoops; | ||||
1702 | if (failed(analyzeProfitability(match.getMatchedChildren(), 1, | ||||
1703 | patternDepth, &strategy))) { | ||||
1704 | continue; | ||||
1705 | } | ||||
1706 | vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth, | ||||
1707 | &strategy); | ||||
1708 | // Vectorize match. Skip the rest of intersecting matches in the bucket if | ||||
1709 | // vectorization succeeded. | ||||
1710 | // TODO: if pattern does not apply, report it; alter the cost/benefit. | ||||
1711 | // TODO: some diagnostics if failure to vectorize occurs. | ||||
1712 | if (succeeded(vectorizeRootMatch(match, strategy))) | ||||
1713 | break; | ||||
1714 | } | ||||
1715 | } | ||||
1716 | |||||
1717 | LLVM_DEBUG(dbgs() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n"; } } while (false); | ||||
1718 | } | ||||
1719 | |||||
1720 | /// Applies vectorization to the current function by searching over a bunch of | ||||
1721 | /// predetermined patterns. | ||||
1722 | void Vectorize::runOnOperation() { | ||||
1723 | func::FuncOp f = getOperation(); | ||||
1724 | if (!fastestVaryingPattern.empty() && | ||||
1725 | fastestVaryingPattern.size() != vectorSizes.size()) { | ||||
1726 | f.emitRemark("Fastest varying pattern specified with different size than " | ||||
1727 | "the vector size."); | ||||
1728 | return signalPassFailure(); | ||||
1729 | } | ||||
1730 | |||||
1731 | if (vectorizeReductions && vectorSizes.size() != 1) { | ||||
1732 | f.emitError("Vectorizing reductions is supported only for 1-D vectors."); | ||||
1733 | return signalPassFailure(); | ||||
1734 | } | ||||
1735 | |||||
1736 | DenseSet<Operation *> parallelLoops; | ||||
1737 | ReductionLoopMap reductionLoops; | ||||
1738 | |||||
1739 | // If 'vectorize-reduction=true' is provided, we also populate the | ||||
1740 | // `reductionLoops` map. | ||||
1741 | if (vectorizeReductions) { | ||||
1742 | f.walk([¶llelLoops, &reductionLoops](AffineForOp loop) { | ||||
1743 | SmallVector<LoopReduction, 2> reductions; | ||||
1744 | if (isLoopParallel(loop, &reductions)) { | ||||
1745 | parallelLoops.insert(loop); | ||||
1746 | // If it's not a reduction loop, adding it to the map is not necessary. | ||||
1747 | if (!reductions.empty()) | ||||
1748 | reductionLoops[loop] = reductions; | ||||
1749 | } | ||||
1750 | }); | ||||
1751 | } else { | ||||
1752 | f.walk([¶llelLoops](AffineForOp loop) { | ||||
1753 | if (isLoopParallel(loop)) | ||||
1754 | parallelLoops.insert(loop); | ||||
1755 | }); | ||||
1756 | } | ||||
1757 | |||||
1758 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1759 | NestedPatternContext mlContext; | ||||
1760 | vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern, | ||||
1761 | reductionLoops); | ||||
1762 | } | ||||
1763 | |||||
1764 | /// Verify that affine loops in 'loops' meet the nesting criteria expected by | ||||
1765 | /// SuperVectorizer: | ||||
1766 | /// * There must be at least one loop. | ||||
1767 | /// * There must be a single root loop (nesting level 0). | ||||
1768 | /// * Each loop at a given nesting level must be nested in a loop from a | ||||
1769 | /// previous nesting level. | ||||
1770 | static LogicalResult | ||||
1771 | verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1772 | // Expected at least one loop. | ||||
1773 | if (loops.empty()) | ||||
1774 | return failure(); | ||||
1775 | |||||
1776 | // Expected only one root loop. | ||||
1777 | if (loops[0].size() != 1) | ||||
1778 | return failure(); | ||||
1779 | |||||
1780 | // Traverse loops outer-to-inner to check some invariants. | ||||
1781 | for (int i = 1, end = loops.size(); i < end; ++i) { | ||||
1782 | for (AffineForOp loop : loops[i]) { | ||||
1783 | // Check that each loop at this level is nested in one of the loops from | ||||
1784 | // the previous level. | ||||
1785 | if (none_of(loops[i - 1], [&](AffineForOp maybeParent) { | ||||
1786 | return maybeParent->isProperAncestor(loop); | ||||
1787 | })) | ||||
1788 | return failure(); | ||||
1789 | |||||
1790 | // Check that each loop at this level is not nested in another loop from | ||||
1791 | // this level. | ||||
1792 | for (AffineForOp sibling : loops[i]) { | ||||
1793 | if (sibling->isProperAncestor(loop)) | ||||
1794 | return failure(); | ||||
1795 | } | ||||
1796 | } | ||||
1797 | } | ||||
1798 | |||||
1799 | return success(); | ||||
1800 | } | ||||
1801 | |||||
1802 | |||||
1803 | /// External utility to vectorize affine loops in 'loops' using the n-D | ||||
1804 | /// vectorization factors in 'vectorSizes'. By default, each vectorization | ||||
1805 | /// factor is applied inner-to-outer to the loops of each loop nest. | ||||
1806 | /// 'fastestVaryingPattern' can be optionally used to provide a different loop | ||||
1807 | /// vectorization order. | ||||
1808 | /// If `reductionLoops` is not empty, the given reduction loops may be | ||||
1809 | /// vectorized along the reduction dimension. | ||||
1810 | /// TODO: Vectorizing reductions is supported only for 1-D vectorization. | ||||
1811 | void mlir::affine::vectorizeAffineLoops( | ||||
1812 | Operation *parentOp, DenseSet<Operation *> &loops, | ||||
1813 | ArrayRef<int64_t> vectorSizes, ArrayRef<int64_t> fastestVaryingPattern, | ||||
1814 | const ReductionLoopMap &reductionLoops) { | ||||
1815 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1816 | NestedPatternContext mlContext; | ||||
1817 | vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern, | ||||
1818 | reductionLoops); | ||||
1819 | } | ||||
1820 | |||||
1821 | /// External utility to vectorize affine loops from a single loop nest using an | ||||
1822 | /// n-D vectorization strategy (see doc in VectorizationStrategy definition). | ||||
1823 | /// Loops are provided in a 2D vector container. The first dimension represents | ||||
1824 | /// the nesting level relative to the loops to be vectorized. The second | ||||
1825 | /// dimension contains the loops. This means that: | ||||
1826 | /// a) every loop in 'loops[i]' must have a parent loop in 'loops[i-1]', | ||||
1827 | /// b) a loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'. | ||||
1828 | /// | ||||
1829 | /// For example, for the following loop nest: | ||||
1830 | /// | ||||
1831 | /// func @vec2d(%in0: memref<64x128x512xf32>, %in1: memref<64x128x128xf32>, | ||||
1832 | /// %out0: memref<64x128x512xf32>, | ||||
1833 | /// %out1: memref<64x128x128xf32>) { | ||||
1834 | /// affine.for %i0 = 0 to 64 { | ||||
1835 | /// affine.for %i1 = 0 to 128 { | ||||
1836 | /// affine.for %i2 = 0 to 512 { | ||||
1837 | /// %ld = affine.load %in0[%i0, %i1, %i2] : memref<64x128x512xf32> | ||||
1838 | /// affine.store %ld, %out0[%i0, %i1, %i2] : memref<64x128x512xf32> | ||||
1839 | /// } | ||||
1840 | /// affine.for %i3 = 0 to 128 { | ||||
1841 | /// %ld = affine.load %in1[%i0, %i1, %i3] : memref<64x128x128xf32> | ||||
1842 | /// affine.store %ld, %out1[%i0, %i1, %i3] : memref<64x128x128xf32> | ||||
1843 | /// } | ||||
1844 | /// } | ||||
1845 | /// } | ||||
1846 | /// return | ||||
1847 | /// } | ||||
1848 | /// | ||||
1849 | /// loops = {{%i0}, {%i2, %i3}}, to vectorize the outermost and the two | ||||
1850 | /// innermost loops; | ||||
1851 | /// loops = {{%i1}, {%i2, %i3}}, to vectorize the middle and the two innermost | ||||
1852 | /// loops; | ||||
1853 | /// loops = {{%i2}}, to vectorize only the first innermost loop; | ||||
1854 | /// loops = {{%i3}}, to vectorize only the second innermost loop; | ||||
1855 | /// loops = {{%i1}}, to vectorize only the middle loop. | ||||
1856 | LogicalResult mlir::affine::vectorizeAffineLoopNest( | ||||
1857 | std::vector<SmallVector<AffineForOp, 2>> &loops, | ||||
1858 | const VectorizationStrategy &strategy) { | ||||
1859 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1860 | NestedPatternContext mlContext; | ||||
1861 | if (failed(verifyLoopNesting(loops))) | ||||
1862 | return failure(); | ||||
1863 | return vectorizeLoopNest(loops, strategy); | ||||
1864 | } |
1 | //===- Value.h - Base of the SSA Value hierarchy ----------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines generic Value type and manipulation utilities. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_VALUE_H |
14 | #define MLIR_IR_VALUE_H |
15 | |
16 | #include "mlir/IR/Types.h" |
17 | #include "mlir/IR/UseDefLists.h" |
18 | #include "mlir/Support/LLVM.h" |
19 | #include "llvm/Support/PointerLikeTypeTraits.h" |
20 | |
21 | namespace mlir { |
22 | class AsmState; |
23 | class Block; |
24 | class BlockArgument; |
25 | class Operation; |
26 | class OpOperand; |
27 | class OpPrintingFlags; |
28 | class OpResult; |
29 | class Region; |
30 | class Value; |
31 | |
32 | //===----------------------------------------------------------------------===// |
33 | // Value |
34 | //===----------------------------------------------------------------------===// |
35 | |
36 | namespace detail { |
37 | |
38 | /// The base class for all derived Value classes. It contains all of the |
39 | /// components that are shared across Value classes. |
40 | class alignas(8) ValueImpl : public IRObjectWithUseList<OpOperand> { |
41 | public: |
42 | /// The enumeration represents the various different kinds of values the |
43 | /// internal representation may take. We use all of the bits from Type that we |
44 | /// can to store indices inline. |
45 | enum class Kind { |
46 | /// The first N kinds are all inline operation results. An inline operation |
47 | /// result means that the kind represents the result number. This removes |
48 | /// the need to store an additional index value. The derived class here is |
49 | /// an `OpResultImpl`. |
50 | InlineOpResult = 0, |
51 | |
52 | /// The next kind represents a 'out-of-line' operation result. This is for |
53 | /// results with numbers larger than we can represent inline. The derived |
54 | /// class here is an `OpResultImpl`. |
55 | OutOfLineOpResult = 6, |
56 | |
57 | /// The last kind represents a block argument. The derived class here is an |
58 | /// `BlockArgumentImpl`. |
59 | BlockArgument = 7 |
60 | }; |
61 | |
62 | /// Return the type of this value. |
63 | Type getType() const { return typeAndKind.getPointer(); } |
64 | |
65 | /// Set the type of this value. |
66 | void setType(Type type) { return typeAndKind.setPointer(type); } |
67 | |
68 | /// Return the kind of this value. |
69 | Kind getKind() const { return typeAndKind.getInt(); } |
70 | |
71 | protected: |
72 | ValueImpl(Type type, Kind kind) : typeAndKind(type, kind) {} |
73 | |
74 | /// Expose a few methods explicitly for the debugger to call for |
75 | /// visualization. |
76 | #ifndef NDEBUG |
77 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) Type debug_getType() const { return getType(); } |
78 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) Kind debug_getKind() const { return getKind(); } |
79 | |
80 | #endif |
81 | |
82 | /// The type of this result and the kind. |
83 | llvm::PointerIntPair<Type, 3, Kind> typeAndKind; |
84 | }; |
85 | } // namespace detail |
86 | |
87 | /// This class represents an instance of an SSA value in the MLIR system, |
88 | /// representing a computable value that has a type and a set of users. An SSA |
89 | /// value is either a BlockArgument or the result of an operation. Note: This |
90 | /// class has value-type semantics and is just a simple wrapper around a |
91 | /// ValueImpl that is either owner by a block(in the case of a BlockArgument) or |
92 | /// an Operation(in the case of an OpResult). |
93 | class Value { |
94 | public: |
95 | constexpr Value(detail::ValueImpl *impl = nullptr) : impl(impl) {} |
96 | |
97 | template <typename U> |
98 | bool isa() const { |
99 | return llvm::isa<U>(*this); |
100 | } |
101 | |
102 | template <typename U> |
103 | U dyn_cast() const { |
104 | return llvm::dyn_cast<U>(*this); |
105 | } |
106 | |
107 | template <typename U> |
108 | U dyn_cast_or_null() const { |
109 | return llvm::dyn_cast_if_present<U>(*this); |
110 | } |
111 | |
112 | template <typename U> |
113 | U cast() const { |
114 | return llvm::cast<U>(*this); |
115 | } |
116 | |
117 | explicit operator bool() const { return impl; } |
118 | bool operator==(const Value &other) const { return impl == other.impl; } |
119 | bool operator!=(const Value &other) const { return !(*this == other); } |
120 | |
121 | /// Return the type of this value. |
122 | Type getType() const { return impl->getType(); } |
123 | |
124 | /// Utility to get the associated MLIRContext that this value is defined in. |
125 | MLIRContext *getContext() const { return getType().getContext(); } |
126 | |
127 | /// Mutate the type of this Value to be of the specified type. |
128 | /// |
129 | /// Note that this is an extremely dangerous operation which can create |
130 | /// completely invalid IR very easily. It is strongly recommended that you |
131 | /// recreate IR objects with the right types instead of mutating them in |
132 | /// place. |
133 | void setType(Type newType) { impl->setType(newType); } |
134 | |
135 | /// If this value is the result of an operation, return the operation that |
136 | /// defines it. |
137 | Operation *getDefiningOp() const; |
138 | |
139 | /// If this value is the result of an operation of type OpTy, return the |
140 | /// operation that defines it. |
141 | template <typename OpTy> |
142 | OpTy getDefiningOp() const { |
143 | return llvm::dyn_cast_or_null<OpTy>(getDefiningOp()); |
144 | } |
145 | |
146 | /// Return the location of this value. |
147 | Location getLoc() const; |
148 | void setLoc(Location loc); |
149 | |
150 | /// Return the Region in which this Value is defined. |
151 | Region *getParentRegion(); |
152 | |
153 | /// Return the Block in which this Value is defined. |
154 | Block *getParentBlock(); |
155 | |
156 | //===--------------------------------------------------------------------===// |
157 | // UseLists |
158 | //===--------------------------------------------------------------------===// |
159 | |
160 | /// Drop all uses of this object from their respective owners. |
161 | void dropAllUses() const { return impl->dropAllUses(); } |
162 | |
163 | /// Replace all uses of 'this' value with the new value, updating anything in |
164 | /// the IR that uses 'this' to use the other value instead. When this returns |
165 | /// there are zero uses of 'this'. |
166 | void replaceAllUsesWith(Value newValue) const { |
167 | impl->replaceAllUsesWith(newValue); |
168 | } |
169 | |
170 | /// Replace all uses of 'this' value with 'newValue', updating anything in the |
171 | /// IR that uses 'this' to use the other value instead except if the user is |
172 | /// listed in 'exceptions' . |
173 | void |
174 | replaceAllUsesExcept(Value newValue, |
175 | const SmallPtrSetImpl<Operation *> &exceptions) const; |
176 | |
177 | /// Replace all uses of 'this' value with 'newValue', updating anything in the |
178 | /// IR that uses 'this' to use the other value instead except if the user is |
179 | /// 'exceptedUser'. |
180 | void replaceAllUsesExcept(Value newValue, Operation *exceptedUser) const; |
181 | |
182 | /// Replace all uses of 'this' value with 'newValue' if the given callback |
183 | /// returns true. |
184 | void replaceUsesWithIf(Value newValue, |
185 | function_ref<bool(OpOperand &)> shouldReplace); |
186 | |
187 | /// Returns true if the value is used outside of the given block. |
188 | bool isUsedOutsideOfBlock(Block *block); |
189 | |
190 | //===--------------------------------------------------------------------===// |
191 | // Uses |
192 | |
193 | /// This class implements an iterator over the uses of a value. |
194 | using use_iterator = ValueUseIterator<OpOperand>; |
195 | using use_range = iterator_range<use_iterator>; |
196 | |
197 | use_iterator use_begin() const { return impl->use_begin(); } |
198 | use_iterator use_end() const { return use_iterator(); } |
199 | |
200 | /// Returns a range of all uses, which is useful for iterating over all uses. |
201 | use_range getUses() const { return {use_begin(), use_end()}; } |
202 | |
203 | /// Returns true if this value has exactly one use. |
204 | bool hasOneUse() const { return impl->hasOneUse(); } |
205 | |
206 | /// Returns true if this value has no uses. |
207 | bool use_empty() const { return impl->use_empty(); } |
208 | |
209 | //===--------------------------------------------------------------------===// |
210 | // Users |
211 | |
212 | using user_iterator = ValueUserIterator<use_iterator, OpOperand>; |
213 | using user_range = iterator_range<user_iterator>; |
214 | |
215 | user_iterator user_begin() const { return use_begin(); } |
216 | user_iterator user_end() const { return use_end(); } |
217 | user_range getUsers() const { return {user_begin(), user_end()}; } |
218 | |
219 | //===--------------------------------------------------------------------===// |
220 | // Utilities |
221 | |
222 | void print(raw_ostream &os); |
223 | void print(raw_ostream &os, const OpPrintingFlags &flags); |
224 | void print(raw_ostream &os, AsmState &state); |
225 | void dump(); |
226 | |
227 | /// Print this value as if it were an operand. |
228 | void printAsOperand(raw_ostream &os, AsmState &state); |
229 | void printAsOperand(raw_ostream &os, const OpPrintingFlags &flags); |
230 | |
231 | /// Methods for supporting PointerLikeTypeTraits. |
232 | void *getAsOpaquePointer() const { return impl; } |
233 | static Value getFromOpaquePointer(const void *pointer) { |
234 | return reinterpret_cast<detail::ValueImpl *>(const_cast<void *>(pointer)); |
235 | } |
236 | detail::ValueImpl *getImpl() const { return impl; } |
237 | |
238 | friend ::llvm::hash_code hash_value(Value arg); |
239 | |
240 | protected: |
241 | /// A pointer to the internal implementation of the value. |
242 | detail::ValueImpl *impl; |
243 | }; |
244 | |
245 | inline raw_ostream &operator<<(raw_ostream &os, Value value) { |
246 | value.print(os); |
247 | return os; |
248 | } |
249 | |
250 | //===----------------------------------------------------------------------===// |
251 | // OpOperand |
252 | //===----------------------------------------------------------------------===// |
253 | |
254 | /// This class represents an operand of an operation. Instances of this class |
255 | /// contain a reference to a specific `Value`. |
256 | class OpOperand : public IROperand<OpOperand, Value> { |
257 | public: |
258 | /// Provide the use list that is attached to the given value. |
259 | static IRObjectWithUseList<OpOperand> *getUseList(Value value) { |
260 | return value.getImpl(); |
261 | } |
262 | |
263 | /// Return which operand this is in the OpOperand list of the Operation. |
264 | unsigned getOperandNumber(); |
265 | |
266 | private: |
267 | /// Keep the constructor private and accessible to the OperandStorage class |
268 | /// only to avoid hard-to-debug typo/programming mistakes. |
269 | friend class OperandStorage; |
270 | using IROperand<OpOperand, Value>::IROperand; |
271 | }; |
272 | |
273 | //===----------------------------------------------------------------------===// |
274 | // BlockArgument |
275 | //===----------------------------------------------------------------------===// |
276 | |
277 | namespace detail { |
278 | /// The internal implementation of a BlockArgument. |
279 | class BlockArgumentImpl : public ValueImpl { |
280 | public: |
281 | static bool classof(const ValueImpl *value) { |
282 | return value->getKind() == ValueImpl::Kind::BlockArgument; |
283 | } |
284 | |
285 | private: |
286 | BlockArgumentImpl(Type type, Block *owner, int64_t index, Location loc) |
287 | : ValueImpl(type, Kind::BlockArgument), owner(owner), index(index), |
288 | loc(loc) {} |
289 | |
290 | /// The owner of this argument. |
291 | Block *owner; |
292 | |
293 | /// The position in the argument list. |
294 | int64_t index; |
295 | |
296 | /// The source location of this argument. |
297 | Location loc; |
298 | |
299 | /// Allow access to owner and constructor. |
300 | friend BlockArgument; |
301 | }; |
302 | } // namespace detail |
303 | |
304 | /// This class represents an argument of a Block. |
305 | class BlockArgument : public Value { |
306 | public: |
307 | using Value::Value; |
308 | |
309 | static bool classof(Value value) { |
310 | return llvm::isa<detail::BlockArgumentImpl>(value.getImpl()); |
311 | } |
312 | |
313 | /// Returns the block that owns this argument. |
314 | Block *getOwner() const { return getImpl()->owner; } |
315 | |
316 | /// Returns the number of this argument. |
317 | unsigned getArgNumber() const { return getImpl()->index; } |
318 | |
319 | /// Return the location for this argument. |
320 | Location getLoc() const { return getImpl()->loc; } |
321 | void setLoc(Location loc) { getImpl()->loc = loc; } |
322 | |
323 | private: |
324 | /// Allocate a new argument with the given type and owner. |
325 | static BlockArgument create(Type type, Block *owner, int64_t index, |
326 | Location loc) { |
327 | return new detail::BlockArgumentImpl(type, owner, index, loc); |
328 | } |
329 | |
330 | /// Destroy and deallocate this argument. |
331 | void destroy() { delete getImpl(); } |
332 | |
333 | /// Get a raw pointer to the internal implementation. |
334 | detail::BlockArgumentImpl *getImpl() const { |
335 | return reinterpret_cast<detail::BlockArgumentImpl *>(impl); |
336 | } |
337 | |
338 | /// Cache the position in the block argument list. |
339 | void setArgNumber(int64_t index) { getImpl()->index = index; } |
340 | |
341 | /// Allow access to `create`, `destroy` and `setArgNumber`. |
342 | friend Block; |
343 | |
344 | /// Allow access to 'getImpl'. |
345 | friend Value; |
346 | }; |
347 | |
348 | //===----------------------------------------------------------------------===// |
349 | // OpResult |
350 | //===----------------------------------------------------------------------===// |
351 | |
352 | namespace detail { |
353 | /// This class provides the implementation for an operation result. |
354 | class alignas(8) OpResultImpl : public ValueImpl { |
355 | public: |
356 | using ValueImpl::ValueImpl; |
357 | |
358 | static bool classof(const ValueImpl *value) { |
359 | return value->getKind() != ValueImpl::Kind::BlockArgument; |
360 | } |
361 | |
362 | /// Returns the parent operation of this result. |
363 | Operation *getOwner() const; |
364 | |
365 | /// Returns the result number of this op result. |
366 | unsigned getResultNumber() const; |
367 | |
368 | /// Returns the next operation result at `offset` after this result. This |
369 | /// method is useful when indexing the result storage of an operation, given |
370 | /// that there is more than one kind of operation result (with the different |
371 | /// kinds having different sizes) and that operations are stored in reverse |
372 | /// order. |
373 | OpResultImpl *getNextResultAtOffset(intptr_t offset); |
374 | |
375 | /// Returns the maximum number of results that can be stored inline. |
376 | static unsigned getMaxInlineResults() { |
377 | return static_cast<unsigned>(Kind::OutOfLineOpResult); |
378 | } |
379 | }; |
380 | |
381 | /// This class provides the implementation for an operation result whose index |
382 | /// can be represented "inline" in the underlying ValueImpl. |
383 | struct InlineOpResult : public OpResultImpl { |
384 | public: |
385 | InlineOpResult(Type type, unsigned resultNo) |
386 | : OpResultImpl(type, static_cast<ValueImpl::Kind>(resultNo)) { |
387 | assert(resultNo < getMaxInlineResults())(static_cast <bool> (resultNo < getMaxInlineResults( )) ? void (0) : __assert_fail ("resultNo < getMaxInlineResults()" , "mlir/include/mlir/IR/Value.h", 387, __extension__ __PRETTY_FUNCTION__ )); |
388 | } |
389 | |
390 | /// Return the result number of this op result. |
391 | unsigned getResultNumber() const { return static_cast<unsigned>(getKind()); } |
392 | |
393 | static bool classof(const OpResultImpl *value) { |
394 | return value->getKind() != ValueImpl::Kind::OutOfLineOpResult; |
395 | } |
396 | }; |
397 | |
398 | /// This class provides the implementation for an operation result whose index |
399 | /// cannot be represented "inline", and thus requires an additional index field. |
400 | class OutOfLineOpResult : public OpResultImpl { |
401 | public: |
402 | OutOfLineOpResult(Type type, uint64_t outOfLineIndex) |
403 | : OpResultImpl(type, Kind::OutOfLineOpResult), |
404 | outOfLineIndex(outOfLineIndex) {} |
405 | |
406 | static bool classof(const OpResultImpl *value) { |
407 | return value->getKind() == ValueImpl::Kind::OutOfLineOpResult; |
408 | } |
409 | |
410 | /// Return the result number of this op result. |
411 | unsigned getResultNumber() const { |
412 | return outOfLineIndex + getMaxInlineResults(); |
413 | } |
414 | |
415 | /// The trailing result number, or the offset from the beginning of the |
416 | /// `OutOfLineOpResult` array. |
417 | uint64_t outOfLineIndex; |
418 | }; |
419 | |
420 | /// Return the result number of this op result. |
421 | inline unsigned OpResultImpl::getResultNumber() const { |
422 | if (const auto *outOfLineResult = dyn_cast<OutOfLineOpResult>(this)) |
423 | return outOfLineResult->getResultNumber(); |
424 | return cast<InlineOpResult>(this)->getResultNumber(); |
425 | } |
426 | |
427 | /// TypedValue is a Value with a statically know type. |
428 | /// TypedValue can be null/empty |
429 | template <typename Ty> |
430 | struct TypedValue : Value { |
431 | using Value::Value; |
432 | |
433 | static bool classof(Value value) { return llvm::isa<Ty>(value.getType()); } |
434 | |
435 | /// Return the known Type |
436 | Ty getType() { return Value::getType().template cast<Ty>(); } |
437 | void setType(Ty ty) { Value::setType(ty); } |
438 | }; |
439 | |
440 | } // namespace detail |
441 | |
442 | /// This is a value defined by a result of an operation. |
443 | class OpResult : public Value { |
444 | public: |
445 | using Value::Value; |
446 | |
447 | static bool classof(Value value) { |
448 | return llvm::isa<detail::OpResultImpl>(value.getImpl()); |
449 | } |
450 | |
451 | /// Returns the operation that owns this result. |
452 | Operation *getOwner() const { return getImpl()->getOwner(); } |
453 | |
454 | /// Returns the number of this result. |
455 | unsigned getResultNumber() const { return getImpl()->getResultNumber(); } |
456 | |
457 | private: |
458 | /// Get a raw pointer to the internal implementation. |
459 | detail::OpResultImpl *getImpl() const { |
460 | return reinterpret_cast<detail::OpResultImpl *>(impl); |
461 | } |
462 | |
463 | /// Given a number of operation results, returns the number that need to be |
464 | /// stored inline. |
465 | static unsigned getNumInline(unsigned numResults); |
466 | |
467 | /// Given a number of operation results, returns the number that need to be |
468 | /// stored as trailing. |
469 | static unsigned getNumTrailing(unsigned numResults); |
470 | |
471 | /// Allow access to constructor. |
472 | friend Operation; |
473 | }; |
474 | |
475 | /// Make Value hashable. |
476 | inline ::llvm::hash_code hash_value(Value arg) { |
477 | return ::llvm::hash_value(arg.getImpl()); |
478 | } |
479 | |
480 | template <typename Ty, typename Value = mlir::Value> |
481 | /// If Ty is mlir::Type this will select `Value` instead of having a wrapper |
482 | /// around it. This helps resolve ambiguous conversion issues. |
483 | using TypedValue = std::conditional_t<std::is_same_v<Ty, mlir::Type>, |
484 | mlir::Value, detail::TypedValue<Ty>>; |
485 | |
486 | } // namespace mlir |
487 | |
488 | namespace llvm { |
489 | |
490 | template <> |
491 | struct DenseMapInfo<mlir::Value> { |
492 | static mlir::Value getEmptyKey() { |
493 | void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
494 | return mlir::Value::getFromOpaquePointer(pointer); |
495 | } |
496 | static mlir::Value getTombstoneKey() { |
497 | void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
498 | return mlir::Value::getFromOpaquePointer(pointer); |
499 | } |
500 | static unsigned getHashValue(mlir::Value val) { |
501 | return mlir::hash_value(val); |
502 | } |
503 | static bool isEqual(mlir::Value lhs, mlir::Value rhs) { return lhs == rhs; } |
504 | }; |
505 | template <> |
506 | struct DenseMapInfo<mlir::BlockArgument> : public DenseMapInfo<mlir::Value> { |
507 | static mlir::BlockArgument getEmptyKey() { |
508 | void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
509 | return reinterpret_cast<mlir::detail::BlockArgumentImpl *>(pointer); |
510 | } |
511 | static mlir::BlockArgument getTombstoneKey() { |
512 | void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
513 | return reinterpret_cast<mlir::detail::BlockArgumentImpl *>(pointer); |
514 | } |
515 | }; |
516 | template <> |
517 | struct DenseMapInfo<mlir::OpResult> : public DenseMapInfo<mlir::Value> { |
518 | static mlir::OpResult getEmptyKey() { |
519 | void *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); |
520 | return reinterpret_cast<mlir::detail::OpResultImpl *>(pointer); |
521 | } |
522 | static mlir::OpResult getTombstoneKey() { |
523 | void *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); |
524 | return reinterpret_cast<mlir::detail::OpResultImpl *>(pointer); |
525 | } |
526 | }; |
527 | |
528 | /// Allow stealing the low bits of a value. |
529 | template <> |
530 | struct PointerLikeTypeTraits<mlir::Value> { |
531 | public: |
532 | static inline void *getAsVoidPointer(mlir::Value value) { |
533 | return const_cast<void *>(value.getAsOpaquePointer()); |
534 | } |
535 | static inline mlir::Value getFromVoidPointer(void *pointer) { |
536 | return mlir::Value::getFromOpaquePointer(pointer); |
537 | } |
538 | enum { |
539 | NumLowBitsAvailable = |
540 | PointerLikeTypeTraits<mlir::detail::ValueImpl *>::NumLowBitsAvailable |
541 | }; |
542 | }; |
543 | template <> |
544 | struct PointerLikeTypeTraits<mlir::BlockArgument> |
545 | : public PointerLikeTypeTraits<mlir::Value> { |
546 | public: |
547 | static inline mlir::BlockArgument getFromVoidPointer(void *pointer) { |
548 | return reinterpret_cast<mlir::detail::BlockArgumentImpl *>(pointer); |
549 | } |
550 | }; |
551 | template <> |
552 | struct PointerLikeTypeTraits<mlir::OpResult> |
553 | : public PointerLikeTypeTraits<mlir::Value> { |
554 | public: |
555 | static inline mlir::OpResult getFromVoidPointer(void *pointer) { |
556 | return reinterpret_cast<mlir::detail::OpResultImpl *>(pointer); |
557 | } |
558 | }; |
559 | |
560 | /// Add support for llvm style casts. We provide a cast between To and From if |
561 | /// From is mlir::Value or derives from it. |
562 | template <typename To, typename From> |
563 | struct CastInfo< |
564 | To, From, |
565 | std::enable_if_t<std::is_same_v<mlir::Value, std::remove_const_t<From>> || |
566 | std::is_base_of_v<mlir::Value, From>>> |
567 | : NullableValueCastFailed<To>, |
568 | DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { |
569 | /// Arguments are taken as mlir::Value here and not as `From`, because |
570 | /// when casting from an intermediate type of the hierarchy to one of its |
571 | /// children, the val.getKind() inside T::classof will use the static |
572 | /// getKind() of the parent instead of the non-static ValueImpl::getKind() |
573 | /// that returns the dynamic type. This means that T::classof would end up |
574 | /// comparing the static Kind of the children to the static Kind of its |
575 | /// parent, making it impossible to downcast from the parent to the child. |
576 | static inline bool isPossible(mlir::Value ty) { |
577 | /// Return a constant true instead of a dynamic true when casting to self or |
578 | /// up the hierarchy. |
579 | if constexpr (std::is_base_of_v<To, From>) { |
580 | (void)ty; |
581 | return true; |
582 | } else { |
583 | return To::classof(ty); |
584 | } |
585 | } |
586 | static inline To doCast(mlir::Value value) { return To(value.getImpl()); } |
587 | }; |
588 | |
589 | } // namespace llvm |
590 | |
591 | #endif |
1 | //===- llvm/Support/Casting.h - Allow flexible, checked, casts --*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the isa<X>(), cast<X>(), dyn_cast<X>(), |
10 | // cast_if_present<X>(), and dyn_cast_if_present<X>() templates. |
11 | // |
12 | //===----------------------------------------------------------------------===// |
13 | |
14 | #ifndef LLVM_SUPPORT_CASTING_H |
15 | #define LLVM_SUPPORT_CASTING_H |
16 | |
17 | #include "llvm/Support/Compiler.h" |
18 | #include "llvm/Support/type_traits.h" |
19 | #include <cassert> |
20 | #include <memory> |
21 | #include <optional> |
22 | #include <type_traits> |
23 | |
24 | namespace llvm { |
25 | |
26 | //===----------------------------------------------------------------------===// |
27 | // simplify_type |
28 | //===----------------------------------------------------------------------===// |
29 | |
30 | /// Define a template that can be specialized by smart pointers to reflect the |
31 | /// fact that they are automatically dereferenced, and are not involved with the |
32 | /// template selection process... the default implementation is a noop. |
33 | // TODO: rename this and/or replace it with other cast traits. |
34 | template <typename From> struct simplify_type { |
35 | using SimpleType = From; // The real type this represents... |
36 | |
37 | // An accessor to get the real value... |
38 | static SimpleType &getSimplifiedValue(From &Val) { return Val; } |
39 | }; |
40 | |
41 | template <typename From> struct simplify_type<const From> { |
42 | using NonConstSimpleType = typename simplify_type<From>::SimpleType; |
43 | using SimpleType = typename add_const_past_pointer<NonConstSimpleType>::type; |
44 | using RetType = |
45 | typename add_lvalue_reference_if_not_pointer<SimpleType>::type; |
46 | |
47 | static RetType getSimplifiedValue(const From &Val) { |
48 | return simplify_type<From>::getSimplifiedValue(const_cast<From &>(Val)); |
49 | } |
50 | }; |
51 | |
52 | // TODO: add this namespace once everyone is switched to using the new |
53 | // interface. |
54 | // namespace detail { |
55 | |
56 | //===----------------------------------------------------------------------===// |
57 | // isa_impl |
58 | //===----------------------------------------------------------------------===// |
59 | |
60 | // The core of the implementation of isa<X> is here; To and From should be |
61 | // the names of classes. This template can be specialized to customize the |
62 | // implementation of isa<> without rewriting it from scratch. |
63 | template <typename To, typename From, typename Enabler = void> struct isa_impl { |
64 | static inline bool doit(const From &Val) { return To::classof(&Val); } |
65 | }; |
66 | |
67 | // Always allow upcasts, and perform no dynamic check for them. |
68 | template <typename To, typename From> |
69 | struct isa_impl<To, From, std::enable_if_t<std::is_base_of_v<To, From>>> { |
70 | static inline bool doit(const From &) { return true; } |
71 | }; |
72 | |
73 | template <typename To, typename From> struct isa_impl_cl { |
74 | static inline bool doit(const From &Val) { |
75 | return isa_impl<To, From>::doit(Val); |
76 | } |
77 | }; |
78 | |
79 | template <typename To, typename From> struct isa_impl_cl<To, const From> { |
80 | static inline bool doit(const From &Val) { |
81 | return isa_impl<To, From>::doit(Val); |
82 | } |
83 | }; |
84 | |
85 | template <typename To, typename From> |
86 | struct isa_impl_cl<To, const std::unique_ptr<From>> { |
87 | static inline bool doit(const std::unique_ptr<From> &Val) { |
88 | assert(Val && "isa<> used on a null pointer")(static_cast <bool> (Val && "isa<> used on a null pointer" ) ? void (0) : __assert_fail ("Val && \"isa<> used on a null pointer\"" , "llvm/include/llvm/Support/Casting.h", 88, __extension__ __PRETTY_FUNCTION__ )); |
89 | return isa_impl_cl<To, From>::doit(*Val); |
90 | } |
91 | }; |
92 | |
93 | template <typename To, typename From> struct isa_impl_cl<To, From *> { |
94 | static inline bool doit(const From *Val) { |
95 | assert(Val && "isa<> used on a null pointer")(static_cast <bool> (Val && "isa<> used on a null pointer" ) ? void (0) : __assert_fail ("Val && \"isa<> used on a null pointer\"" , "llvm/include/llvm/Support/Casting.h", 95, __extension__ __PRETTY_FUNCTION__ )); |
96 | return isa_impl<To, From>::doit(*Val); |
97 | } |
98 | }; |
99 | |
100 | template <typename To, typename From> struct isa_impl_cl<To, From *const> { |
101 | static inline bool doit(const From *Val) { |
102 | assert(Val && "isa<> used on a null pointer")(static_cast <bool> (Val && "isa<> used on a null pointer" ) ? void (0) : __assert_fail ("Val && \"isa<> used on a null pointer\"" , "llvm/include/llvm/Support/Casting.h", 102, __extension__ __PRETTY_FUNCTION__ )); |
103 | return isa_impl<To, From>::doit(*Val); |
104 | } |
105 | }; |
106 | |
107 | template <typename To, typename From> struct isa_impl_cl<To, const From *> { |
108 | static inline bool doit(const From *Val) { |
109 | assert(Val && "isa<> used on a null pointer")(static_cast <bool> (Val && "isa<> used on a null pointer" ) ? void (0) : __assert_fail ("Val && \"isa<> used on a null pointer\"" , "llvm/include/llvm/Support/Casting.h", 109, __extension__ __PRETTY_FUNCTION__ )); |
110 | return isa_impl<To, From>::doit(*Val); |
111 | } |
112 | }; |
113 | |
114 | template <typename To, typename From> |
115 | struct isa_impl_cl<To, const From *const> { |
116 | static inline bool doit(const From *Val) { |
117 | assert(Val && "isa<> used on a null pointer")(static_cast <bool> (Val && "isa<> used on a null pointer" ) ? void (0) : __assert_fail ("Val && \"isa<> used on a null pointer\"" , "llvm/include/llvm/Support/Casting.h", 117, __extension__ __PRETTY_FUNCTION__ )); |
118 | return isa_impl<To, From>::doit(*Val); |
119 | } |
120 | }; |
121 | |
122 | template <typename To, typename From, typename SimpleFrom> |
123 | struct isa_impl_wrap { |
124 | // When From != SimplifiedType, we can simplify the type some more by using |
125 | // the simplify_type template. |
126 | static bool doit(const From &Val) { |
127 | return isa_impl_wrap<To, SimpleFrom, |
128 | typename simplify_type<SimpleFrom>::SimpleType>:: |
129 | doit(simplify_type<const From>::getSimplifiedValue(Val)); |
130 | } |
131 | }; |
132 | |
133 | template <typename To, typename FromTy> |
134 | struct isa_impl_wrap<To, FromTy, FromTy> { |
135 | // When From == SimpleType, we are as simple as we are going to get. |
136 | static bool doit(const FromTy &Val) { |
137 | return isa_impl_cl<To, FromTy>::doit(Val); |
138 | } |
139 | }; |
140 | |
141 | //===----------------------------------------------------------------------===// |
142 | // cast_retty + cast_retty_impl |
143 | //===----------------------------------------------------------------------===// |
144 | |
145 | template <class To, class From> struct cast_retty; |
146 | |
147 | // Calculate what type the 'cast' function should return, based on a requested |
148 | // type of To and a source type of From. |
149 | template <class To, class From> struct cast_retty_impl { |
150 | using ret_type = To &; // Normal case, return Ty& |
151 | }; |
152 | template <class To, class From> struct cast_retty_impl<To, const From> { |
153 | using ret_type = const To &; // Normal case, return Ty& |
154 | }; |
155 | |
156 | template <class To, class From> struct cast_retty_impl<To, From *> { |
157 | using ret_type = To *; // Pointer arg case, return Ty* |
158 | }; |
159 | |
160 | template <class To, class From> struct cast_retty_impl<To, const From *> { |
161 | using ret_type = const To *; // Constant pointer arg case, return const Ty* |
162 | }; |
163 | |
164 | template <class To, class From> struct cast_retty_impl<To, const From *const> { |
165 | using ret_type = const To *; // Constant pointer arg case, return const Ty* |
166 | }; |
167 | |
168 | template <class To, class From> |
169 | struct cast_retty_impl<To, std::unique_ptr<From>> { |
170 | private: |
171 | using PointerType = typename cast_retty_impl<To, From *>::ret_type; |
172 | using ResultType = std::remove_pointer_t<PointerType>; |
173 | |
174 | public: |
175 | using ret_type = std::unique_ptr<ResultType>; |
176 | }; |
177 | |
178 | template <class To, class From, class SimpleFrom> struct cast_retty_wrap { |
179 | // When the simplified type and the from type are not the same, use the type |
180 | // simplifier to reduce the type, then reuse cast_retty_impl to get the |
181 | // resultant type. |
182 | using ret_type = typename cast_retty<To, SimpleFrom>::ret_type; |
183 | }; |
184 | |
185 | template <class To, class FromTy> struct cast_retty_wrap<To, FromTy, FromTy> { |
186 | // When the simplified type is equal to the from type, use it directly. |
187 | using ret_type = typename cast_retty_impl<To, FromTy>::ret_type; |
188 | }; |
189 | |
190 | template <class To, class From> struct cast_retty { |
191 | using ret_type = typename cast_retty_wrap< |
192 | To, From, typename simplify_type<From>::SimpleType>::ret_type; |
193 | }; |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // cast_convert_val |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | // Ensure the non-simple values are converted using the simplify_type template |
200 | // that may be specialized by smart pointers... |
201 | // |
202 | template <class To, class From, class SimpleFrom> struct cast_convert_val { |
203 | // This is not a simple type, use the template to simplify it... |
204 | static typename cast_retty<To, From>::ret_type doit(const From &Val) { |
205 | return cast_convert_val<To, SimpleFrom, |
206 | typename simplify_type<SimpleFrom>::SimpleType>:: |
207 | doit(simplify_type<From>::getSimplifiedValue(const_cast<From &>(Val))); |
208 | } |
209 | }; |
210 | |
211 | template <class To, class FromTy> struct cast_convert_val<To, FromTy, FromTy> { |
212 | // If it's a reference, switch to a pointer to do the cast and then deref it. |
213 | static typename cast_retty<To, FromTy>::ret_type doit(const FromTy &Val) { |
214 | return *(std::remove_reference_t<typename cast_retty<To, FromTy>::ret_type> |
215 | *)&const_cast<FromTy &>(Val); |
216 | } |
217 | }; |
218 | |
219 | template <class To, class FromTy> |
220 | struct cast_convert_val<To, FromTy *, FromTy *> { |
221 | // If it's a pointer, we can use c-style casting directly. |
222 | static typename cast_retty<To, FromTy *>::ret_type doit(const FromTy *Val) { |
223 | return (typename cast_retty<To, FromTy *>::ret_type) const_cast<FromTy *>( |
224 | Val); |
225 | } |
226 | }; |
227 | |
228 | //===----------------------------------------------------------------------===// |
229 | // is_simple_type |
230 | //===----------------------------------------------------------------------===// |
231 | |
232 | template <class X> struct is_simple_type { |
233 | static const bool value = |
234 | std::is_same_v<X, typename simplify_type<X>::SimpleType>; |
235 | }; |
236 | |
237 | // } // namespace detail |
238 | |
239 | //===----------------------------------------------------------------------===// |
240 | // CastIsPossible |
241 | //===----------------------------------------------------------------------===// |
242 | |
243 | /// This struct provides a way to check if a given cast is possible. It provides |
244 | /// a static function called isPossible that is used to check if a cast can be |
245 | /// performed. It should be overridden like this: |
246 | /// |
247 | /// template<> struct CastIsPossible<foo, bar> { |
248 | /// static inline bool isPossible(const bar &b) { |
249 | /// return bar.isFoo(); |
250 | /// } |
251 | /// }; |
252 | template <typename To, typename From, typename Enable = void> |
253 | struct CastIsPossible { |
254 | static inline bool isPossible(const From &f) { |
255 | return isa_impl_wrap< |
256 | To, const From, |
257 | typename simplify_type<const From>::SimpleType>::doit(f); |
258 | } |
259 | }; |
260 | |
261 | // Needed for optional unwrapping. This could be implemented with isa_impl, but |
262 | // we want to implement things in the new method and move old implementations |
263 | // over. In fact, some of the isa_impl templates should be moved over to |
264 | // CastIsPossible. |
265 | template <typename To, typename From> |
266 | struct CastIsPossible<To, std::optional<From>> { |
267 | static inline bool isPossible(const std::optional<From> &f) { |
268 | assert(f && "CastIsPossible::isPossible called on a nullopt!")(static_cast <bool> (f && "CastIsPossible::isPossible called on a nullopt!" ) ? void (0) : __assert_fail ("f && \"CastIsPossible::isPossible called on a nullopt!\"" , "llvm/include/llvm/Support/Casting.h", 268, __extension__ __PRETTY_FUNCTION__ )); |
269 | return isa_impl_wrap< |
270 | To, const From, |
271 | typename simplify_type<const From>::SimpleType>::doit(*f); |
272 | } |
273 | }; |
274 | |
275 | /// Upcasting (from derived to base) and casting from a type to itself should |
276 | /// always be possible. |
277 | template <typename To, typename From> |
278 | struct CastIsPossible<To, From, std::enable_if_t<std::is_base_of_v<To, From>>> { |
279 | static inline bool isPossible(const From &f) { return true; } |
280 | }; |
281 | |
282 | //===----------------------------------------------------------------------===// |
283 | // Cast traits |
284 | //===----------------------------------------------------------------------===// |
285 | |
286 | /// All of these cast traits are meant to be implementations for useful casts |
287 | /// that users may want to use that are outside the standard behavior. An |
288 | /// example of how to use a special cast called `CastTrait` is: |
289 | /// |
290 | /// template<> struct CastInfo<foo, bar> : public CastTrait<foo, bar> {}; |
291 | /// |
292 | /// Essentially, if your use case falls directly into one of the use cases |
293 | /// supported by a given cast trait, simply inherit your special CastInfo |
294 | /// directly from one of these to avoid having to reimplement the boilerplate |
295 | /// `isPossible/castFailed/doCast/doCastIfPossible`. A cast trait can also |
296 | /// provide a subset of those functions. |
297 | |
298 | /// This cast trait just provides castFailed for the specified `To` type to make |
299 | /// CastInfo specializations more declarative. In order to use this, the target |
300 | /// result type must be `To` and `To` must be constructible from `nullptr`. |
301 | template <typename To> struct NullableValueCastFailed { |
302 | static To castFailed() { return To(nullptr); } |
303 | }; |
304 | |
305 | /// This cast trait just provides the default implementation of doCastIfPossible |
306 | /// to make CastInfo specializations more declarative. The `Derived` template |
307 | /// parameter *must* be provided for forwarding castFailed and doCast. |
308 | template <typename To, typename From, typename Derived> |
309 | struct DefaultDoCastIfPossible { |
310 | static To doCastIfPossible(From f) { |
311 | if (!Derived::isPossible(f)) |
312 | return Derived::castFailed(); |
313 | return Derived::doCast(f); |
314 | } |
315 | }; |
316 | |
317 | namespace detail { |
318 | /// A helper to derive the type to use with `Self` for cast traits, when the |
319 | /// provided CRTP derived type is allowed to be void. |
320 | template <typename OptionalDerived, typename Default> |
321 | using SelfType = std::conditional_t<std::is_same_v<OptionalDerived, void>, |
322 | Default, OptionalDerived>; |
323 | } // namespace detail |
324 | |
325 | /// This cast trait provides casting for the specific case of casting to a |
326 | /// value-typed object from a pointer-typed object. Note that `To` must be |
327 | /// nullable/constructible from a pointer to `From` to use this cast. |
328 | template <typename To, typename From, typename Derived = void> |
329 | struct ValueFromPointerCast |
330 | : public CastIsPossible<To, From *>, |
331 | public NullableValueCastFailed<To>, |
332 | public DefaultDoCastIfPossible< |
333 | To, From *, |
334 | detail::SelfType<Derived, ValueFromPointerCast<To, From>>> { |
335 | static inline To doCast(From *f) { return To(f); } |
336 | }; |
337 | |
338 | /// This cast trait provides std::unique_ptr casting. It has the semantics of |
339 | /// moving the contents of the input unique_ptr into the output unique_ptr |
340 | /// during the cast. It's also a good example of how to implement a move-only |
341 | /// cast. |
342 | template <typename To, typename From, typename Derived = void> |
343 | struct UniquePtrCast : public CastIsPossible<To, From *> { |
344 | using Self = detail::SelfType<Derived, UniquePtrCast<To, From>>; |
345 | using CastResultType = std::unique_ptr< |
346 | std::remove_reference_t<typename cast_retty<To, From>::ret_type>>; |
347 | |
348 | static inline CastResultType doCast(std::unique_ptr<From> &&f) { |
349 | return CastResultType((typename CastResultType::element_type *)f.release()); |
350 | } |
351 | |
352 | static inline CastResultType castFailed() { return CastResultType(nullptr); } |
353 | |
354 | static inline CastResultType doCastIfPossible(std::unique_ptr<From> &&f) { |
355 | if (!Self::isPossible(f)) |
356 | return castFailed(); |
357 | return doCast(f); |
358 | } |
359 | }; |
360 | |
361 | /// This cast trait provides std::optional<T> casting. This means that if you |
362 | /// have a value type, you can cast it to another value type and have dyn_cast |
363 | /// return an std::optional<T>. |
364 | template <typename To, typename From, typename Derived = void> |
365 | struct OptionalValueCast |
366 | : public CastIsPossible<To, From>, |
367 | public DefaultDoCastIfPossible< |
368 | std::optional<To>, From, |
369 | detail::SelfType<Derived, OptionalValueCast<To, From>>> { |
370 | static inline std::optional<To> castFailed() { return std::optional<To>{}; } |
371 | |
372 | static inline std::optional<To> doCast(const From &f) { return To(f); } |
373 | }; |
374 | |
375 | /// Provides a cast trait that strips `const` from types to make it easier to |
376 | /// implement a const-version of a non-const cast. It just removes boilerplate |
377 | /// and reduces the amount of code you as the user need to implement. You can |
378 | /// use it like this: |
379 | /// |
380 | /// template<> struct CastInfo<foo, bar> { |
381 | /// ...verbose implementation... |
382 | /// }; |
383 | /// |
384 | /// template<> struct CastInfo<foo, const bar> : public |
385 | /// ConstStrippingForwardingCast<foo, const bar, CastInfo<foo, bar>> {}; |
386 | /// |
387 | template <typename To, typename From, typename ForwardTo> |
388 | struct ConstStrippingForwardingCast { |
389 | // Remove the pointer if it exists, then we can get rid of consts/volatiles. |
390 | using DecayedFrom = std::remove_cv_t<std::remove_pointer_t<From>>; |
391 | // Now if it's a pointer, add it back. Otherwise, we want a ref. |
392 | using NonConstFrom = |
393 | std::conditional_t<std::is_pointer_v<From>, DecayedFrom *, DecayedFrom &>; |
394 | |
395 | static inline bool isPossible(const From &f) { |
396 | return ForwardTo::isPossible(const_cast<NonConstFrom>(f)); |
397 | } |
398 | |
399 | static inline decltype(auto) castFailed() { return ForwardTo::castFailed(); } |
400 | |
401 | static inline decltype(auto) doCast(const From &f) { |
402 | return ForwardTo::doCast(const_cast<NonConstFrom>(f)); |
403 | } |
404 | |
405 | static inline decltype(auto) doCastIfPossible(const From &f) { |
406 | return ForwardTo::doCastIfPossible(const_cast<NonConstFrom>(f)); |
407 | } |
408 | }; |
409 | |
410 | /// Provides a cast trait that uses a defined pointer to pointer cast as a base |
411 | /// for reference-to-reference casts. Note that it does not provide castFailed |
412 | /// and doCastIfPossible because a pointer-to-pointer cast would likely just |
413 | /// return `nullptr` which could cause nullptr dereference. You can use it like |
414 | /// this: |
415 | /// |
416 | /// template <> struct CastInfo<foo, bar *> { ... verbose implementation... }; |
417 | /// |
418 | /// template <> |
419 | /// struct CastInfo<foo, bar> |
420 | /// : public ForwardToPointerCast<foo, bar, CastInfo<foo, bar *>> {}; |
421 | /// |
422 | template <typename To, typename From, typename ForwardTo> |
423 | struct ForwardToPointerCast { |
424 | static inline bool isPossible(const From &f) { |
425 | return ForwardTo::isPossible(&f); |
426 | } |
427 | |
428 | static inline decltype(auto) doCast(const From &f) { |
429 | return *ForwardTo::doCast(&f); |
430 | } |
431 | }; |
432 | |
433 | //===----------------------------------------------------------------------===// |
434 | // CastInfo |
435 | //===----------------------------------------------------------------------===// |
436 | |
437 | /// This struct provides a method for customizing the way a cast is performed. |
438 | /// It inherits from CastIsPossible, to support the case of declaring many |
439 | /// CastIsPossible specializations without having to specialize the full |
440 | /// CastInfo. |
441 | /// |
442 | /// In order to specialize different behaviors, specify different functions in |
443 | /// your CastInfo specialization. |
444 | /// For isa<> customization, provide: |
445 | /// |
446 | /// `static bool isPossible(const From &f)` |
447 | /// |
448 | /// For cast<> customization, provide: |
449 | /// |
450 | /// `static To doCast(const From &f)` |
451 | /// |
452 | /// For dyn_cast<> and the *_if_present<> variants' customization, provide: |
453 | /// |
454 | /// `static To castFailed()` and `static To doCastIfPossible(const From &f)` |
455 | /// |
456 | /// Your specialization might look something like this: |
457 | /// |
458 | /// template<> struct CastInfo<foo, bar> : public CastIsPossible<foo, bar> { |
459 | /// static inline foo doCast(const bar &b) { |
460 | /// return foo(const_cast<bar &>(b)); |
461 | /// } |
462 | /// static inline foo castFailed() { return foo(); } |
463 | /// static inline foo doCastIfPossible(const bar &b) { |
464 | /// if (!CastInfo<foo, bar>::isPossible(b)) |
465 | /// return castFailed(); |
466 | /// return doCast(b); |
467 | /// } |
468 | /// }; |
469 | |
470 | // The default implementations of CastInfo don't use cast traits for now because |
471 | // we need to specify types all over the place due to the current expected |
472 | // casting behavior and the way cast_retty works. New use cases can and should |
473 | // take advantage of the cast traits whenever possible! |
474 | |
475 | template <typename To, typename From, typename Enable = void> |
476 | struct CastInfo : public CastIsPossible<To, From> { |
477 | using Self = CastInfo<To, From, Enable>; |
478 | |
479 | using CastReturnType = typename cast_retty<To, From>::ret_type; |
480 | |
481 | static inline CastReturnType doCast(const From &f) { |
482 | return cast_convert_val< |
483 | To, From, |
484 | typename simplify_type<From>::SimpleType>::doit(const_cast<From &>(f)); |
485 | } |
486 | |
487 | // This assumes that you can construct the cast return type from `nullptr`. |
488 | // This is largely to support legacy use cases - if you don't want this |
489 | // behavior you should specialize CastInfo for your use case. |
490 | static inline CastReturnType castFailed() { return CastReturnType(nullptr); } |
491 | |
492 | static inline CastReturnType doCastIfPossible(const From &f) { |
493 | if (!Self::isPossible(f)) |
494 | return castFailed(); |
495 | return doCast(f); |
496 | } |
497 | }; |
498 | |
499 | /// This struct provides an overload for CastInfo where From has simplify_type |
500 | /// defined. This simply forwards to the appropriate CastInfo with the |
501 | /// simplified type/value, so you don't have to implement both. |
502 | template <typename To, typename From> |
503 | struct CastInfo<To, From, std::enable_if_t<!is_simple_type<From>::value>> { |
504 | using Self = CastInfo<To, From>; |
505 | using SimpleFrom = typename simplify_type<From>::SimpleType; |
506 | using SimplifiedSelf = CastInfo<To, SimpleFrom>; |
507 | |
508 | static inline bool isPossible(From &f) { |
509 | return SimplifiedSelf::isPossible( |
510 | simplify_type<From>::getSimplifiedValue(f)); |
511 | } |
512 | |
513 | static inline decltype(auto) doCast(From &f) { |
514 | return SimplifiedSelf::doCast(simplify_type<From>::getSimplifiedValue(f)); |
515 | } |
516 | |
517 | static inline decltype(auto) castFailed() { |
518 | return SimplifiedSelf::castFailed(); |
519 | } |
520 | |
521 | static inline decltype(auto) doCastIfPossible(From &f) { |
522 | return SimplifiedSelf::doCastIfPossible( |
523 | simplify_type<From>::getSimplifiedValue(f)); |
524 | } |
525 | }; |
526 | |
527 | //===----------------------------------------------------------------------===// |
528 | // Pre-specialized CastInfo |
529 | //===----------------------------------------------------------------------===// |
530 | |
531 | /// Provide a CastInfo specialized for std::unique_ptr. |
532 | template <typename To, typename From> |
533 | struct CastInfo<To, std::unique_ptr<From>> : public UniquePtrCast<To, From> {}; |
534 | |
535 | /// Provide a CastInfo specialized for std::optional<From>. It's assumed that if |
536 | /// the input is std::optional<From> that the output can be std::optional<To>. |
537 | /// If that's not the case, specialize CastInfo for your use case. |
538 | template <typename To, typename From> |
539 | struct CastInfo<To, std::optional<From>> : public OptionalValueCast<To, From> { |
540 | }; |
541 | |
542 | /// isa<X> - Return true if the parameter to the template is an instance of one |
543 | /// of the template type arguments. Used like this: |
544 | /// |
545 | /// if (isa<Type>(myVal)) { ... } |
546 | /// if (isa<Type0, Type1, Type2>(myVal)) { ... } |
547 | template <typename To, typename From> |
548 | [[nodiscard]] inline bool isa(const From &Val) { |
549 | return CastInfo<To, const From>::isPossible(Val); |
550 | } |
551 | |
552 | template <typename First, typename Second, typename... Rest, typename From> |
553 | [[nodiscard]] inline bool isa(const From &Val) { |
554 | return isa<First>(Val) || isa<Second, Rest...>(Val); |
555 | } |
556 | |
557 | /// cast<X> - Return the argument parameter cast to the specified type. This |
558 | /// casting operator asserts that the type is correct, so it does not return |
559 | /// null on failure. It does not allow a null argument (use cast_if_present for |
560 | /// that). It is typically used like this: |
561 | /// |
562 | /// cast<Instruction>(myVal)->getParent() |
563 | |
564 | template <typename To, typename From> |
565 | [[nodiscard]] inline decltype(auto) cast(const From &Val) { |
566 | assert(isa<To>(Val) && "cast<Ty>() argument of incompatible type!")(static_cast <bool> (isa<To>(Val) && "cast<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<To>(Val) && \"cast<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 566, __extension__ __PRETTY_FUNCTION__ )); |
567 | return CastInfo<To, const From>::doCast(Val); |
568 | } |
569 | |
570 | template <typename To, typename From> |
571 | [[nodiscard]] inline decltype(auto) cast(From &Val) { |
572 | assert(isa<To>(Val) && "cast<Ty>() argument of incompatible type!")(static_cast <bool> (isa<To>(Val) && "cast<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<To>(Val) && \"cast<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 572, __extension__ __PRETTY_FUNCTION__ )); |
573 | return CastInfo<To, From>::doCast(Val); |
574 | } |
575 | |
576 | template <typename To, typename From> |
577 | [[nodiscard]] inline decltype(auto) cast(From *Val) { |
578 | assert(isa<To>(Val) && "cast<Ty>() argument of incompatible type!")(static_cast <bool> (isa<To>(Val) && "cast<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<To>(Val) && \"cast<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 578, __extension__ __PRETTY_FUNCTION__ )); |
579 | return CastInfo<To, From *>::doCast(Val); |
580 | } |
581 | |
582 | template <typename To, typename From> |
583 | [[nodiscard]] inline decltype(auto) cast(std::unique_ptr<From> &&Val) { |
584 | assert(isa<To>(Val) && "cast<Ty>() argument of incompatible type!")(static_cast <bool> (isa<To>(Val) && "cast<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<To>(Val) && \"cast<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 584, __extension__ __PRETTY_FUNCTION__ )); |
585 | return CastInfo<To, std::unique_ptr<From>>::doCast(std::move(Val)); |
586 | } |
587 | |
588 | //===----------------------------------------------------------------------===// |
589 | // ValueIsPresent |
590 | //===----------------------------------------------------------------------===// |
591 | |
592 | template <typename T> |
593 | constexpr bool IsNullable = |
594 | std::is_pointer_v<T> || std::is_constructible_v<T, std::nullptr_t>; |
595 | |
596 | /// ValueIsPresent provides a way to check if a value is, well, present. For |
597 | /// pointers, this is the equivalent of checking against nullptr, for Optionals |
598 | /// this is the equivalent of checking hasValue(). It also provides a method for |
599 | /// unwrapping a value (think calling .value() on an optional). |
600 | |
601 | // Generic values can't *not* be present. |
602 | template <typename T, typename Enable = void> struct ValueIsPresent { |
603 | using UnwrappedType = T; |
604 | static inline bool isPresent(const T &t) { return true; } |
605 | static inline decltype(auto) unwrapValue(T &t) { return t; } |
606 | }; |
607 | |
608 | // Optional provides its own way to check if something is present. |
609 | template <typename T> struct ValueIsPresent<std::optional<T>> { |
610 | using UnwrappedType = T; |
611 | static inline bool isPresent(const std::optional<T> &t) { |
612 | return t.has_value(); |
613 | } |
614 | static inline decltype(auto) unwrapValue(std::optional<T> &t) { return *t; } |
615 | }; |
616 | |
617 | // If something is "nullable" then we just compare it to nullptr to see if it |
618 | // exists. |
619 | template <typename T> |
620 | struct ValueIsPresent<T, std::enable_if_t<IsNullable<T>>> { |
621 | using UnwrappedType = T; |
622 | static inline bool isPresent(const T &t) { return t != T(nullptr); } |
623 | static inline decltype(auto) unwrapValue(T &t) { return t; } |
624 | }; |
625 | |
626 | namespace detail { |
627 | // Convenience function we can use to check if a value is present. Because of |
628 | // simplify_type, we have to call it on the simplified type for now. |
629 | template <typename T> inline bool isPresent(const T &t) { |
630 | return ValueIsPresent<typename simplify_type<T>::SimpleType>::isPresent( |
631 | simplify_type<T>::getSimplifiedValue(const_cast<T &>(t))); |
632 | } |
633 | |
634 | // Convenience function we can use to unwrap a value. |
635 | template <typename T> inline decltype(auto) unwrapValue(T &t) { |
636 | return ValueIsPresent<T>::unwrapValue(t); |
637 | } |
638 | } // namespace detail |
639 | |
640 | /// dyn_cast<X> - Return the argument parameter cast to the specified type. This |
641 | /// casting operator returns null if the argument is of the wrong type, so it |
642 | /// can be used to test for a type as well as cast if successful. The value |
643 | /// passed in must be present, if not, use dyn_cast_if_present. This should be |
644 | /// used in the context of an if statement like this: |
645 | /// |
646 | /// if (const Instruction *I = dyn_cast<Instruction>(myVal)) { ... } |
647 | |
648 | template <typename To, typename From> |
649 | [[nodiscard]] inline decltype(auto) dyn_cast(const From &Val) { |
650 | assert(detail::isPresent(Val) && "dyn_cast on a non-existent value")(static_cast <bool> (detail::isPresent(Val) && "dyn_cast on a non-existent value" ) ? void (0) : __assert_fail ("detail::isPresent(Val) && \"dyn_cast on a non-existent value\"" , "llvm/include/llvm/Support/Casting.h", 650, __extension__ __PRETTY_FUNCTION__ )); |
651 | return CastInfo<To, const From>::doCastIfPossible(Val); |
652 | } |
653 | |
654 | template <typename To, typename From> |
655 | [[nodiscard]] inline decltype(auto) dyn_cast(From &Val) { |
656 | assert(detail::isPresent(Val) && "dyn_cast on a non-existent value")(static_cast <bool> (detail::isPresent(Val) && "dyn_cast on a non-existent value" ) ? void (0) : __assert_fail ("detail::isPresent(Val) && \"dyn_cast on a non-existent value\"" , "llvm/include/llvm/Support/Casting.h", 656, __extension__ __PRETTY_FUNCTION__ )); |
657 | return CastInfo<To, From>::doCastIfPossible(Val); |
658 | } |
659 | |
660 | template <typename To, typename From> |
661 | [[nodiscard]] inline decltype(auto) dyn_cast(From *Val) { |
662 | assert(detail::isPresent(Val) && "dyn_cast on a non-existent value")(static_cast <bool> (detail::isPresent(Val) && "dyn_cast on a non-existent value" ) ? void (0) : __assert_fail ("detail::isPresent(Val) && \"dyn_cast on a non-existent value\"" , "llvm/include/llvm/Support/Casting.h", 662, __extension__ __PRETTY_FUNCTION__ )); |
663 | return CastInfo<To, From *>::doCastIfPossible(Val); |
664 | } |
665 | |
666 | template <typename To, typename From> |
667 | [[nodiscard]] inline decltype(auto) dyn_cast(std::unique_ptr<From> &&Val) { |
668 | assert(detail::isPresent(Val) && "dyn_cast on a non-existent value")(static_cast <bool> (detail::isPresent(Val) && "dyn_cast on a non-existent value" ) ? void (0) : __assert_fail ("detail::isPresent(Val) && \"dyn_cast on a non-existent value\"" , "llvm/include/llvm/Support/Casting.h", 668, __extension__ __PRETTY_FUNCTION__ )); |
669 | return CastInfo<To, std::unique_ptr<From>>::doCastIfPossible( |
670 | std::forward<std::unique_ptr<From> &&>(Val)); |
671 | } |
672 | |
673 | /// isa_and_present<X> - Functionally identical to isa, except that a null value |
674 | /// is accepted. |
675 | template <typename... X, class Y> |
676 | [[nodiscard]] inline bool isa_and_present(const Y &Val) { |
677 | if (!detail::isPresent(Val)) |
678 | return false; |
679 | return isa<X...>(Val); |
680 | } |
681 | |
682 | template <typename... X, class Y> |
683 | [[nodiscard]] inline bool isa_and_nonnull(const Y &Val) { |
684 | return isa_and_present<X...>(Val); |
685 | } |
686 | |
687 | /// cast_if_present<X> - Functionally identical to cast, except that a null |
688 | /// value is accepted. |
689 | template <class X, class Y> |
690 | [[nodiscard]] inline auto cast_if_present(const Y &Val) { |
691 | if (!detail::isPresent(Val)) |
692 | return CastInfo<X, const Y>::castFailed(); |
693 | assert(isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!")(static_cast <bool> (isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<X>(Val) && \"cast_if_present<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 693, __extension__ __PRETTY_FUNCTION__ )); |
694 | return cast<X>(detail::unwrapValue(Val)); |
695 | } |
696 | |
697 | template <class X, class Y> [[nodiscard]] inline auto cast_if_present(Y &Val) { |
698 | if (!detail::isPresent(Val)) |
699 | return CastInfo<X, Y>::castFailed(); |
700 | assert(isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!")(static_cast <bool> (isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<X>(Val) && \"cast_if_present<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 700, __extension__ __PRETTY_FUNCTION__ )); |
701 | return cast<X>(detail::unwrapValue(Val)); |
702 | } |
703 | |
704 | template <class X, class Y> [[nodiscard]] inline auto cast_if_present(Y *Val) { |
705 | if (!detail::isPresent(Val)) |
706 | return CastInfo<X, Y *>::castFailed(); |
707 | assert(isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!")(static_cast <bool> (isa<X>(Val) && "cast_if_present<Ty>() argument of incompatible type!" ) ? void (0) : __assert_fail ("isa<X>(Val) && \"cast_if_present<Ty>() argument of incompatible type!\"" , "llvm/include/llvm/Support/Casting.h", 707, __extension__ __PRETTY_FUNCTION__ )); |
708 | return cast<X>(detail::unwrapValue(Val)); |
709 | } |
710 | |
711 | template <class X, class Y> |
712 | [[nodiscard]] inline auto cast_if_present(std::unique_ptr<Y> &&Val) { |
713 | if (!detail::isPresent(Val)) |
714 | return UniquePtrCast<X, Y>::castFailed(); |
715 | return UniquePtrCast<X, Y>::doCast(std::move(Val)); |
716 | } |
717 | |
718 | // Provide a forwarding from cast_or_null to cast_if_present for current |
719 | // users. This is deprecated and will be removed in a future patch, use |
720 | // cast_if_present instead. |
721 | template <class X, class Y> auto cast_or_null(const Y &Val) { |
722 | return cast_if_present<X>(Val); |
723 | } |
724 | |
725 | template <class X, class Y> auto cast_or_null(Y &Val) { |
726 | return cast_if_present<X>(Val); |
727 | } |
728 | |
729 | template <class X, class Y> auto cast_or_null(Y *Val) { |
730 | return cast_if_present<X>(Val); |
731 | } |
732 | |
733 | template <class X, class Y> auto cast_or_null(std::unique_ptr<Y> &&Val) { |
734 | return cast_if_present<X>(std::move(Val)); |
735 | } |
736 | |
737 | /// dyn_cast_if_present<X> - Functionally identical to dyn_cast, except that a |
738 | /// null (or none in the case of optionals) value is accepted. |
739 | template <class X, class Y> auto dyn_cast_if_present(const Y &Val) { |
740 | if (!detail::isPresent(Val)) |
741 | return CastInfo<X, const Y>::castFailed(); |
742 | return CastInfo<X, const Y>::doCastIfPossible(detail::unwrapValue(Val)); |
743 | } |
744 | |
745 | template <class X, class Y> auto dyn_cast_if_present(Y &Val) { |
746 | if (!detail::isPresent(Val)) |
747 | return CastInfo<X, Y>::castFailed(); |
748 | return CastInfo<X, Y>::doCastIfPossible(detail::unwrapValue(Val)); |
749 | } |
750 | |
751 | template <class X, class Y> auto dyn_cast_if_present(Y *Val) { |
752 | if (!detail::isPresent(Val)) |
753 | return CastInfo<X, Y *>::castFailed(); |
754 | return CastInfo<X, Y *>::doCastIfPossible(detail::unwrapValue(Val)); |
755 | } |
756 | |
757 | // Forwards to dyn_cast_if_present to avoid breaking current users. This is |
758 | // deprecated and will be removed in a future patch, use |
759 | // cast_if_present instead. |
760 | template <class X, class Y> auto dyn_cast_or_null(const Y &Val) { |
761 | return dyn_cast_if_present<X>(Val); |
762 | } |
763 | |
764 | template <class X, class Y> auto dyn_cast_or_null(Y &Val) { |
765 | return dyn_cast_if_present<X>(Val); |
766 | } |
767 | |
768 | template <class X, class Y> auto dyn_cast_or_null(Y *Val) { |
769 | return dyn_cast_if_present<X>(Val); |
770 | } |
771 | |
772 | /// unique_dyn_cast<X> - Given a unique_ptr<Y>, try to return a unique_ptr<X>, |
773 | /// taking ownership of the input pointer iff isa<X>(Val) is true. If the |
774 | /// cast is successful, From refers to nullptr on exit and the casted value |
775 | /// is returned. If the cast is unsuccessful, the function returns nullptr |
776 | /// and From is unchanged. |
777 | template <class X, class Y> |
778 | [[nodiscard]] inline typename CastInfo<X, std::unique_ptr<Y>>::CastResultType |
779 | unique_dyn_cast(std::unique_ptr<Y> &Val) { |
780 | if (!isa<X>(Val)) |
781 | return nullptr; |
782 | return cast<X>(std::move(Val)); |
783 | } |
784 | |
785 | template <class X, class Y> |
786 | [[nodiscard]] inline auto unique_dyn_cast(std::unique_ptr<Y> &&Val) { |
787 | return unique_dyn_cast<X, Y>(Val); |
788 | } |
789 | |
790 | // unique_dyn_cast_or_null<X> - Functionally identical to unique_dyn_cast, |
791 | // except that a null value is accepted. |
792 | template <class X, class Y> |
793 | [[nodiscard]] inline typename CastInfo<X, std::unique_ptr<Y>>::CastResultType |
794 | unique_dyn_cast_or_null(std::unique_ptr<Y> &Val) { |
795 | if (!Val) |
796 | return nullptr; |
797 | return unique_dyn_cast<X, Y>(Val); |
798 | } |
799 | |
800 | template <class X, class Y> |
801 | [[nodiscard]] inline auto unique_dyn_cast_or_null(std::unique_ptr<Y> &&Val) { |
802 | return unique_dyn_cast_or_null<X, Y>(Val); |
803 | } |
804 | |
805 | } // end namespace llvm |
806 | |
807 | #endif // LLVM_SUPPORT_CASTING_H |
1 | //===- Operation.h - MLIR Operation Class -----------------------*- C++ -*-===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This file defines the Operation class. |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #ifndef MLIR_IR_OPERATION_H |
14 | #define MLIR_IR_OPERATION_H |
15 | |
16 | #include "mlir/IR/Block.h" |
17 | #include "mlir/IR/BuiltinAttributes.h" |
18 | #include "mlir/IR/Diagnostics.h" |
19 | #include "mlir/IR/OperationSupport.h" |
20 | #include "mlir/IR/Region.h" |
21 | #include "llvm/ADT/Twine.h" |
22 | #include <optional> |
23 | |
24 | namespace mlir { |
25 | namespace detail { |
26 | /// This is a "tag" used for mapping the properties storage in |
27 | /// llvm::TrailingObjects. |
28 | enum class OpProperties : char {}; |
29 | } // namespace detail |
30 | |
31 | /// Operation is the basic unit of execution within MLIR. |
32 | /// |
33 | /// The following documentation are recommended to understand this class: |
34 | /// - https://mlir.llvm.org/docs/LangRef/#operations |
35 | /// - https://mlir.llvm.org/docs/Tutorials/UnderstandingTheIRStructure/ |
36 | /// |
37 | /// An Operation is defined first by its name, which is a unique string. The |
38 | /// name is interpreted so that if it contains a '.' character, the part before |
39 | /// is the dialect name this operation belongs to, and everything that follows |
40 | /// is this operation name within the dialect. |
41 | /// |
42 | /// An Operation defines zero or more SSA `Value` that we refer to as the |
43 | /// Operation results. This array of Value is actually stored in memory before |
44 | /// the Operation itself in reverse order. That is for an Operation with 3 |
45 | /// results we allocate the following memory layout: |
46 | /// |
47 | /// [Result2, Result1, Result0, Operation] |
48 | /// ^ this is where `Operation*` pointer points to. |
49 | /// |
50 | /// A consequence of this is that this class must be heap allocated, which is |
51 | /// handled by the various `create` methods. Each result contains: |
52 | /// - one pointer to the first use (see `OpOperand`) |
53 | /// - the type of the SSA Value this result defines. |
54 | /// - the index for this result in the array. |
55 | /// The results are defined as subclass of `ValueImpl`, and more precisely as |
56 | /// the only two subclasses of `OpResultImpl`: `InlineOpResult` and |
57 | /// `OutOfLineOpResult`. The former is used for the first 5 results and the |
58 | /// latter for the subsequent ones. They differ in how they store their index: |
59 | /// the first 5 results only need 3 bits and thus are packed with the Type |
60 | /// pointer, while the subsequent one have an extra `unsigned` value and thus |
61 | /// need more space. |
62 | /// |
63 | /// An Operation also has zero or more operands: these are uses of SSA Value, |
64 | /// which can be the results of other operations or Block arguments. Each of |
65 | /// these uses is an instance of `OpOperand`. This optional array is initially |
66 | /// tail allocated with the operation class itself, but can be dynamically moved |
67 | /// out-of-line in a dynamic allocation as needed. |
68 | /// |
69 | /// An Operation may contain optionally one or multiple Regions, stored in a |
70 | /// tail allocated array. Each `Region` is a list of Blocks. Each `Block` is |
71 | /// itself a list of Operations. This structure is effectively forming a tree. |
72 | /// |
73 | /// Some operations like branches also refer to other Block, in which case they |
74 | /// would have an array of `BlockOperand`. |
75 | /// |
76 | /// An Operation may contain optionally a "Properties" object: this is a |
77 | /// pre-defined C++ object with a fixed size. This object is owned by the |
78 | /// operation and deleted with the operation. It can be converted to an |
79 | /// Attribute on demand, or loaded from an Attribute. |
80 | /// |
81 | /// |
82 | /// Finally an Operation also contain an optional `DictionaryAttr`, a Location, |
83 | /// and a pointer to its parent Block (if any). |
84 | class alignas(8) Operation final |
85 | : public llvm::ilist_node_with_parent<Operation, Block>, |
86 | private llvm::TrailingObjects<Operation, detail::OperandStorage, |
87 | detail::OpProperties, BlockOperand, Region, |
88 | OpOperand> { |
89 | public: |
90 | /// Create a new Operation with the specific fields. This constructor |
91 | /// populates the provided attribute list with default attributes if |
92 | /// necessary. |
93 | static Operation *create(Location location, OperationName name, |
94 | TypeRange resultTypes, ValueRange operands, |
95 | NamedAttrList &&attributes, |
96 | OpaqueProperties properties, BlockRange successors, |
97 | unsigned numRegions); |
98 | |
99 | /// Create a new Operation with the specific fields. This constructor uses an |
100 | /// existing attribute dictionary to avoid uniquing a list of attributes. |
101 | static Operation *create(Location location, OperationName name, |
102 | TypeRange resultTypes, ValueRange operands, |
103 | DictionaryAttr attributes, |
104 | OpaqueProperties properties, BlockRange successors, |
105 | unsigned numRegions); |
106 | |
107 | /// Create a new Operation from the fields stored in `state`. |
108 | static Operation *create(const OperationState &state); |
109 | |
110 | /// Create a new Operation with the specific fields. |
111 | static Operation *create(Location location, OperationName name, |
112 | TypeRange resultTypes, ValueRange operands, |
113 | NamedAttrList &&attributes, |
114 | OpaqueProperties properties, |
115 | BlockRange successors = {}, |
116 | RegionRange regions = {}); |
117 | |
118 | /// The name of an operation is the key identifier for it. |
119 | OperationName getName() { return name; } |
120 | |
121 | /// If this operation has a registered operation description, return it. |
122 | /// Otherwise return std::nullopt. |
123 | std::optional<RegisteredOperationName> getRegisteredInfo() { |
124 | return getName().getRegisteredInfo(); |
125 | } |
126 | |
127 | /// Returns true if this operation has a registered operation description, |
128 | /// otherwise false. |
129 | bool isRegistered() { return getName().isRegistered(); } |
130 | |
131 | /// Remove this operation from its parent block and delete it. |
132 | void erase(); |
133 | |
134 | /// Remove the operation from its parent block, but don't delete it. |
135 | void remove(); |
136 | |
137 | /// Class encompassing various options related to cloning an operation. Users |
138 | /// of this class should pass it to Operation's 'clone' methods. |
139 | /// Current options include: |
140 | /// * Whether cloning should recursively traverse into the regions of the |
141 | /// operation or not. |
142 | /// * Whether cloning should also clone the operands of the operation. |
143 | class CloneOptions { |
144 | public: |
145 | /// Default constructs an option with all flags set to false. That means all |
146 | /// parts of an operation that may optionally not be cloned, are not cloned. |
147 | CloneOptions(); |
148 | |
149 | /// Constructs an instance with the clone regions and clone operands flags |
150 | /// set accordingly. |
151 | CloneOptions(bool cloneRegions, bool cloneOperands); |
152 | |
153 | /// Returns an instance with all flags set to true. This is the default |
154 | /// when using the clone method and clones all parts of the operation. |
155 | static CloneOptions all(); |
156 | |
157 | /// Configures whether cloning should traverse into any of the regions of |
158 | /// the operation. If set to true, the operation's regions are recursively |
159 | /// cloned. If set to false, cloned operations will have the same number of |
160 | /// regions, but they will be empty. |
161 | /// Cloning of nested operations in the operation's regions are currently |
162 | /// unaffected by other flags. |
163 | CloneOptions &cloneRegions(bool enable = true); |
164 | |
165 | /// Returns whether regions of the operation should be cloned as well. |
166 | bool shouldCloneRegions() const { return cloneRegionsFlag; } |
167 | |
168 | /// Configures whether operation' operands should be cloned. Otherwise the |
169 | /// resulting clones will simply have zero operands. |
170 | CloneOptions &cloneOperands(bool enable = true); |
171 | |
172 | /// Returns whether operands should be cloned as well. |
173 | bool shouldCloneOperands() const { return cloneOperandsFlag; } |
174 | |
175 | private: |
176 | /// Whether regions should be cloned. |
177 | bool cloneRegionsFlag : 1; |
178 | /// Whether operands should be cloned. |
179 | bool cloneOperandsFlag : 1; |
180 | }; |
181 | |
182 | /// Create a deep copy of this operation, remapping any operands that use |
183 | /// values outside of the operation using the map that is provided (leaving |
184 | /// them alone if no entry is present). Replaces references to cloned |
185 | /// sub-operations to the corresponding operation that is copied, and adds |
186 | /// those mappings to the map. |
187 | /// Optionally, one may configure what parts of the operation to clone using |
188 | /// the options parameter. |
189 | /// |
190 | /// Calling this method from multiple threads is generally safe if through the |
191 | /// process of cloning no new uses of 'Value's from outside the operation are |
192 | /// created. Cloning an isolated-from-above operation with no operands, such |
193 | /// as top level function operations, is therefore always safe. Using the |
194 | /// mapper, it is possible to avoid adding uses to outside operands by |
195 | /// remapping them to 'Value's owned by the caller thread. |
196 | Operation *clone(IRMapping &mapper, |
197 | CloneOptions options = CloneOptions::all()); |
198 | Operation *clone(CloneOptions options = CloneOptions::all()); |
199 | |
200 | /// Create a partial copy of this operation without traversing into attached |
201 | /// regions. The new operation will have the same number of regions as the |
202 | /// original one, but they will be left empty. |
203 | /// Operands are remapped using `mapper` (if present), and `mapper` is updated |
204 | /// to contain the results. |
205 | Operation *cloneWithoutRegions(IRMapping &mapper); |
206 | |
207 | /// Create a partial copy of this operation without traversing into attached |
208 | /// regions. The new operation will have the same number of regions as the |
209 | /// original one, but they will be left empty. |
210 | Operation *cloneWithoutRegions(); |
211 | |
212 | /// Returns the operation block that contains this operation. |
213 | Block *getBlock() { return block; } |
214 | |
215 | /// Return the context this operation is associated with. |
216 | MLIRContext *getContext() { return location->getContext(); } |
217 | |
218 | /// Return the dialect this operation is associated with, or nullptr if the |
219 | /// associated dialect is not loaded. |
220 | Dialect *getDialect() { return getName().getDialect(); } |
221 | |
222 | /// The source location the operation was defined or derived from. |
223 | Location getLoc() { return location; } |
224 | |
225 | /// Set the source location the operation was defined or derived from. |
226 | void setLoc(Location loc) { location = loc; } |
227 | |
228 | /// Returns the region to which the instruction belongs. Returns nullptr if |
229 | /// the instruction is unlinked. |
230 | Region *getParentRegion() { return block ? block->getParent() : nullptr; } |
231 | |
232 | /// Returns the closest surrounding operation that contains this operation |
233 | /// or nullptr if this is a top-level operation. |
234 | Operation *getParentOp() { return block ? block->getParentOp() : nullptr; } |
235 | |
236 | /// Return the closest surrounding parent operation that is of type 'OpTy'. |
237 | template <typename OpTy> |
238 | OpTy getParentOfType() { |
239 | auto *op = this; |
240 | while ((op = op->getParentOp())) |
241 | if (auto parentOp = dyn_cast<OpTy>(op)) |
242 | return parentOp; |
243 | return OpTy(); |
244 | } |
245 | |
246 | /// Returns the closest surrounding parent operation with trait `Trait`. |
247 | template <template <typename T> class Trait> |
248 | Operation *getParentWithTrait() { |
249 | Operation *op = this; |
250 | while ((op = op->getParentOp())) |
251 | if (op->hasTrait<Trait>()) |
252 | return op; |
253 | return nullptr; |
254 | } |
255 | |
256 | /// Return true if this operation is a proper ancestor of the `other` |
257 | /// operation. |
258 | bool isProperAncestor(Operation *other); |
259 | |
260 | /// Return true if this operation is an ancestor of the `other` operation. An |
261 | /// operation is considered as its own ancestor, use `isProperAncestor` to |
262 | /// avoid this. |
263 | bool isAncestor(Operation *other) { |
264 | return this == other || isProperAncestor(other); |
265 | } |
266 | |
267 | /// Replace any uses of 'from' with 'to' within this operation. |
268 | void replaceUsesOfWith(Value from, Value to); |
269 | |
270 | /// Replace all uses of results of this operation with the provided 'values'. |
271 | template <typename ValuesT> |
272 | void replaceAllUsesWith(ValuesT &&values) { |
273 | getResults().replaceAllUsesWith(std::forward<ValuesT>(values)); |
274 | } |
275 | |
276 | /// Replace uses of results of this operation with the provided `values` if |
277 | /// the given callback returns true. |
278 | template <typename ValuesT> |
279 | void replaceUsesWithIf(ValuesT &&values, |
280 | function_ref<bool(OpOperand &)> shouldReplace) { |
281 | getResults().replaceUsesWithIf(std::forward<ValuesT>(values), |
282 | shouldReplace); |
283 | } |
284 | |
285 | /// Destroys this operation and its subclass data. |
286 | void destroy(); |
287 | |
288 | /// This drops all operand uses from this operation, which is an essential |
289 | /// step in breaking cyclic dependences between references when they are to |
290 | /// be deleted. |
291 | void dropAllReferences(); |
292 | |
293 | /// Drop uses of all values defined by this operation or its nested regions. |
294 | void dropAllDefinedValueUses(); |
295 | |
296 | /// Unlink this operation from its current block and insert it right before |
297 | /// `existingOp` which may be in the same or another block in the same |
298 | /// function. |
299 | void moveBefore(Operation *existingOp); |
300 | |
301 | /// Unlink this operation from its current block and insert it right before |
302 | /// `iterator` in the specified block. |
303 | void moveBefore(Block *block, llvm::iplist<Operation>::iterator iterator); |
304 | |
305 | /// Unlink this operation from its current block and insert it right after |
306 | /// `existingOp` which may be in the same or another block in the same |
307 | /// function. |
308 | void moveAfter(Operation *existingOp); |
309 | |
310 | /// Unlink this operation from its current block and insert it right after |
311 | /// `iterator` in the specified block. |
312 | void moveAfter(Block *block, llvm::iplist<Operation>::iterator iterator); |
313 | |
314 | /// Given an operation 'other' that is within the same parent block, return |
315 | /// whether the current operation is before 'other' in the operation list |
316 | /// of the parent block. |
317 | /// Note: This function has an average complexity of O(1), but worst case may |
318 | /// take O(N) where N is the number of operations within the parent block. |
319 | bool isBeforeInBlock(Operation *other); |
320 | |
321 | void print(raw_ostream &os, const OpPrintingFlags &flags = std::nullopt); |
322 | void print(raw_ostream &os, AsmState &state); |
323 | void dump(); |
324 | |
325 | //===--------------------------------------------------------------------===// |
326 | // Operands |
327 | //===--------------------------------------------------------------------===// |
328 | |
329 | /// Replace the current operands of this operation with the ones provided in |
330 | /// 'operands'. |
331 | void setOperands(ValueRange operands); |
332 | |
333 | /// Replace the operands beginning at 'start' and ending at 'start' + 'length' |
334 | /// with the ones provided in 'operands'. 'operands' may be smaller or larger |
335 | /// than the range pointed to by 'start'+'length'. |
336 | void setOperands(unsigned start, unsigned length, ValueRange operands); |
337 | |
338 | /// Insert the given operands into the operand list at the given 'index'. |
339 | void insertOperands(unsigned index, ValueRange operands); |
340 | |
341 | unsigned getNumOperands() { |
342 | return LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true) ? getOperandStorage().size() : 0; |
343 | } |
344 | |
345 | Value getOperand(unsigned idx) { return getOpOperand(idx).get(); } |
346 | void setOperand(unsigned idx, Value value) { |
347 | return getOpOperand(idx).set(value); |
348 | } |
349 | |
350 | /// Erase the operand at position `idx`. |
351 | void eraseOperand(unsigned idx) { eraseOperands(idx); } |
352 | |
353 | /// Erase the operands starting at position `idx` and ending at position |
354 | /// 'idx'+'length'. |
355 | void eraseOperands(unsigned idx, unsigned length = 1) { |
356 | getOperandStorage().eraseOperands(idx, length); |
357 | } |
358 | |
359 | /// Erases the operands that have their corresponding bit set in |
360 | /// `eraseIndices` and removes them from the operand list. |
361 | void eraseOperands(const BitVector &eraseIndices) { |
362 | getOperandStorage().eraseOperands(eraseIndices); |
363 | } |
364 | |
365 | // Support operand iteration. |
366 | using operand_range = OperandRange; |
367 | using operand_iterator = operand_range::iterator; |
368 | |
369 | operand_iterator operand_begin() { return getOperands().begin(); } |
370 | operand_iterator operand_end() { return getOperands().end(); } |
371 | |
372 | /// Returns an iterator on the underlying Value's. |
373 | operand_range getOperands() { |
374 | MutableArrayRef<OpOperand> operands = getOpOperands(); |
375 | return OperandRange(operands.data(), operands.size()); |
376 | } |
377 | |
378 | MutableArrayRef<OpOperand> getOpOperands() { |
379 | return LLVM_LIKELY(hasOperandStorage)__builtin_expect((bool)(hasOperandStorage), true) ? getOperandStorage().getOperands() |
380 | : MutableArrayRef<OpOperand>(); |
381 | } |
382 | |
383 | OpOperand &getOpOperand(unsigned idx) { |
384 | return getOperandStorage().getOperands()[idx]; |
385 | } |
386 | |
387 | // Support operand type iteration. |
388 | using operand_type_iterator = operand_range::type_iterator; |
389 | using operand_type_range = operand_range::type_range; |
390 | operand_type_iterator operand_type_begin() { return operand_begin(); } |
391 | operand_type_iterator operand_type_end() { return operand_end(); } |
392 | operand_type_range getOperandTypes() { return getOperands().getTypes(); } |
393 | |
394 | //===--------------------------------------------------------------------===// |
395 | // Results |
396 | //===--------------------------------------------------------------------===// |
397 | |
398 | /// Return the number of results held by this operation. |
399 | unsigned getNumResults() { return numResults; } |
400 | |
401 | /// Get the 'idx'th result of this operation. |
402 | OpResult getResult(unsigned idx) { return OpResult(getOpResultImpl(idx)); } |
403 | |
404 | /// Support result iteration. |
405 | using result_range = ResultRange; |
406 | using result_iterator = result_range::iterator; |
407 | |
408 | result_iterator result_begin() { return getResults().begin(); } |
409 | result_iterator result_end() { return getResults().end(); } |
410 | result_range getResults() { |
411 | return numResults == 0 ? result_range(nullptr, 0) |
412 | : result_range(getInlineOpResult(0), numResults); |
413 | } |
414 | |
415 | result_range getOpResults() { return getResults(); } |
416 | OpResult getOpResult(unsigned idx) { return getResult(idx); } |
417 | |
418 | /// Support result type iteration. |
419 | using result_type_iterator = result_range::type_iterator; |
420 | using result_type_range = result_range::type_range; |
421 | result_type_iterator result_type_begin() { return getResultTypes().begin(); } |
422 | result_type_iterator result_type_end() { return getResultTypes().end(); } |
423 | result_type_range getResultTypes() { return getResults().getTypes(); } |
424 | |
425 | //===--------------------------------------------------------------------===// |
426 | // Attributes |
427 | //===--------------------------------------------------------------------===// |
428 | |
429 | // Operations may optionally carry a list of attributes that associate |
430 | // constants to names. Attributes may be dynamically added and removed over |
431 | // the lifetime of an operation. |
432 | |
433 | /// Access an inherent attribute by name: returns an empty optional if there |
434 | /// is no inherent attribute with this name. |
435 | /// |
436 | /// This method is available as a transient facility in the migration process |
437 | /// to use Properties instead. |
438 | std::optional<Attribute> getInherentAttr(StringRef name); |
439 | |
440 | /// Set an inherent attribute by name. |
441 | /// |
442 | /// This method is available as a transient facility in the migration process |
443 | /// to use Properties instead. |
444 | void setInherentAttr(StringAttr name, Attribute value); |
445 | |
446 | /// Access a discardable attribute by name, returns an null Attribute if the |
447 | /// discardable attribute does not exist. |
448 | Attribute getDiscardableAttr(StringRef name) { return attrs.get(name); } |
449 | |
450 | /// Access a discardable attribute by name, returns an null Attribute if the |
451 | /// discardable attribute does not exist. |
452 | Attribute getDiscardableAttr(StringAttr name) { return attrs.get(name); } |
453 | |
454 | /// Set a discardable attribute by name. |
455 | void setDiscardableAttr(StringAttr name, Attribute value) { |
456 | NamedAttrList attributes(attrs); |
457 | if (attributes.set(name, value) != value) |
458 | attrs = attributes.getDictionary(getContext()); |
459 | } |
460 | |
461 | /// Return all of the discardable attributes on this operation. |
462 | ArrayRef<NamedAttribute> getDiscardableAttrs() { return attrs.getValue(); } |
463 | |
464 | /// Return all of the discardable attributes on this operation as a |
465 | /// DictionaryAttr. |
466 | DictionaryAttr getDiscardableAttrDictionary() { return attrs; } |
467 | |
468 | /// Return all of the attributes on this operation. |
469 | ArrayRef<NamedAttribute> getAttrs() { |
470 | if (!getPropertiesStorage()) |
471 | return getDiscardableAttrs(); |
472 | return getAttrDictionary().getValue(); |
473 | } |
474 | |
475 | /// Return all of the attributes on this operation as a DictionaryAttr. |
476 | DictionaryAttr getAttrDictionary(); |
477 | |
478 | /// Set the attributes from a dictionary on this operation. |
479 | /// These methods are expensive: if the dictionnary only contains discardable |
480 | /// attributes, `setDiscardableAttrs` is more efficient. |
481 | void setAttrs(DictionaryAttr newAttrs); |
482 | void setAttrs(ArrayRef<NamedAttribute> newAttrs); |
483 | /// Set the discardable attribute dictionary on this operation. |
484 | void setDiscardableAttrs(DictionaryAttr newAttrs) { |
485 | assert(newAttrs && "expected valid attribute dictionary")(static_cast <bool> (newAttrs && "expected valid attribute dictionary" ) ? void (0) : __assert_fail ("newAttrs && \"expected valid attribute dictionary\"" , "mlir/include/mlir/IR/Operation.h", 485, __extension__ __PRETTY_FUNCTION__ )); |
486 | attrs = newAttrs; |
487 | } |
488 | void setDiscardableAttrs(ArrayRef<NamedAttribute> newAttrs) { |
489 | setDiscardableAttrs(DictionaryAttr::get(getContext(), newAttrs)); |
490 | } |
491 | |
492 | /// Return the specified attribute if present, null otherwise. |
493 | /// These methods are expensive: if the dictionnary only contains discardable |
494 | /// attributes, `getDiscardableAttr` is more efficient. |
495 | Attribute getAttr(StringAttr name) { |
496 | if (getPropertiesStorageSize()) { |
497 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) |
498 | return *inherentAttr; |
499 | } |
500 | return attrs.get(name); |
501 | } |
502 | Attribute getAttr(StringRef name) { |
503 | if (getPropertiesStorageSize()) { |
504 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) |
505 | return *inherentAttr; |
506 | } |
507 | return attrs.get(name); |
508 | } |
509 | |
510 | template <typename AttrClass> |
511 | AttrClass getAttrOfType(StringAttr name) { |
512 | return getAttr(name).dyn_cast_or_null<AttrClass>(); |
513 | } |
514 | template <typename AttrClass> |
515 | AttrClass getAttrOfType(StringRef name) { |
516 | return getAttr(name).dyn_cast_or_null<AttrClass>(); |
517 | } |
518 | |
519 | /// Return true if the operation has an attribute with the provided name, |
520 | /// false otherwise. |
521 | bool hasAttr(StringAttr name) { |
522 | if (getPropertiesStorageSize()) { |
523 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) |
524 | return (bool)*inherentAttr; |
525 | } |
526 | return attrs.contains(name); |
527 | } |
528 | bool hasAttr(StringRef name) { |
529 | if (getPropertiesStorageSize()) { |
530 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) |
531 | return (bool)*inherentAttr; |
532 | } |
533 | return attrs.contains(name); |
534 | } |
535 | template <typename AttrClass, typename NameT> |
536 | bool hasAttrOfType(NameT &&name) { |
537 | return static_cast<bool>( |
538 | getAttrOfType<AttrClass>(std::forward<NameT>(name))); |
539 | } |
540 | |
541 | /// If the an attribute exists with the specified name, change it to the new |
542 | /// value. Otherwise, add a new attribute with the specified name/value. |
543 | void setAttr(StringAttr name, Attribute value) { |
544 | if (getPropertiesStorageSize()) { |
545 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) { |
546 | setInherentAttr(name, value); |
547 | return; |
548 | } |
549 | } |
550 | NamedAttrList attributes(attrs); |
551 | if (attributes.set(name, value) != value) |
552 | attrs = attributes.getDictionary(getContext()); |
553 | } |
554 | void setAttr(StringRef name, Attribute value) { |
555 | setAttr(StringAttr::get(getContext(), name), value); |
556 | } |
557 | |
558 | /// Remove the attribute with the specified name if it exists. Return the |
559 | /// attribute that was erased, or nullptr if there was no attribute with such |
560 | /// name. |
561 | Attribute removeAttr(StringAttr name) { |
562 | if (getPropertiesStorageSize()) { |
563 | if (std::optional<Attribute> inherentAttr = getInherentAttr(name)) { |
564 | setInherentAttr(name, {}); |
565 | return *inherentAttr; |
566 | } |
567 | } |
568 | NamedAttrList attributes(attrs); |
569 | Attribute removedAttr = attributes.erase(name); |
570 | if (removedAttr) |
571 | attrs = attributes.getDictionary(getContext()); |
572 | return removedAttr; |
573 | } |
574 | Attribute removeAttr(StringRef name) { |
575 | return removeAttr(StringAttr::get(getContext(), name)); |
576 | } |
577 | |
578 | /// A utility iterator that filters out non-dialect attributes. |
579 | class dialect_attr_iterator |
580 | : public llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator, |
581 | bool (*)(NamedAttribute)> { |
582 | static bool filter(NamedAttribute attr) { |
583 | // Dialect attributes are prefixed by the dialect name, like operations. |
584 | return attr.getName().strref().count('.'); |
585 | } |
586 | |
587 | explicit dialect_attr_iterator(ArrayRef<NamedAttribute>::iterator it, |
588 | ArrayRef<NamedAttribute>::iterator end) |
589 | : llvm::filter_iterator<ArrayRef<NamedAttribute>::iterator, |
590 | bool (*)(NamedAttribute)>(it, end, &filter) {} |
591 | |
592 | // Allow access to the constructor. |
593 | friend Operation; |
594 | }; |
595 | using dialect_attr_range = iterator_range<dialect_attr_iterator>; |
596 | |
597 | /// Return a range corresponding to the dialect attributes for this operation. |
598 | dialect_attr_range getDialectAttrs() { |
599 | auto attrs = getAttrs(); |
600 | return {dialect_attr_iterator(attrs.begin(), attrs.end()), |
601 | dialect_attr_iterator(attrs.end(), attrs.end())}; |
602 | } |
603 | dialect_attr_iterator dialect_attr_begin() { |
604 | auto attrs = getAttrs(); |
605 | return dialect_attr_iterator(attrs.begin(), attrs.end()); |
606 | } |
607 | dialect_attr_iterator dialect_attr_end() { |
608 | auto attrs = getAttrs(); |
609 | return dialect_attr_iterator(attrs.end(), attrs.end()); |
610 | } |
611 | |
612 | /// Set the dialect attributes for this operation, and preserve all inherent. |
613 | template <typename DialectAttrT> |
614 | void setDialectAttrs(DialectAttrT &&dialectAttrs) { |
615 | NamedAttrList attrs; |
616 | attrs.append(std::begin(dialectAttrs), std::end(dialectAttrs)); |
617 | for (auto attr : getAttrs()) |
618 | if (!attr.getName().strref().contains('.')) |
619 | attrs.push_back(attr); |
620 | setAttrs(attrs.getDictionary(getContext())); |
621 | } |
622 | |
623 | /// Sets default attributes on unset attributes. |
624 | void populateDefaultAttrs() { |
625 | NamedAttrList attrs(getAttrDictionary()); |
626 | name.populateDefaultAttrs(attrs); |
627 | setAttrs(attrs.getDictionary(getContext())); |
628 | } |
629 | |
630 | //===--------------------------------------------------------------------===// |
631 | // Blocks |
632 | //===--------------------------------------------------------------------===// |
633 | |
634 | /// Returns the number of regions held by this operation. |
635 | unsigned getNumRegions() { return numRegions; } |
636 | |
637 | /// Returns the regions held by this operation. |
638 | MutableArrayRef<Region> getRegions() { |
639 | // Check the count first, as computing the trailing objects can be slow. |
640 | if (numRegions == 0) |
641 | return MutableArrayRef<Region>(); |
642 | |
643 | auto *regions = getTrailingObjects<Region>(); |
644 | return {regions, numRegions}; |
645 | } |
646 | |
647 | /// Returns the region held by this operation at position 'index'. |
648 | Region &getRegion(unsigned index) { |
649 | assert(index < numRegions && "invalid region index")(static_cast <bool> (index < numRegions && "invalid region index" ) ? void (0) : __assert_fail ("index < numRegions && \"invalid region index\"" , "mlir/include/mlir/IR/Operation.h", 649, __extension__ __PRETTY_FUNCTION__ )); |
650 | return getRegions()[index]; |
651 | } |
652 | |
653 | //===--------------------------------------------------------------------===// |
654 | // Successors |
655 | //===--------------------------------------------------------------------===// |
656 | |
657 | MutableArrayRef<BlockOperand> getBlockOperands() { |
658 | return {getTrailingObjects<BlockOperand>(), numSuccs}; |
659 | } |
660 | |
661 | // Successor iteration. |
662 | using succ_iterator = SuccessorRange::iterator; |
663 | succ_iterator successor_begin() { return getSuccessors().begin(); } |
664 | succ_iterator successor_end() { return getSuccessors().end(); } |
665 | SuccessorRange getSuccessors() { return SuccessorRange(this); } |
666 | |
667 | bool hasSuccessors() { return numSuccs != 0; } |
668 | unsigned getNumSuccessors() { return numSuccs; } |
669 | |
670 | Block *getSuccessor(unsigned index) { |
671 | assert(index < getNumSuccessors())(static_cast <bool> (index < getNumSuccessors()) ? void (0) : __assert_fail ("index < getNumSuccessors()", "mlir/include/mlir/IR/Operation.h" , 671, __extension__ __PRETTY_FUNCTION__)); |
672 | return getBlockOperands()[index].get(); |
673 | } |
674 | void setSuccessor(Block *block, unsigned index); |
675 | |
676 | //===--------------------------------------------------------------------===// |
677 | // Accessors for various properties of operations |
678 | //===--------------------------------------------------------------------===// |
679 | |
680 | /// Attempt to fold this operation with the specified constant operand values |
681 | /// - the elements in "operands" will correspond directly to the operands of |
682 | /// the operation, but may be null if non-constant. If folding is successful, |
683 | /// this fills in the `results` vector. If not, `results` is unspecified. |
684 | LogicalResult fold(ArrayRef<Attribute> operands, |
685 | SmallVectorImpl<OpFoldResult> &results); |
686 | |
687 | /// Returns true if the operation was registered with a particular trait, e.g. |
688 | /// hasTrait<OperandsAreSignlessIntegerLike>(). |
689 | template <template <typename T> class Trait> |
690 | bool hasTrait() { |
691 | return name.hasTrait<Trait>(); |
692 | } |
693 | |
694 | /// Returns true if the operation *might* have the provided trait. This |
695 | /// means that either the operation is unregistered, or it was registered with |
696 | /// the provide trait. |
697 | template <template <typename T> class Trait> |
698 | bool mightHaveTrait() { |
699 | return name.mightHaveTrait<Trait>(); |
700 | } |
701 | |
702 | //===--------------------------------------------------------------------===// |
703 | // Operation Walkers |
704 | //===--------------------------------------------------------------------===// |
705 | |
706 | /// Walk the operation by calling the callback for each nested operation |
707 | /// (including this one), block or region, depending on the callback provided. |
708 | /// The order in which regions, blocks and operations at the same nesting |
709 | /// level are visited (e.g., lexicographical or reverse lexicographical order) |
710 | /// is determined by 'Iterator'. The walk order for enclosing regions, blocks |
711 | /// and operations with respect to their nested ones is specified by 'Order' |
712 | /// (post-order by default). A callback on a block or operation is allowed to |
713 | /// erase that block or operation if either: |
714 | /// * the walk is in post-order, or |
715 | /// * the walk is in pre-order and the walk is skipped after the erasure. |
716 | /// |
717 | /// The callback method can take any of the following forms: |
718 | /// void(Operation*) : Walk all operations opaquely. |
719 | /// * op->walk([](Operation *nestedOp) { ...}); |
720 | /// void(OpT) : Walk all operations of the given derived type. |
721 | /// * op->walk([](ReturnOp returnOp) { ...}); |
722 | /// WalkResult(Operation*|OpT) : Walk operations, but allow for |
723 | /// interruption/skipping. |
724 | /// * op->walk([](... op) { |
725 | /// // Skip the walk of this op based on some invariant. |
726 | /// if (some_invariant) |
727 | /// return WalkResult::skip(); |
728 | /// // Interrupt, i.e cancel, the walk based on some invariant. |
729 | /// if (another_invariant) |
730 | /// return WalkResult::interrupt(); |
731 | /// return WalkResult::advance(); |
732 | /// }); |
733 | template <WalkOrder Order = WalkOrder::PostOrder, |
734 | typename Iterator = ForwardIterator, typename FnT, |
735 | typename RetT = detail::walkResultType<FnT>> |
736 | std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 1, |
737 | RetT> |
738 | walk(FnT &&callback) { |
739 | return detail::walk<Order, Iterator>(this, std::forward<FnT>(callback)); |
740 | } |
741 | |
742 | /// Generic walker with a stage aware callback. Walk the operation by calling |
743 | /// the callback for each nested operation (including this one) N+1 times, |
744 | /// where N is the number of regions attached to that operation. |
745 | /// |
746 | /// The callback method can take any of the following forms: |
747 | /// void(Operation *, const WalkStage &) : Walk all operation opaquely |
748 | /// * op->walk([](Operation *nestedOp, const WalkStage &stage) { ...}); |
749 | /// void(OpT, const WalkStage &) : Walk all operations of the given derived |
750 | /// type. |
751 | /// * op->walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); |
752 | /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, |
753 | /// but allow for interruption/skipping. |
754 | /// * op->walk([](... op, const WalkStage &stage) { |
755 | /// // Skip the walk of this op based on some invariant. |
756 | /// if (some_invariant) |
757 | /// return WalkResult::skip(); |
758 | /// // Interrupt, i.e cancel, the walk based on some invariant. |
759 | /// if (another_invariant) |
760 | /// return WalkResult::interrupt(); |
761 | /// return WalkResult::advance(); |
762 | /// }); |
763 | template <typename FnT, typename RetT = detail::walkResultType<FnT>> |
764 | std::enable_if_t<llvm::function_traits<std::decay_t<FnT>>::num_args == 2, |
765 | RetT> |
766 | walk(FnT &&callback) { |
767 | return detail::walk(this, std::forward<FnT>(callback)); |
768 | } |
769 | |
770 | //===--------------------------------------------------------------------===// |
771 | // Uses |
772 | //===--------------------------------------------------------------------===// |
773 | |
774 | /// Drop all uses of results of this operation. |
775 | void dropAllUses() { |
776 | for (OpResult result : getOpResults()) |
777 | result.dropAllUses(); |
778 | } |
779 | |
780 | using use_iterator = result_range::use_iterator; |
781 | using use_range = result_range::use_range; |
782 | |
783 | use_iterator use_begin() { return getResults().use_begin(); } |
784 | use_iterator use_end() { return getResults().use_end(); } |
785 | |
786 | /// Returns a range of all uses, which is useful for iterating over all uses. |
787 | use_range getUses() { return getResults().getUses(); } |
788 | |
789 | /// Returns true if this operation has exactly one use. |
790 | bool hasOneUse() { return llvm::hasSingleElement(getUses()); } |
791 | |
792 | /// Returns true if this operation has no uses. |
793 | bool use_empty() { return getResults().use_empty(); } |
794 | |
795 | /// Returns true if the results of this operation are used outside of the |
796 | /// given block. |
797 | bool isUsedOutsideOfBlock(Block *block) { |
798 | return llvm::any_of(getOpResults(), [block](OpResult result) { |
799 | return result.isUsedOutsideOfBlock(block); |
800 | }); |
801 | } |
802 | |
803 | //===--------------------------------------------------------------------===// |
804 | // Users |
805 | //===--------------------------------------------------------------------===// |
806 | |
807 | using user_iterator = ValueUserIterator<use_iterator, OpOperand>; |
808 | using user_range = iterator_range<user_iterator>; |
809 | |
810 | user_iterator user_begin() { return user_iterator(use_begin()); } |
811 | user_iterator user_end() { return user_iterator(use_end()); } |
812 | |
813 | /// Returns a range of all users. |
814 | user_range getUsers() { return {user_begin(), user_end()}; } |
815 | |
816 | //===--------------------------------------------------------------------===// |
817 | // Other |
818 | //===--------------------------------------------------------------------===// |
819 | |
820 | /// Emit an error with the op name prefixed, like "'dim' op " which is |
821 | /// convenient for verifiers. |
822 | InFlightDiagnostic emitOpError(const Twine &message = {}); |
823 | |
824 | /// Emit an error about fatal conditions with this operation, reporting up to |
825 | /// any diagnostic handlers that may be listening. |
826 | InFlightDiagnostic emitError(const Twine &message = {}); |
827 | |
828 | /// Emit a warning about this operation, reporting up to any diagnostic |
829 | /// handlers that may be listening. |
830 | InFlightDiagnostic emitWarning(const Twine &message = {}); |
831 | |
832 | /// Emit a remark about this operation, reporting up to any diagnostic |
833 | /// handlers that may be listening. |
834 | InFlightDiagnostic emitRemark(const Twine &message = {}); |
835 | |
836 | /// Returns the properties storage size. |
837 | int getPropertiesStorageSize() const { |
838 | return ((int)propertiesStorageSize) * 8; |
839 | } |
840 | /// Returns the properties storage. |
841 | OpaqueProperties getPropertiesStorage() { |
842 | if (propertiesStorageSize) |
843 | return { |
844 | reinterpret_cast<void *>(getTrailingObjects<detail::OpProperties>())}; |
845 | return {nullptr}; |
846 | } |
847 | OpaqueProperties getPropertiesStorage() const { |
848 | if (propertiesStorageSize) |
849 | return {reinterpret_cast<void *>(const_cast<detail::OpProperties *>( |
850 | getTrailingObjects<detail::OpProperties>()))}; |
851 | return {nullptr}; |
852 | } |
853 | |
854 | /// Return the properties converted to an attribute. |
855 | /// This is expensive, and mostly useful when dealing with unregistered |
856 | /// operation. Returns an empty attribute if no properties are present. |
857 | Attribute getPropertiesAsAttribute(); |
858 | |
859 | /// Set the properties from the provided attribute. |
860 | /// This is an expensive operation that can fail if the attribute is not |
861 | /// matching the expectations of the properties for this operation. This is |
862 | /// mostly useful for unregistered operations or used when parsing the |
863 | /// generic format. An optional diagnostic can be passed in for richer errors. |
864 | LogicalResult setPropertiesFromAttribute(Attribute attr, |
865 | InFlightDiagnostic *diagnostic); |
866 | |
867 | /// Copy properties from an existing other properties object. The two objects |
868 | /// must be the same type. |
869 | void copyProperties(OpaqueProperties rhs); |
870 | |
871 | /// Compute a hash for the op properties (if any). |
872 | llvm::hash_code hashProperties(); |
873 | |
874 | private: |
875 | //===--------------------------------------------------------------------===// |
876 | // Ordering |
877 | //===--------------------------------------------------------------------===// |
878 | |
879 | /// This value represents an invalid index ordering for an operation within a |
880 | /// block. |
881 | static constexpr unsigned kInvalidOrderIdx = -1; |
882 | |
883 | /// This value represents the stride to use when computing a new order for an |
884 | /// operation. |
885 | static constexpr unsigned kOrderStride = 5; |
886 | |
887 | /// Update the order index of this operation of this operation if necessary, |
888 | /// potentially recomputing the order of the parent block. |
889 | void updateOrderIfNecessary(); |
890 | |
891 | /// Returns true if this operation has a valid order. |
892 | bool hasValidOrder() { return orderIndex != kInvalidOrderIdx; } |
893 | |
894 | private: |
895 | Operation(Location location, OperationName name, unsigned numResults, |
896 | unsigned numSuccessors, unsigned numRegions, |
897 | int propertiesStorageSize, DictionaryAttr attributes, |
898 | OpaqueProperties properties, bool hasOperandStorage); |
899 | |
900 | // Operations are deleted through the destroy() member because they are |
901 | // allocated with malloc. |
902 | ~Operation(); |
903 | |
904 | /// Returns the additional size necessary for allocating the given objects |
905 | /// before an Operation in-memory. |
906 | static size_t prefixAllocSize(unsigned numOutOfLineResults, |
907 | unsigned numInlineResults) { |
908 | return sizeof(detail::OutOfLineOpResult) * numOutOfLineResults + |
909 | sizeof(detail::InlineOpResult) * numInlineResults; |
910 | } |
911 | /// Returns the additional size allocated before this Operation in-memory. |
912 | size_t prefixAllocSize() { |
913 | unsigned numResults = getNumResults(); |
914 | unsigned numOutOfLineResults = OpResult::getNumTrailing(numResults); |
915 | unsigned numInlineResults = OpResult::getNumInline(numResults); |
916 | return prefixAllocSize(numOutOfLineResults, numInlineResults); |
917 | } |
918 | |
919 | /// Returns the operand storage object. |
920 | detail::OperandStorage &getOperandStorage() { |
921 | assert(hasOperandStorage && "expected operation to have operand storage")(static_cast <bool> (hasOperandStorage && "expected operation to have operand storage" ) ? void (0) : __assert_fail ("hasOperandStorage && \"expected operation to have operand storage\"" , "mlir/include/mlir/IR/Operation.h", 921, __extension__ __PRETTY_FUNCTION__ )); |
922 | return *getTrailingObjects<detail::OperandStorage>(); |
923 | } |
924 | |
925 | /// Returns a pointer to the use list for the given out-of-line result. |
926 | detail::OutOfLineOpResult *getOutOfLineOpResult(unsigned resultNumber) { |
927 | // Out-of-line results are stored in reverse order after (before in memory) |
928 | // the inline results. |
929 | return reinterpret_cast<detail::OutOfLineOpResult *>(getInlineOpResult( |
930 | detail::OpResultImpl::getMaxInlineResults() - 1)) - |
931 | ++resultNumber; |
932 | } |
933 | |
934 | /// Returns a pointer to the use list for the given inline result. |
935 | detail::InlineOpResult *getInlineOpResult(unsigned resultNumber) { |
936 | // Inline results are stored in reverse order before the operation in |
937 | // memory. |
938 | return reinterpret_cast<detail::InlineOpResult *>(this) - ++resultNumber; |
939 | } |
940 | |
941 | /// Returns a pointer to the use list for the given result, which may be |
942 | /// either inline or out-of-line. |
943 | detail::OpResultImpl *getOpResultImpl(unsigned resultNumber) { |
944 | assert(resultNumber < getNumResults() &&(static_cast <bool> (resultNumber < getNumResults() && "Result number is out of range for operation") ? void (0) : __assert_fail ("resultNumber < getNumResults() && \"Result number is out of range for operation\"" , "mlir/include/mlir/IR/Operation.h", 945, __extension__ __PRETTY_FUNCTION__ )) |
945 | "Result number is out of range for operation")(static_cast <bool> (resultNumber < getNumResults() && "Result number is out of range for operation") ? void (0) : __assert_fail ("resultNumber < getNumResults() && \"Result number is out of range for operation\"" , "mlir/include/mlir/IR/Operation.h", 945, __extension__ __PRETTY_FUNCTION__ )); |
946 | unsigned maxInlineResults = detail::OpResultImpl::getMaxInlineResults(); |
947 | if (resultNumber < maxInlineResults) |
948 | return getInlineOpResult(resultNumber); |
949 | return getOutOfLineOpResult(resultNumber - maxInlineResults); |
950 | } |
951 | |
952 | /// Provide a 'getParent' method for ilist_node_with_parent methods. |
953 | /// We mark it as a const function because ilist_node_with_parent specifically |
954 | /// requires a 'getParent() const' method. Once ilist_node removes this |
955 | /// constraint, we should drop the const to fit the rest of the MLIR const |
956 | /// model. |
957 | Block *getParent() const { return block; } |
958 | |
959 | /// Expose a few methods explicitly for the debugger to call for |
960 | /// visualization. |
961 | #ifndef NDEBUG |
962 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) operand_range debug_getOperands() { return getOperands(); } |
963 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) result_range debug_getResults() { return getResults(); } |
964 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) SuccessorRange debug_getSuccessors() { |
965 | return getSuccessors(); |
966 | } |
967 | LLVM_DUMP_METHOD__attribute__((noinline)) __attribute__((__used__)) MutableArrayRef<Region> debug_getRegions() { |
968 | return getRegions(); |
969 | } |
970 | #endif |
971 | |
972 | /// The operation block that contains this operation. |
973 | Block *block = nullptr; |
974 | |
975 | /// This holds information about the source location the operation was defined |
976 | /// or derived from. |
977 | Location location; |
978 | |
979 | /// Relative order of this operation in its parent block. Used for |
980 | /// O(1) local dominance checks between operations. |
981 | mutable unsigned orderIndex = 0; |
982 | |
983 | const unsigned numResults; |
984 | const unsigned numSuccs; |
985 | const unsigned numRegions : 23; |
986 | |
987 | /// This bit signals whether this operation has an operand storage or not. The |
988 | /// operand storage may be elided for operations that are known to never have |
989 | /// operands. |
990 | bool hasOperandStorage : 1; |
991 | |
992 | /// The size of the storage for properties (if any), divided by 8: since the |
993 | /// Properties storage will always be rounded up to the next multiple of 8 we |
994 | /// save some bits here. |
995 | unsigned char propertiesStorageSize : 8; |
996 | /// This is the maximum size we support to allocate properties inline with an |
997 | /// operation: this must match the bitwidth above. |
998 | static constexpr int64_t propertiesCapacity = 8 * 256; |
999 | |
1000 | /// This holds the name of the operation. |
1001 | OperationName name; |
1002 | |
1003 | /// This holds general named attributes for the operation. |
1004 | DictionaryAttr attrs; |
1005 | |
1006 | // allow ilist_traits access to 'block' field. |
1007 | friend struct llvm::ilist_traits<Operation>; |
1008 | |
1009 | // allow block to access the 'orderIndex' field. |
1010 | friend class Block; |
1011 | |
1012 | // allow value to access the 'ResultStorage' methods. |
1013 | friend class Value; |
1014 | |
1015 | // allow ilist_node_with_parent to access the 'getParent' method. |
1016 | friend class llvm::ilist_node_with_parent<Operation, Block>; |
1017 | |
1018 | // This stuff is used by the TrailingObjects template. |
1019 | friend llvm::TrailingObjects<Operation, detail::OperandStorage, |
1020 | detail::OpProperties, BlockOperand, Region, |
1021 | OpOperand>; |
1022 | size_t numTrailingObjects(OverloadToken<detail::OperandStorage>) const { |
1023 | return hasOperandStorage ? 1 : 0; |
1024 | } |
1025 | size_t numTrailingObjects(OverloadToken<BlockOperand>) const { |
1026 | return numSuccs; |
1027 | } |
1028 | size_t numTrailingObjects(OverloadToken<Region>) const { return numRegions; } |
1029 | size_t numTrailingObjects(OverloadToken<detail::OpProperties>) const { |
1030 | return getPropertiesStorageSize(); |
1031 | } |
1032 | }; |
1033 | |
1034 | inline raw_ostream &operator<<(raw_ostream &os, const Operation &op) { |
1035 | const_cast<Operation &>(op).print(os, OpPrintingFlags().useLocalScope()); |
1036 | return os; |
1037 | } |
1038 | |
1039 | } // namespace mlir |
1040 | |
1041 | namespace llvm { |
1042 | /// Cast from an (const) Operation * to a derived operation type. |
1043 | template <typename T> |
1044 | struct CastInfo<T, ::mlir::Operation *> |
1045 | : public ValueFromPointerCast<T, ::mlir::Operation, |
1046 | CastInfo<T, ::mlir::Operation *>> { |
1047 | static bool isPossible(::mlir::Operation *op) { return T::classof(op); } |
1048 | }; |
1049 | template <typename T> |
1050 | struct CastInfo<T, const ::mlir::Operation *> |
1051 | : public ConstStrippingForwardingCast<T, const ::mlir::Operation *, |
1052 | CastInfo<T, ::mlir::Operation *>> {}; |
1053 | |
1054 | /// Cast from an (const) Operation & to a derived operation type. |
1055 | template <typename T> |
1056 | struct CastInfo<T, ::mlir::Operation> |
1057 | : public NullableValueCastFailed<T>, |
1058 | public DefaultDoCastIfPossible<T, ::mlir::Operation &, |
1059 | CastInfo<T, ::mlir::Operation>> { |
1060 | // Provide isPossible here because here we have the const-stripping from |
1061 | // ConstStrippingCast. |
1062 | static bool isPossible(::mlir::Operation &val) { return T::classof(&val); } |
1063 | static T doCast(::mlir::Operation &val) { return T(&val); } |
1064 | }; |
1065 | template <typename T> |
1066 | struct CastInfo<T, const ::mlir::Operation> |
1067 | : public ConstStrippingForwardingCast<T, const ::mlir::Operation, |
1068 | CastInfo<T, ::mlir::Operation>> {}; |
1069 | |
1070 | /// Cast (const) Operation * to itself. This is helpful to avoid SFINAE in |
1071 | /// templated implementations that should work on both base and derived |
1072 | /// operation types. |
1073 | template <> |
1074 | struct CastInfo<::mlir::Operation *, ::mlir::Operation *> |
1075 | : public NullableValueCastFailed<::mlir::Operation *>, |
1076 | public DefaultDoCastIfPossible< |
1077 | ::mlir::Operation *, ::mlir::Operation *, |
1078 | CastInfo<::mlir::Operation *, ::mlir::Operation *>> { |
1079 | static bool isPossible(::mlir::Operation *op) { return true; } |
1080 | static ::mlir::Operation *doCast(::mlir::Operation *op) { return op; } |
1081 | }; |
1082 | template <> |
1083 | struct CastInfo<const ::mlir::Operation *, const ::mlir::Operation *> |
1084 | : public ConstStrippingForwardingCast< |
1085 | const ::mlir::Operation *, const ::mlir::Operation *, |
1086 | CastInfo<::mlir::Operation *, ::mlir::Operation *>> {}; |
1087 | } // namespace llvm |
1088 | |
1089 | #endif // MLIR_IR_OPERATION_H |
1 | /*===- TableGen'erated file -------------------------------------*- C++ -*-===*\ |
2 | |* *| |
3 | |* Op Declarations *| |
4 | |* *| |
5 | |* Automatically generated file, do not edit! *| |
6 | |* *| |
7 | \*===----------------------------------------------------------------------===*/ |
8 | |
9 | #if defined(GET_OP_CLASSES) || defined(GET_OP_FWD_DEFINES) |
10 | #undef GET_OP_FWD_DEFINES |
11 | namespace mlir { |
12 | namespace arith { |
13 | class AddFOp; |
14 | } // namespace arith |
15 | } // namespace mlir |
16 | namespace mlir { |
17 | namespace arith { |
18 | class AddIOp; |
19 | } // namespace arith |
20 | } // namespace mlir |
21 | namespace mlir { |
22 | namespace arith { |
23 | class AddUIExtendedOp; |
24 | } // namespace arith |
25 | } // namespace mlir |
26 | namespace mlir { |
27 | namespace arith { |
28 | class AndIOp; |
29 | } // namespace arith |
30 | } // namespace mlir |
31 | namespace mlir { |
32 | namespace arith { |
33 | class BitcastOp; |
34 | } // namespace arith |
35 | } // namespace mlir |
36 | namespace mlir { |
37 | namespace arith { |
38 | class CeilDivSIOp; |
39 | } // namespace arith |
40 | } // namespace mlir |
41 | namespace mlir { |
42 | namespace arith { |
43 | class CeilDivUIOp; |
44 | } // namespace arith |
45 | } // namespace mlir |
46 | namespace mlir { |
47 | namespace arith { |
48 | class CmpFOp; |
49 | } // namespace arith |
50 | } // namespace mlir |
51 | namespace mlir { |
52 | namespace arith { |
53 | class CmpIOp; |
54 | } // namespace arith |
55 | } // namespace mlir |
56 | namespace mlir { |
57 | namespace arith { |
58 | class ConstantOp; |
59 | } // namespace arith |
60 | } // namespace mlir |
61 | namespace mlir { |
62 | namespace arith { |
63 | class DivFOp; |
64 | } // namespace arith |
65 | } // namespace mlir |
66 | namespace mlir { |
67 | namespace arith { |
68 | class DivSIOp; |
69 | } // namespace arith |
70 | } // namespace mlir |
71 | namespace mlir { |
72 | namespace arith { |
73 | class DivUIOp; |
74 | } // namespace arith |
75 | } // namespace mlir |
76 | namespace mlir { |
77 | namespace arith { |
78 | class ExtFOp; |
79 | } // namespace arith |
80 | } // namespace mlir |
81 | namespace mlir { |
82 | namespace arith { |
83 | class ExtSIOp; |
84 | } // namespace arith |
85 | } // namespace mlir |
86 | namespace mlir { |
87 | namespace arith { |
88 | class ExtUIOp; |
89 | } // namespace arith |
90 | } // namespace mlir |
91 | namespace mlir { |
92 | namespace arith { |
93 | class FPToSIOp; |
94 | } // namespace arith |
95 | } // namespace mlir |
96 | namespace mlir { |
97 | namespace arith { |
98 | class FPToUIOp; |
99 | } // namespace arith |
100 | } // namespace mlir |
101 | namespace mlir { |
102 | namespace arith { |
103 | class FloorDivSIOp; |
104 | } // namespace arith |
105 | } // namespace mlir |
106 | namespace mlir { |
107 | namespace arith { |
108 | class IndexCastOp; |
109 | } // namespace arith |
110 | } // namespace mlir |
111 | namespace mlir { |
112 | namespace arith { |
113 | class IndexCastUIOp; |
114 | } // namespace arith |
115 | } // namespace mlir |
116 | namespace mlir { |
117 | namespace arith { |
118 | class MaxFOp; |
119 | } // namespace arith |
120 | } // namespace mlir |
121 | namespace mlir { |
122 | namespace arith { |
123 | class MaxSIOp; |
124 | } // namespace arith |
125 | } // namespace mlir |
126 | namespace mlir { |
127 | namespace arith { |
128 | class MaxUIOp; |
129 | } // namespace arith |
130 | } // namespace mlir |
131 | namespace mlir { |
132 | namespace arith { |
133 | class MinFOp; |
134 | } // namespace arith |
135 | } // namespace mlir |
136 | namespace mlir { |
137 | namespace arith { |
138 | class MinSIOp; |
139 | } // namespace arith |
140 | } // namespace mlir |
141 | namespace mlir { |
142 | namespace arith { |
143 | class MinUIOp; |
144 | } // namespace arith |
145 | } // namespace mlir |
146 | namespace mlir { |
147 | namespace arith { |
148 | class MulFOp; |
149 | } // namespace arith |
150 | } // namespace mlir |
151 | namespace mlir { |
152 | namespace arith { |
153 | class MulIOp; |
154 | } // namespace arith |
155 | } // namespace mlir |
156 | namespace mlir { |
157 | namespace arith { |
158 | class MulSIExtendedOp; |
159 | } // namespace arith |
160 | } // namespace mlir |
161 | namespace mlir { |
162 | namespace arith { |
163 | class MulUIExtendedOp; |
164 | } // namespace arith |
165 | } // namespace mlir |
166 | namespace mlir { |
167 | namespace arith { |
168 | class NegFOp; |
169 | } // namespace arith |
170 | } // namespace mlir |
171 | namespace mlir { |
172 | namespace arith { |
173 | class OrIOp; |
174 | } // namespace arith |
175 | } // namespace mlir |
176 | namespace mlir { |
177 | namespace arith { |
178 | class RemFOp; |
179 | } // namespace arith |
180 | } // namespace mlir |
181 | namespace mlir { |
182 | namespace arith { |
183 | class RemSIOp; |
184 | } // namespace arith |
185 | } // namespace mlir |
186 | namespace mlir { |
187 | namespace arith { |
188 | class RemUIOp; |
189 | } // namespace arith |
190 | } // namespace mlir |
191 | namespace mlir { |
192 | namespace arith { |
193 | class SIToFPOp; |
194 | } // namespace arith |
195 | } // namespace mlir |
196 | namespace mlir { |
197 | namespace arith { |
198 | class ShLIOp; |
199 | } // namespace arith |
200 | } // namespace mlir |
201 | namespace mlir { |
202 | namespace arith { |
203 | class ShRSIOp; |
204 | } // namespace arith |
205 | } // namespace mlir |
206 | namespace mlir { |
207 | namespace arith { |
208 | class ShRUIOp; |
209 | } // namespace arith |
210 | } // namespace mlir |
211 | namespace mlir { |
212 | namespace arith { |
213 | class SubFOp; |
214 | } // namespace arith |
215 | } // namespace mlir |
216 | namespace mlir { |
217 | namespace arith { |
218 | class SubIOp; |
219 | } // namespace arith |
220 | } // namespace mlir |
221 | namespace mlir { |
222 | namespace arith { |
223 | class TruncFOp; |
224 | } // namespace arith |
225 | } // namespace mlir |
226 | namespace mlir { |
227 | namespace arith { |
228 | class TruncIOp; |
229 | } // namespace arith |
230 | } // namespace mlir |
231 | namespace mlir { |
232 | namespace arith { |
233 | class UIToFPOp; |
234 | } // namespace arith |
235 | } // namespace mlir |
236 | namespace mlir { |
237 | namespace arith { |
238 | class XOrIOp; |
239 | } // namespace arith |
240 | } // namespace mlir |
241 | namespace mlir { |
242 | namespace arith { |
243 | class SelectOp; |
244 | } // namespace arith |
245 | } // namespace mlir |
246 | #endif |
247 | |
248 | #ifdef GET_OP_CLASSES |
249 | #undef GET_OP_CLASSES |
250 | |
251 | |
252 | //===----------------------------------------------------------------------===// |
253 | // Local Utility Method Definitions |
254 | //===----------------------------------------------------------------------===// |
255 | |
256 | namespace mlir { |
257 | namespace arith { |
258 | |
259 | //===----------------------------------------------------------------------===// |
260 | // ::mlir::arith::AddFOp declarations |
261 | //===----------------------------------------------------------------------===// |
262 | |
263 | namespace detail { |
264 | class AddFOpGenericAdaptorBase { |
265 | public: |
266 | struct Properties { |
267 | using fastmathTy = ::mlir::arith::FastMathFlagsAttr; |
268 | fastmathTy fastmath; |
269 | |
270 | auto getFastmath() { |
271 | auto &propStorage = this->fastmath; |
272 | return propStorage.dyn_cast_or_null<::mlir::arith::FastMathFlagsAttr>(); |
273 | } |
274 | void setFastmath(const ::mlir::arith::FastMathFlagsAttr &propValue) { |
275 | this->fastmath = propValue; |
276 | } |
277 | }; |
278 | protected: |
279 | ::mlir::DictionaryAttr odsAttrs; |
280 | ::std::optional<::mlir::OperationName> odsOpName; |
281 | Properties properties; |
282 | ::mlir::RegionRange odsRegions; |
283 | public: |
284 | AddFOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}); |
285 | |
286 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
287 | const Properties &getProperties() { |
288 | return properties; |
289 | } |
290 | |
291 | ::mlir::DictionaryAttr getAttributes(); |
292 | ::mlir::arith::FastMathFlagsAttr getFastmathAttr(); |
293 | ::mlir::arith::FastMathFlags getFastmath(); |
294 | }; |
295 | } // namespace detail |
296 | template <typename RangeT> |
297 | class AddFOpGenericAdaptor : public detail::AddFOpGenericAdaptorBase { |
298 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
299 | using Base = detail::AddFOpGenericAdaptorBase; |
300 | public: |
301 | AddFOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const Properties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
302 | |
303 | AddFOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddFOpGenericAdaptor(values, attrs, (properties ? *properties.as<Properties *>() : Properties{}), regions) {} |
304 | |
305 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
306 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
307 | } |
308 | |
309 | RangeT getODSOperands(unsigned index) { |
310 | auto valueRange = getODSOperandIndexAndLength(index); |
311 | return {std::next(odsOperands.begin(), valueRange.first), |
312 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
313 | } |
314 | |
315 | ValueT getLhs() { |
316 | return (*getODSOperands(0).begin()); |
317 | } |
318 | |
319 | ValueT getRhs() { |
320 | return (*getODSOperands(1).begin()); |
321 | } |
322 | |
323 | RangeT getOperands() { |
324 | return odsOperands; |
325 | } |
326 | |
327 | private: |
328 | RangeT odsOperands; |
329 | }; |
330 | class AddFOpAdaptor : public AddFOpGenericAdaptor<::mlir::ValueRange> { |
331 | public: |
332 | using AddFOpGenericAdaptor::AddFOpGenericAdaptor; |
333 | AddFOpAdaptor(AddFOp op); |
334 | |
335 | ::mlir::LogicalResult verify(::mlir::Location loc); |
336 | }; |
337 | class AddFOp : public ::mlir::Op<AddFOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::arith::ArithFastMathInterface::Trait, ::mlir::OpTrait::IsCommutative, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { |
338 | public: |
339 | using Op::Op; |
340 | using Op::print; |
341 | using Adaptor = AddFOpAdaptor; |
342 | template <typename RangeT> |
343 | using GenericAdaptor = AddFOpGenericAdaptor<RangeT>; |
344 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
345 | using Properties = FoldAdaptor::Properties; |
346 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
347 | static ::llvm::StringRef attrNames[] = {::llvm::StringRef("fastmath")}; |
348 | return ::llvm::ArrayRef(attrNames); |
349 | } |
350 | |
351 | ::mlir::StringAttr getFastmathAttrName() { |
352 | return getAttributeNameForIndex(0); |
353 | } |
354 | |
355 | static ::mlir::StringAttr getFastmathAttrName(::mlir::OperationName name) { |
356 | return getAttributeNameForIndex(name, 0); |
357 | } |
358 | |
359 | static constexpr ::llvm::StringLiteral getOperationName() { |
360 | return ::llvm::StringLiteral("arith.addf"); |
361 | } |
362 | |
363 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
364 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
365 | ::mlir::Value getLhs(); |
366 | ::mlir::Value getRhs(); |
367 | ::mlir::MutableOperandRange getLhsMutable(); |
368 | ::mlir::MutableOperandRange getRhsMutable(); |
369 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
370 | ::mlir::Operation::result_range getODSResults(unsigned index); |
371 | ::mlir::Value getResult(); |
372 | static ::mlir::LogicalResult setPropertiesFromAttr(Properties &prop, ::mlir::Attribute attr, ::mlir::InFlightDiagnostic *diag); |
373 | static ::mlir::Attribute getPropertiesAsAttr(::mlir::MLIRContext *ctx, const Properties &prop); |
374 | static llvm::hash_code computePropertiesHash(const Properties &prop); |
375 | static std::optional<mlir::Attribute> getInherentAttr(const Properties &prop, llvm::StringRef name); |
376 | static void setInherentAttr(Properties &prop, llvm::StringRef name, mlir::Attribute value); |
377 | static void populateInherentAttrs(const Properties &prop, ::mlir::NamedAttrList &attrs); |
378 | static ::mlir::LogicalResult verifyInherentAttrs(::mlir::OperationName opName, ::mlir::NamedAttrList &attrs, llvm::function_ref<::mlir::InFlightDiagnostic()> getDiag); |
379 | ::mlir::arith::FastMathFlagsAttr getFastmathAttr(); |
380 | ::mlir::arith::FastMathFlags getFastmath(); |
381 | void setFastmathAttr(::mlir::arith::FastMathFlagsAttr attr); |
382 | void setFastmath(::mlir::arith::FastMathFlags attrValue); |
383 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlagsAttr fastmath); |
384 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlagsAttr fastmath); |
385 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlagsAttr fastmath); |
386 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlags fastmath = ::mlir::arith::FastMathFlags::none); |
387 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlags fastmath = ::mlir::arith::FastMathFlags::none); |
388 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs, ::mlir::arith::FastMathFlags fastmath = ::mlir::arith::FastMathFlags::none); |
389 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
390 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
391 | static void populateDefaultProperties(::mlir::OperationName opName, Properties &properties); |
392 | ::mlir::LogicalResult verifyInvariantsImpl(); |
393 | ::mlir::LogicalResult verifyInvariants(); |
394 | ::mlir::OpFoldResult fold(FoldAdaptor adaptor); |
395 | static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes); |
396 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
397 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
398 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
399 | private: |
400 | ::mlir::StringAttr getAttributeNameForIndex(unsigned index) { |
401 | return getAttributeNameForIndex((*this)->getName(), index); |
402 | } |
403 | |
404 | static ::mlir::StringAttr getAttributeNameForIndex(::mlir::OperationName name, unsigned index) { |
405 | assert(index < 1 && "invalid attribute index")(static_cast <bool> (index < 1 && "invalid attribute index" ) ? void (0) : __assert_fail ("index < 1 && \"invalid attribute index\"" , "tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc", 405 , __extension__ __PRETTY_FUNCTION__)); |
406 | assert(name.getStringRef() == getOperationName() && "invalid operation name")(static_cast <bool> (name.getStringRef() == getOperationName () && "invalid operation name") ? void (0) : __assert_fail ("name.getStringRef() == getOperationName() && \"invalid operation name\"" , "tools/mlir/include/mlir/Dialect/Arith/IR/ArithOps.h.inc", 406 , __extension__ __PRETTY_FUNCTION__)); |
407 | return name.getAttributeNames()[index]; |
408 | } |
409 | |
410 | public: |
411 | }; |
412 | } // namespace arith |
413 | } // namespace mlir |
414 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddFOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::AddFOp> { public: static TypeID resolveTypeID () { return id; } private: static SelfOwningTypeID id; }; } } |
415 | |
416 | namespace mlir { |
417 | namespace arith { |
418 | |
419 | //===----------------------------------------------------------------------===// |
420 | // ::mlir::arith::AddIOp declarations |
421 | //===----------------------------------------------------------------------===// |
422 | |
423 | namespace detail { |
424 | class AddIOpGenericAdaptorBase { |
425 | public: |
426 | protected: |
427 | ::mlir::DictionaryAttr odsAttrs; |
428 | ::std::optional<::mlir::OperationName> odsOpName; |
429 | ::mlir::RegionRange odsRegions; |
430 | public: |
431 | AddIOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
432 | |
433 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
434 | ::mlir::DictionaryAttr getAttributes(); |
435 | }; |
436 | } // namespace detail |
437 | template <typename RangeT> |
438 | class AddIOpGenericAdaptor : public detail::AddIOpGenericAdaptorBase { |
439 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
440 | using Base = detail::AddIOpGenericAdaptorBase; |
441 | public: |
442 | AddIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
443 | |
444 | AddIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddIOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
445 | |
446 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
447 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
448 | } |
449 | |
450 | RangeT getODSOperands(unsigned index) { |
451 | auto valueRange = getODSOperandIndexAndLength(index); |
452 | return {std::next(odsOperands.begin(), valueRange.first), |
453 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
454 | } |
455 | |
456 | ValueT getLhs() { |
457 | return (*getODSOperands(0).begin()); |
458 | } |
459 | |
460 | ValueT getRhs() { |
461 | return (*getODSOperands(1).begin()); |
462 | } |
463 | |
464 | RangeT getOperands() { |
465 | return odsOperands; |
466 | } |
467 | |
468 | private: |
469 | RangeT odsOperands; |
470 | }; |
471 | class AddIOpAdaptor : public AddIOpGenericAdaptor<::mlir::ValueRange> { |
472 | public: |
473 | using AddIOpGenericAdaptor::AddIOpGenericAdaptor; |
474 | AddIOpAdaptor(AddIOp op); |
475 | |
476 | ::mlir::LogicalResult verify(::mlir::Location loc); |
477 | }; |
478 | class AddIOp : public ::mlir::Op<AddIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { |
479 | public: |
480 | using Op::Op; |
481 | using Op::print; |
482 | using Adaptor = AddIOpAdaptor; |
483 | template <typename RangeT> |
484 | using GenericAdaptor = AddIOpGenericAdaptor<RangeT>; |
485 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
486 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
487 | return {}; |
488 | } |
489 | |
490 | static constexpr ::llvm::StringLiteral getOperationName() { |
491 | return ::llvm::StringLiteral("arith.addi"); |
492 | } |
493 | |
494 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
495 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
496 | ::mlir::Value getLhs(); |
497 | ::mlir::Value getRhs(); |
498 | ::mlir::MutableOperandRange getLhsMutable(); |
499 | ::mlir::MutableOperandRange getRhsMutable(); |
500 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
501 | ::mlir::Operation::result_range getODSResults(unsigned index); |
502 | ::mlir::Value getResult(); |
503 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs); |
504 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs); |
505 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); |
506 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
507 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
508 | ::mlir::LogicalResult verifyInvariantsImpl(); |
509 | ::mlir::LogicalResult verifyInvariants(); |
510 | static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); |
511 | ::mlir::OpFoldResult fold(FoldAdaptor adaptor); |
512 | static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes); |
513 | void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges); |
514 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
515 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
516 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
517 | public: |
518 | }; |
519 | } // namespace arith |
520 | } // namespace mlir |
521 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddIOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::AddIOp> { public: static TypeID resolveTypeID () { return id; } private: static SelfOwningTypeID id; }; } } |
522 | |
523 | namespace mlir { |
524 | namespace arith { |
525 | |
526 | //===----------------------------------------------------------------------===// |
527 | // ::mlir::arith::AddUIExtendedOp declarations |
528 | //===----------------------------------------------------------------------===// |
529 | |
530 | namespace detail { |
531 | class AddUIExtendedOpGenericAdaptorBase { |
532 | public: |
533 | protected: |
534 | ::mlir::DictionaryAttr odsAttrs; |
535 | ::std::optional<::mlir::OperationName> odsOpName; |
536 | ::mlir::RegionRange odsRegions; |
537 | public: |
538 | AddUIExtendedOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
539 | |
540 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
541 | ::mlir::DictionaryAttr getAttributes(); |
542 | }; |
543 | } // namespace detail |
544 | template <typename RangeT> |
545 | class AddUIExtendedOpGenericAdaptor : public detail::AddUIExtendedOpGenericAdaptorBase { |
546 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
547 | using Base = detail::AddUIExtendedOpGenericAdaptorBase; |
548 | public: |
549 | AddUIExtendedOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
550 | |
551 | AddUIExtendedOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AddUIExtendedOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
552 | |
553 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
554 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
555 | } |
556 | |
557 | RangeT getODSOperands(unsigned index) { |
558 | auto valueRange = getODSOperandIndexAndLength(index); |
559 | return {std::next(odsOperands.begin(), valueRange.first), |
560 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
561 | } |
562 | |
563 | ValueT getLhs() { |
564 | return (*getODSOperands(0).begin()); |
565 | } |
566 | |
567 | ValueT getRhs() { |
568 | return (*getODSOperands(1).begin()); |
569 | } |
570 | |
571 | RangeT getOperands() { |
572 | return odsOperands; |
573 | } |
574 | |
575 | private: |
576 | RangeT odsOperands; |
577 | }; |
578 | class AddUIExtendedOpAdaptor : public AddUIExtendedOpGenericAdaptor<::mlir::ValueRange> { |
579 | public: |
580 | using AddUIExtendedOpGenericAdaptor::AddUIExtendedOpGenericAdaptor; |
581 | AddUIExtendedOpAdaptor(AddUIExtendedOp op); |
582 | |
583 | ::mlir::LogicalResult verify(::mlir::Location loc); |
584 | }; |
585 | class AddUIExtendedOp : public ::mlir::Op<AddUIExtendedOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::NResults<2>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::IsCommutative, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::OpAsmOpInterface::Trait> { |
586 | public: |
587 | using Op::Op; |
588 | using Op::print; |
589 | using Adaptor = AddUIExtendedOpAdaptor; |
590 | template <typename RangeT> |
591 | using GenericAdaptor = AddUIExtendedOpGenericAdaptor<RangeT>; |
592 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
593 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
594 | return {}; |
595 | } |
596 | |
597 | void getAsmResultNames(::mlir::OpAsmSetValueNameFn setNameFn); |
598 | static constexpr ::llvm::StringLiteral getOperationName() { |
599 | return ::llvm::StringLiteral("arith.addui_extended"); |
600 | } |
601 | |
602 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
603 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
604 | ::mlir::Value getLhs(); |
605 | ::mlir::Value getRhs(); |
606 | ::mlir::MutableOperandRange getLhsMutable(); |
607 | ::mlir::MutableOperandRange getRhsMutable(); |
608 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
609 | ::mlir::Operation::result_range getODSResults(unsigned index); |
610 | ::mlir::Value getSum(); |
611 | ::mlir::Value getOverflow(); |
612 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, Value lhs, Value rhs); |
613 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type sum, ::mlir::Type overflow, ::mlir::Value lhs, ::mlir::Value rhs); |
614 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); |
615 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
616 | ::mlir::LogicalResult verifyInvariantsImpl(); |
617 | ::mlir::LogicalResult verifyInvariants(); |
618 | static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); |
619 | ::mlir::LogicalResult fold(FoldAdaptor adaptor, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results); |
620 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
621 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
622 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
623 | public: |
624 | std::optional<SmallVector<int64_t, 4>> getShapeForUnroll(); |
625 | }; |
626 | } // namespace arith |
627 | } // namespace mlir |
628 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AddUIExtendedOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::AddUIExtendedOp> { public: static TypeID resolveTypeID() { return id; } private: static SelfOwningTypeID id; }; } } |
629 | |
630 | namespace mlir { |
631 | namespace arith { |
632 | |
633 | //===----------------------------------------------------------------------===// |
634 | // ::mlir::arith::AndIOp declarations |
635 | //===----------------------------------------------------------------------===// |
636 | |
637 | namespace detail { |
638 | class AndIOpGenericAdaptorBase { |
639 | public: |
640 | protected: |
641 | ::mlir::DictionaryAttr odsAttrs; |
642 | ::std::optional<::mlir::OperationName> odsOpName; |
643 | ::mlir::RegionRange odsRegions; |
644 | public: |
645 | AndIOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
646 | |
647 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
648 | ::mlir::DictionaryAttr getAttributes(); |
649 | }; |
650 | } // namespace detail |
651 | template <typename RangeT> |
652 | class AndIOpGenericAdaptor : public detail::AndIOpGenericAdaptorBase { |
653 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
654 | using Base = detail::AndIOpGenericAdaptorBase; |
655 | public: |
656 | AndIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
657 | |
658 | AndIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : AndIOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
659 | |
660 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
661 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
662 | } |
663 | |
664 | RangeT getODSOperands(unsigned index) { |
665 | auto valueRange = getODSOperandIndexAndLength(index); |
666 | return {std::next(odsOperands.begin(), valueRange.first), |
667 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
668 | } |
669 | |
670 | ValueT getLhs() { |
671 | return (*getODSOperands(0).begin()); |
672 | } |
673 | |
674 | ValueT getRhs() { |
675 | return (*getODSOperands(1).begin()); |
676 | } |
677 | |
678 | RangeT getOperands() { |
679 | return odsOperands; |
680 | } |
681 | |
682 | private: |
683 | RangeT odsOperands; |
684 | }; |
685 | class AndIOpAdaptor : public AndIOpGenericAdaptor<::mlir::ValueRange> { |
686 | public: |
687 | using AndIOpGenericAdaptor::AndIOpGenericAdaptor; |
688 | AndIOpAdaptor(AndIOp op); |
689 | |
690 | ::mlir::LogicalResult verify(::mlir::Location loc); |
691 | }; |
692 | class AndIOp : public ::mlir::Op<AndIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsCommutative, ::mlir::OpTrait::IsIdempotent, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { |
693 | public: |
694 | using Op::Op; |
695 | using Op::print; |
696 | using Adaptor = AndIOpAdaptor; |
697 | template <typename RangeT> |
698 | using GenericAdaptor = AndIOpGenericAdaptor<RangeT>; |
699 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
700 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
701 | return {}; |
702 | } |
703 | |
704 | static constexpr ::llvm::StringLiteral getOperationName() { |
705 | return ::llvm::StringLiteral("arith.andi"); |
706 | } |
707 | |
708 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
709 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
710 | ::mlir::Value getLhs(); |
711 | ::mlir::Value getRhs(); |
712 | ::mlir::MutableOperandRange getLhsMutable(); |
713 | ::mlir::MutableOperandRange getRhsMutable(); |
714 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
715 | ::mlir::Operation::result_range getODSResults(unsigned index); |
716 | ::mlir::Value getResult(); |
717 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs); |
718 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs); |
719 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); |
720 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
721 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
722 | ::mlir::LogicalResult verifyInvariantsImpl(); |
723 | ::mlir::LogicalResult verifyInvariants(); |
724 | static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); |
725 | ::mlir::OpFoldResult fold(FoldAdaptor adaptor); |
726 | static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes); |
727 | void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges); |
728 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
729 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
730 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
731 | public: |
732 | }; |
733 | } // namespace arith |
734 | } // namespace mlir |
735 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::AndIOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::AndIOp> { public: static TypeID resolveTypeID () { return id; } private: static SelfOwningTypeID id; }; } } |
736 | |
737 | namespace mlir { |
738 | namespace arith { |
739 | |
740 | //===----------------------------------------------------------------------===// |
741 | // ::mlir::arith::BitcastOp declarations |
742 | //===----------------------------------------------------------------------===// |
743 | |
744 | namespace detail { |
745 | class BitcastOpGenericAdaptorBase { |
746 | public: |
747 | protected: |
748 | ::mlir::DictionaryAttr odsAttrs; |
749 | ::std::optional<::mlir::OperationName> odsOpName; |
750 | ::mlir::RegionRange odsRegions; |
751 | public: |
752 | BitcastOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
753 | |
754 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
755 | ::mlir::DictionaryAttr getAttributes(); |
756 | }; |
757 | } // namespace detail |
758 | template <typename RangeT> |
759 | class BitcastOpGenericAdaptor : public detail::BitcastOpGenericAdaptorBase { |
760 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
761 | using Base = detail::BitcastOpGenericAdaptorBase; |
762 | public: |
763 | BitcastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
764 | |
765 | BitcastOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : BitcastOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
766 | |
767 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
768 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
769 | } |
770 | |
771 | RangeT getODSOperands(unsigned index) { |
772 | auto valueRange = getODSOperandIndexAndLength(index); |
773 | return {std::next(odsOperands.begin(), valueRange.first), |
774 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
775 | } |
776 | |
777 | ValueT getIn() { |
778 | return (*getODSOperands(0).begin()); |
779 | } |
780 | |
781 | RangeT getOperands() { |
782 | return odsOperands; |
783 | } |
784 | |
785 | private: |
786 | RangeT odsOperands; |
787 | }; |
788 | class BitcastOpAdaptor : public BitcastOpGenericAdaptor<::mlir::ValueRange> { |
789 | public: |
790 | using BitcastOpGenericAdaptor::BitcastOpGenericAdaptor; |
791 | BitcastOpAdaptor(BitcastOp op); |
792 | |
793 | ::mlir::LogicalResult verify(::mlir::Location loc); |
794 | }; |
795 | class BitcastOp : public ::mlir::Op<BitcastOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::OneOperand, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::OpTrait::AlwaysSpeculatableImplTrait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultShape, ::mlir::CastOpInterface::Trait, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable> { |
796 | public: |
797 | using Op::Op; |
798 | using Op::print; |
799 | using Adaptor = BitcastOpAdaptor; |
800 | template <typename RangeT> |
801 | using GenericAdaptor = BitcastOpGenericAdaptor<RangeT>; |
802 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
803 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
804 | return {}; |
805 | } |
806 | |
807 | static constexpr ::llvm::StringLiteral getOperationName() { |
808 | return ::llvm::StringLiteral("arith.bitcast"); |
809 | } |
810 | |
811 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
812 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
813 | ::mlir::Value getIn(); |
814 | ::mlir::MutableOperandRange getInMutable(); |
815 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
816 | ::mlir::Operation::result_range getODSResults(unsigned index); |
817 | ::mlir::Value getOut(); |
818 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type out, ::mlir::Value in); |
819 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value in); |
820 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
821 | ::mlir::LogicalResult verifyInvariantsImpl(); |
822 | ::mlir::LogicalResult verifyInvariants(); |
823 | static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context); |
824 | ::mlir::OpFoldResult fold(FoldAdaptor adaptor); |
825 | static bool areCastCompatible(::mlir::TypeRange inputs, ::mlir::TypeRange outputs); |
826 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
827 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
828 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
829 | public: |
830 | }; |
831 | } // namespace arith |
832 | } // namespace mlir |
833 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::BitcastOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::BitcastOp> { public: static TypeID resolveTypeID () { return id; } private: static SelfOwningTypeID id; }; } } |
834 | |
835 | namespace mlir { |
836 | namespace arith { |
837 | |
838 | //===----------------------------------------------------------------------===// |
839 | // ::mlir::arith::CeilDivSIOp declarations |
840 | //===----------------------------------------------------------------------===// |
841 | |
842 | namespace detail { |
843 | class CeilDivSIOpGenericAdaptorBase { |
844 | public: |
845 | protected: |
846 | ::mlir::DictionaryAttr odsAttrs; |
847 | ::std::optional<::mlir::OperationName> odsOpName; |
848 | ::mlir::RegionRange odsRegions; |
849 | public: |
850 | CeilDivSIOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
851 | |
852 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
853 | ::mlir::DictionaryAttr getAttributes(); |
854 | }; |
855 | } // namespace detail |
856 | template <typename RangeT> |
857 | class CeilDivSIOpGenericAdaptor : public detail::CeilDivSIOpGenericAdaptorBase { |
858 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
859 | using Base = detail::CeilDivSIOpGenericAdaptorBase; |
860 | public: |
861 | CeilDivSIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
862 | |
863 | CeilDivSIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CeilDivSIOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
864 | |
865 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
866 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
867 | } |
868 | |
869 | RangeT getODSOperands(unsigned index) { |
870 | auto valueRange = getODSOperandIndexAndLength(index); |
871 | return {std::next(odsOperands.begin(), valueRange.first), |
872 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
873 | } |
874 | |
875 | ValueT getLhs() { |
876 | return (*getODSOperands(0).begin()); |
877 | } |
878 | |
879 | ValueT getRhs() { |
880 | return (*getODSOperands(1).begin()); |
881 | } |
882 | |
883 | RangeT getOperands() { |
884 | return odsOperands; |
885 | } |
886 | |
887 | private: |
888 | RangeT odsOperands; |
889 | }; |
890 | class CeilDivSIOpAdaptor : public CeilDivSIOpGenericAdaptor<::mlir::ValueRange> { |
891 | public: |
892 | using CeilDivSIOpGenericAdaptor::CeilDivSIOpGenericAdaptor; |
893 | CeilDivSIOpAdaptor(CeilDivSIOp op); |
894 | |
895 | ::mlir::LogicalResult verify(::mlir::Location loc); |
896 | }; |
897 | class CeilDivSIOp : public ::mlir::Op<CeilDivSIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { |
898 | public: |
899 | using Op::Op; |
900 | using Op::print; |
901 | using Adaptor = CeilDivSIOpAdaptor; |
902 | template <typename RangeT> |
903 | using GenericAdaptor = CeilDivSIOpGenericAdaptor<RangeT>; |
904 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
905 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
906 | return {}; |
907 | } |
908 | |
909 | static constexpr ::llvm::StringLiteral getOperationName() { |
910 | return ::llvm::StringLiteral("arith.ceildivsi"); |
911 | } |
912 | |
913 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
914 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
915 | ::mlir::Value getLhs(); |
916 | ::mlir::Value getRhs(); |
917 | ::mlir::MutableOperandRange getLhsMutable(); |
918 | ::mlir::MutableOperandRange getRhsMutable(); |
919 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
920 | ::mlir::Operation::result_range getODSResults(unsigned index); |
921 | ::mlir::Value getResult(); |
922 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs); |
923 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs); |
924 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); |
925 | static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
926 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); |
927 | ::mlir::LogicalResult verifyInvariantsImpl(); |
928 | ::mlir::LogicalResult verifyInvariants(); |
929 | ::mlir::OpFoldResult fold(FoldAdaptor adaptor); |
930 | static ::mlir::LogicalResult inferReturnTypes(::mlir::MLIRContext *context, ::std::optional<::mlir::Location> location, ::mlir::ValueRange operands, ::mlir::DictionaryAttr attributes, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions, ::llvm::SmallVectorImpl<::mlir::Type>&inferredReturnTypes); |
931 | void inferResultRanges(::llvm::ArrayRef<::mlir::ConstantIntRanges> argRanges, ::mlir::SetIntRangeFn setResultRanges); |
932 | static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); |
933 | void print(::mlir::OpAsmPrinter &_odsPrinter); |
934 | void getEffects(::llvm::SmallVectorImpl<::mlir::SideEffects::EffectInstance<::mlir::MemoryEffects::Effect>> &effects); |
935 | public: |
936 | /// Interface method for ConditionallySpeculatable. |
937 | Speculation::Speculatability getSpeculatability(); |
938 | }; |
939 | } // namespace arith |
940 | } // namespace mlir |
941 | MLIR_DECLARE_EXPLICIT_TYPE_ID(::mlir::arith::CeilDivSIOp)namespace mlir { namespace detail { template <> class TypeIDResolver < ::mlir::arith::CeilDivSIOp> { public: static TypeID resolveTypeID () { return id; } private: static SelfOwningTypeID id; }; } } |
942 | |
943 | namespace mlir { |
944 | namespace arith { |
945 | |
946 | //===----------------------------------------------------------------------===// |
947 | // ::mlir::arith::CeilDivUIOp declarations |
948 | //===----------------------------------------------------------------------===// |
949 | |
950 | namespace detail { |
951 | class CeilDivUIOpGenericAdaptorBase { |
952 | public: |
953 | protected: |
954 | ::mlir::DictionaryAttr odsAttrs; |
955 | ::std::optional<::mlir::OperationName> odsOpName; |
956 | ::mlir::RegionRange odsRegions; |
957 | public: |
958 | CeilDivUIOpGenericAdaptorBase(::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}); |
959 | |
960 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index, unsigned odsOperandsSize); |
961 | ::mlir::DictionaryAttr getAttributes(); |
962 | }; |
963 | } // namespace detail |
964 | template <typename RangeT> |
965 | class CeilDivUIOpGenericAdaptor : public detail::CeilDivUIOpGenericAdaptorBase { |
966 | using ValueT = ::llvm::detail::ValueOfRange<RangeT>; |
967 | using Base = detail::CeilDivUIOpGenericAdaptorBase; |
968 | public: |
969 | CeilDivUIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs = nullptr, const ::mlir::EmptyProperties &properties = {}, ::mlir::RegionRange regions = {}) : Base(attrs, properties, regions), odsOperands(values) {} |
970 | |
971 | CeilDivUIOpGenericAdaptor(RangeT values, ::mlir::DictionaryAttr attrs, ::mlir::OpaqueProperties properties, ::mlir::RegionRange regions = {}) : CeilDivUIOpGenericAdaptor(values, attrs, (properties ? *properties.as<::mlir::EmptyProperties *>() : ::mlir::EmptyProperties{}), regions) {} |
972 | |
973 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index) { |
974 | return Base::getODSOperandIndexAndLength(index, odsOperands.size()); |
975 | } |
976 | |
977 | RangeT getODSOperands(unsigned index) { |
978 | auto valueRange = getODSOperandIndexAndLength(index); |
979 | return {std::next(odsOperands.begin(), valueRange.first), |
980 | std::next(odsOperands.begin(), valueRange.first + valueRange.second)}; |
981 | } |
982 | |
983 | ValueT getLhs() { |
984 | return (*getODSOperands(0).begin()); |
985 | } |
986 | |
987 | ValueT getRhs() { |
988 | return (*getODSOperands(1).begin()); |
989 | } |
990 | |
991 | RangeT getOperands() { |
992 | return odsOperands; |
993 | } |
994 | |
995 | private: |
996 | RangeT odsOperands; |
997 | }; |
998 | class CeilDivUIOpAdaptor : public CeilDivUIOpGenericAdaptor<::mlir::ValueRange> { |
999 | public: |
1000 | using CeilDivUIOpGenericAdaptor::CeilDivUIOpGenericAdaptor; |
1001 | CeilDivUIOpAdaptor(CeilDivUIOp op); |
1002 | |
1003 | ::mlir::LogicalResult verify(::mlir::Location loc); |
1004 | }; |
1005 | class CeilDivUIOp : public ::mlir::Op<CeilDivUIOp, ::mlir::OpTrait::ZeroRegions, ::mlir::OpTrait::OneResult, ::mlir::OpTrait::OneTypedResult<::mlir::Type>::Impl, ::mlir::OpTrait::ZeroSuccessors, ::mlir::OpTrait::NOperands<2>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::ConditionallySpeculatable::Trait, ::mlir::InferIntRangeInterface::Trait, ::mlir::OpTrait::SameOperandsAndResultType, ::mlir::VectorUnrollOpInterface::Trait, ::mlir::MemoryEffectOpInterface::Trait, ::mlir::OpTrait::Elementwise, ::mlir::OpTrait::Scalarizable, ::mlir::OpTrait::Vectorizable, ::mlir::OpTrait::Tensorizable, ::mlir::InferTypeOpInterface::Trait> { |
1006 | public: |
1007 | using Op::Op; |
1008 | using Op::print; |
1009 | using Adaptor = CeilDivUIOpAdaptor; |
1010 | template <typename RangeT> |
1011 | using GenericAdaptor = CeilDivUIOpGenericAdaptor<RangeT>; |
1012 | using FoldAdaptor = GenericAdaptor<::llvm::ArrayRef<::mlir::Attribute>>; |
1013 | static ::llvm::ArrayRef<::llvm::StringRef> getAttributeNames() { |
1014 | return {}; |
1015 | } |
1016 | |
1017 | static constexpr ::llvm::StringLiteral getOperationName() { |
1018 | return ::llvm::StringLiteral("arith.ceildivui"); |
1019 | } |
1020 | |
1021 | std::pair<unsigned, unsigned> getODSOperandIndexAndLength(unsigned index); |
1022 | ::mlir::Operation::operand_range getODSOperands(unsigned index); |
1023 | ::mlir::Value getLhs(); |
1024 | ::mlir::Value getRhs(); |
1025 | ::mlir::MutableOperandRange getLhsMutable(); |
1026 | ::mlir::MutableOperandRange getRhsMutable(); |
1027 | std::pair<unsigned, unsigned> getODSResultIndexAndLength(unsigned index); |
1028 | ::mlir::Operation::result_range getODSResults(unsigned index); |
1029 | ::mlir::Value getResult(); |
1030 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type result, ::mlir::Value lhs, ::mlir::Value rhs); |
1031 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Value lhs, ::mlir::Value rhs); |
1032 | static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value lhs, ::mlir::Value rhs); |
1033 | static void build(::mlir::OpBuilder &, ::mlir::OperationState |