File: | build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/include/mlir/IR/OpDefinition.h |
Warning: | line 98, column 5 Called C++ object pointer is null |
Press '?' to see keyboard shortcuts
Keyboard shortcuts:
1 | //===- SuperVectorize.cpp - Vectorize Pass Impl ---------------------------===// | ||||
2 | // | ||||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||||
4 | // See https://llvm.org/LICENSE.txt for license information. | ||||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||||
6 | // | ||||
7 | //===----------------------------------------------------------------------===// | ||||
8 | // | ||||
9 | // This file implements vectorization of loops, operations and data types to | ||||
10 | // a target-independent, n-D super-vector abstraction. | ||||
11 | // | ||||
12 | //===----------------------------------------------------------------------===// | ||||
13 | |||||
14 | #include "PassDetail.h" | ||||
15 | #include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h" | ||||
16 | #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" | ||||
17 | #include "mlir/Dialect/Affine/Analysis/NestedMatcher.h" | ||||
18 | #include "mlir/Dialect/Affine/IR/AffineOps.h" | ||||
19 | #include "mlir/Dialect/Affine/Utils.h" | ||||
20 | #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" | ||||
21 | #include "mlir/Dialect/Vector/IR/VectorOps.h" | ||||
22 | #include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||||
23 | #include "mlir/IR/BlockAndValueMapping.h" | ||||
24 | #include "mlir/Support/LLVM.h" | ||||
25 | #include "llvm/ADT/STLExtras.h" | ||||
26 | #include "llvm/Support/Debug.h" | ||||
27 | |||||
28 | using namespace mlir; | ||||
29 | using namespace vector; | ||||
30 | |||||
31 | /// | ||||
32 | /// Implements a high-level vectorization strategy on a Function. | ||||
33 | /// The abstraction used is that of super-vectors, which provide a single, | ||||
34 | /// compact, representation in the vector types, information that is expected | ||||
35 | /// to reduce the impact of the phase ordering problem | ||||
36 | /// | ||||
37 | /// Vector granularity: | ||||
38 | /// =================== | ||||
39 | /// This pass is designed to perform vectorization at a super-vector | ||||
40 | /// granularity. A super-vector is loosely defined as a vector type that is a | ||||
41 | /// multiple of a "good" vector size so the HW can efficiently implement a set | ||||
42 | /// of high-level primitives. Multiple is understood along any dimension; e.g. | ||||
43 | /// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a | ||||
44 | /// vector<8xf32> HW vector. Note that a "good vector size so the HW can | ||||
45 | /// efficiently implement a set of high-level primitives" is not necessarily an | ||||
46 | /// integer multiple of actual hardware registers. We leave details of this | ||||
47 | /// distinction unspecified for now. | ||||
48 | /// | ||||
49 | /// Some may prefer the terminology a "tile of HW vectors". In this case, one | ||||
50 | /// should note that super-vectors implement an "always full tile" abstraction. | ||||
51 | /// They guarantee no partial-tile separation is necessary by relying on a | ||||
52 | /// high-level copy-reshape abstraction that we call vector.transfer. This | ||||
53 | /// copy-reshape operations is also responsible for performing layout | ||||
54 | /// transposition if necessary. In the general case this will require a scoped | ||||
55 | /// allocation in some notional local memory. | ||||
56 | /// | ||||
57 | /// Whatever the mental model one prefers to use for this abstraction, the key | ||||
58 | /// point is that we burn into a single, compact, representation in the vector | ||||
59 | /// types, information that is expected to reduce the impact of the phase | ||||
60 | /// ordering problem. Indeed, a vector type conveys information that: | ||||
61 | /// 1. the associated loops have dependency semantics that do not prevent | ||||
62 | /// vectorization; | ||||
63 | /// 2. the associate loops have been sliced in chunks of static sizes that are | ||||
64 | /// compatible with vector sizes (i.e. similar to unroll-and-jam); | ||||
65 | /// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by | ||||
66 | /// the | ||||
67 | /// vector type and no vectorization hampering transformations can be | ||||
68 | /// applied to them anymore; | ||||
69 | /// 4. the underlying memrefs are accessed in some notional contiguous way | ||||
70 | /// that allows loading into vectors with some amount of spatial locality; | ||||
71 | /// In other words, super-vectorization provides a level of separation of | ||||
72 | /// concern by way of opacity to subsequent passes. This has the effect of | ||||
73 | /// encapsulating and propagating vectorization constraints down the list of | ||||
74 | /// passes until we are ready to lower further. | ||||
75 | /// | ||||
76 | /// For a particular target, a notion of minimal n-d vector size will be | ||||
77 | /// specified and vectorization targets a multiple of those. In the following | ||||
78 | /// paragraph, let "k ." represent "a multiple of", to be understood as a | ||||
79 | /// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes | ||||
80 | /// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc). | ||||
81 | /// | ||||
82 | /// Some non-exhaustive notable super-vector sizes of interest include: | ||||
83 | /// - CPU: vector<k . HW_vector_size>, | ||||
84 | /// vector<k' . core_count x k . HW_vector_size>, | ||||
85 | /// vector<socket_count x k' . core_count x k . HW_vector_size>; | ||||
86 | /// - GPU: vector<k . warp_size>, | ||||
87 | /// vector<k . warp_size x float2>, | ||||
88 | /// vector<k . warp_size x float4>, | ||||
89 | /// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes). | ||||
90 | /// | ||||
91 | /// Loops and operations are emitted that operate on those super-vector shapes. | ||||
92 | /// Subsequent lowering passes will materialize to actual HW vector sizes. These | ||||
93 | /// passes are expected to be (gradually) more target-specific. | ||||
94 | /// | ||||
95 | /// At a high level, a vectorized load in a loop will resemble: | ||||
96 | /// ```mlir | ||||
97 | /// affine.for %i = ? to ? step ? { | ||||
98 | /// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32> | ||||
99 | /// } | ||||
100 | /// ``` | ||||
101 | /// It is the responsibility of the implementation of vector.transfer_read to | ||||
102 | /// materialize vector registers from the original scalar memrefs. A later (more | ||||
103 | /// target-dependent) lowering pass will materialize to actual HW vector sizes. | ||||
104 | /// This lowering may be occur at different times: | ||||
105 | /// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp + | ||||
106 | /// DmaWaitOp + vectorized operations for data transformations and shuffle; | ||||
107 | /// thus opening opportunities for unrolling and pipelining. This is an | ||||
108 | /// instance of library call "whiteboxing"; or | ||||
109 | /// 2. later in the a target-specific lowering pass or hand-written library | ||||
110 | /// call; achieving full separation of concerns. This is an instance of | ||||
111 | /// library call; or | ||||
112 | /// 3. a mix of both, e.g. based on a model. | ||||
113 | /// In the future, these operations will expose a contract to constrain the | ||||
114 | /// search on vectorization patterns and sizes. | ||||
115 | /// | ||||
116 | /// Occurrence of super-vectorization in the compiler flow: | ||||
117 | /// ======================================================= | ||||
118 | /// This is an active area of investigation. We start with 2 remarks to position | ||||
119 | /// super-vectorization in the context of existing ongoing work: LLVM VPLAN | ||||
120 | /// and LLVM SLP Vectorizer. | ||||
121 | /// | ||||
122 | /// LLVM VPLAN: | ||||
123 | /// ----------- | ||||
124 | /// The astute reader may have noticed that in the limit, super-vectorization | ||||
125 | /// can be applied at a similar time and with similar objectives than VPLAN. | ||||
126 | /// For instance, in the case of a traditional, polyhedral compilation-flow (for | ||||
127 | /// instance, the PPCG project uses ISL to provide dependence analysis, | ||||
128 | /// multi-level(scheduling + tiling), lifting footprint to fast memory, | ||||
129 | /// communication synthesis, mapping, register optimizations) and before | ||||
130 | /// unrolling. When vectorization is applied at this *late* level in a typical | ||||
131 | /// polyhedral flow, and is instantiated with actual hardware vector sizes, | ||||
132 | /// super-vectorization is expected to match (or subsume) the type of patterns | ||||
133 | /// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR | ||||
134 | /// is higher level and our implementation should be significantly simpler. Also | ||||
135 | /// note that in this mode, recursive patterns are probably a bit of an overkill | ||||
136 | /// although it is reasonable to expect that mixing a bit of outer loop and | ||||
137 | /// inner loop vectorization + unrolling will provide interesting choices to | ||||
138 | /// MLIR. | ||||
139 | /// | ||||
140 | /// LLVM SLP Vectorizer: | ||||
141 | /// -------------------- | ||||
142 | /// Super-vectorization however is not meant to be usable in a similar fashion | ||||
143 | /// to the SLP vectorizer. The main difference lies in the information that | ||||
144 | /// both vectorizers use: super-vectorization examines contiguity of memory | ||||
145 | /// references along fastest varying dimensions and loops with recursive nested | ||||
146 | /// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on | ||||
147 | /// the other hand, performs flat pattern matching inside a single unrolled loop | ||||
148 | /// body and stitches together pieces of load and store operations into full | ||||
149 | /// 1-D vectors. We envision that the SLP vectorizer is a good way to capture | ||||
150 | /// innermost loop, control-flow dependent patterns that super-vectorization may | ||||
151 | /// not be able to capture easily. In other words, super-vectorization does not | ||||
152 | /// aim at replacing the SLP vectorizer and the two solutions are complementary. | ||||
153 | /// | ||||
154 | /// Ongoing investigations: | ||||
155 | /// ----------------------- | ||||
156 | /// We discuss the following *early* places where super-vectorization is | ||||
157 | /// applicable and touch on the expected benefits and risks . We list the | ||||
158 | /// opportunities in the context of the traditional polyhedral compiler flow | ||||
159 | /// described in PPCG. There are essentially 6 places in the MLIR pass pipeline | ||||
160 | /// we expect to experiment with super-vectorization: | ||||
161 | /// 1. Right after language lowering to MLIR: this is the earliest time where | ||||
162 | /// super-vectorization is expected to be applied. At this level, all the | ||||
163 | /// language/user/library-level annotations are available and can be fully | ||||
164 | /// exploited. Examples include loop-type annotations (such as parallel, | ||||
165 | /// reduction, scan, dependence distance vector, vectorizable) as well as | ||||
166 | /// memory access annotations (such as non-aliasing writes guaranteed, | ||||
167 | /// indirect accesses that are permutations by construction) accesses or | ||||
168 | /// that a particular operation is prescribed atomic by the user. At this | ||||
169 | /// level, anything that enriches what dependence analysis can do should be | ||||
170 | /// aggressively exploited. At this level we are close to having explicit | ||||
171 | /// vector types in the language, except we do not impose that burden on the | ||||
172 | /// programmer/library: we derive information from scalar code + annotations. | ||||
173 | /// 2. After dependence analysis and before polyhedral scheduling: the | ||||
174 | /// information that supports vectorization does not need to be supplied by a | ||||
175 | /// higher level of abstraction. Traditional dependence analysis is available | ||||
176 | /// in MLIR and will be used to drive vectorization and cost models. | ||||
177 | /// | ||||
178 | /// Let's pause here and remark that applying super-vectorization as described | ||||
179 | /// in 1. and 2. presents clear opportunities and risks: | ||||
180 | /// - the opportunity is that vectorization is burned in the type system and | ||||
181 | /// is protected from the adverse effect of loop scheduling, tiling, loop | ||||
182 | /// interchange and all passes downstream. Provided that subsequent passes are | ||||
183 | /// able to operate on vector types; the vector shapes, associated loop | ||||
184 | /// iterator properties, alignment, and contiguity of fastest varying | ||||
185 | /// dimensions are preserved until we lower the super-vector types. We expect | ||||
186 | /// this to significantly rein in on the adverse effects of phase ordering. | ||||
187 | /// - the risks are that a. all passes after super-vectorization have to work | ||||
188 | /// on elemental vector types (not that this is always true, wherever | ||||
189 | /// vectorization is applied) and b. that imposing vectorization constraints | ||||
190 | /// too early may be overall detrimental to loop fusion, tiling and other | ||||
191 | /// transformations because the dependence distances are coarsened when | ||||
192 | /// operating on elemental vector types. For this reason, the pattern | ||||
193 | /// profitability analysis should include a component that also captures the | ||||
194 | /// maximal amount of fusion available under a particular pattern. This is | ||||
195 | /// still at the stage of rough ideas but in this context, search is our | ||||
196 | /// friend as the Tensor Comprehensions and auto-TVM contributions | ||||
197 | /// demonstrated previously. | ||||
198 | /// Bottom-line is we do not yet have good answers for the above but aim at | ||||
199 | /// making it easy to answer such questions. | ||||
200 | /// | ||||
201 | /// Back to our listing, the last places where early super-vectorization makes | ||||
202 | /// sense are: | ||||
203 | /// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known | ||||
204 | /// to improve locality, parallelism and be configurable (e.g. max-fuse, | ||||
205 | /// smart-fuse etc). They can also have adverse effects on contiguity | ||||
206 | /// properties that are required for vectorization but the vector.transfer | ||||
207 | /// copy-reshape-pad-transpose abstraction is expected to help recapture | ||||
208 | /// these properties. | ||||
209 | /// 4. right after polyhedral-style scheduling+tiling; | ||||
210 | /// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent | ||||
211 | /// probably the most promising places because applying tiling achieves a | ||||
212 | /// separation of concerns that allows rescheduling to worry less about | ||||
213 | /// locality and more about parallelism and distribution (e.g. min-fuse). | ||||
214 | /// | ||||
215 | /// At these levels the risk-reward looks different: on one hand we probably | ||||
216 | /// lost a good deal of language/user/library-level annotation; on the other | ||||
217 | /// hand we gained parallelism and locality through scheduling and tiling. | ||||
218 | /// However we probably want to ensure tiling is compatible with the | ||||
219 | /// full-tile-only abstraction used in super-vectorization or suffer the | ||||
220 | /// consequences. It is too early to place bets on what will win but we expect | ||||
221 | /// super-vectorization to be the right abstraction to allow exploring at all | ||||
222 | /// these levels. And again, search is our friend. | ||||
223 | /// | ||||
224 | /// Lastly, we mention it again here: | ||||
225 | /// 6. as a MLIR-based alternative to VPLAN. | ||||
226 | /// | ||||
227 | /// Lowering, unrolling, pipelining: | ||||
228 | /// ================================ | ||||
229 | /// TODO: point to the proper places. | ||||
230 | /// | ||||
231 | /// Algorithm: | ||||
232 | /// ========== | ||||
233 | /// The algorithm proceeds in a few steps: | ||||
234 | /// 1. defining super-vectorization patterns and matching them on the tree of | ||||
235 | /// AffineForOp. A super-vectorization pattern is defined as a recursive | ||||
236 | /// data structures that matches and captures nested, imperfectly-nested | ||||
237 | /// loops that have a. conformable loop annotations attached (e.g. parallel, | ||||
238 | /// reduction, vectorizable, ...) as well as b. all contiguous load/store | ||||
239 | /// operations along a specified minor dimension (not necessarily the | ||||
240 | /// fastest varying) ; | ||||
241 | /// 2. analyzing those patterns for profitability (TODO: and | ||||
242 | /// interference); | ||||
243 | /// 3. then, for each pattern in order: | ||||
244 | /// a. applying iterative rewriting of the loops and all their nested | ||||
245 | /// operations in topological order. Rewriting is implemented by | ||||
246 | /// coarsening the loops and converting operations and operands to their | ||||
247 | /// vector forms. Processing operations in topological order is relatively | ||||
248 | /// simple due to the structured nature of the control-flow | ||||
249 | /// representation. This order ensures that all the operands of a given | ||||
250 | /// operation have been vectorized before the operation itself in a single | ||||
251 | /// traversal, except for operands defined outside of the loop nest. The | ||||
252 | /// algorithm can convert the following operations to their vector form: | ||||
253 | /// * Affine load and store operations are converted to opaque vector | ||||
254 | /// transfer read and write operations. | ||||
255 | /// * Scalar constant operations/operands are converted to vector | ||||
256 | /// constant operations (splat). | ||||
257 | /// * Uniform operands (only induction variables of loops not mapped to | ||||
258 | /// a vector dimension, or operands defined outside of the loop nest | ||||
259 | /// for now) are broadcasted to a vector. | ||||
260 | /// TODO: Support more uniform cases. | ||||
261 | /// * Affine for operations with 'iter_args' are vectorized by | ||||
262 | /// vectorizing their 'iter_args' operands and results. | ||||
263 | /// TODO: Support more complex loops with divergent lbs and/or ubs. | ||||
264 | /// * The remaining operations in the loop nest are vectorized by | ||||
265 | /// widening their scalar types to vector types. | ||||
266 | /// b. if everything under the root AffineForOp in the current pattern | ||||
267 | /// is vectorized properly, we commit that loop to the IR and remove the | ||||
268 | /// scalar loop. Otherwise, we discard the vectorized loop and keep the | ||||
269 | /// original scalar loop. | ||||
270 | /// c. vectorization is applied on the next pattern in the list. Because | ||||
271 | /// pattern interference avoidance is not yet implemented and that we do | ||||
272 | /// not support further vectorizing an already vector load we need to | ||||
273 | /// re-verify that the pattern is still vectorizable. This is expected to | ||||
274 | /// make cost models more difficult to write and is subject to improvement | ||||
275 | /// in the future. | ||||
276 | /// | ||||
277 | /// Choice of loop transformation to support the algorithm: | ||||
278 | /// ======================================================= | ||||
279 | /// The choice of loop transformation to apply for coarsening vectorized loops | ||||
280 | /// is still subject to exploratory tradeoffs. In particular, say we want to | ||||
281 | /// vectorize by a factor 128, we want to transform the following input: | ||||
282 | /// ```mlir | ||||
283 | /// affine.for %i = %M to %N { | ||||
284 | /// %a = affine.load %A[%i] : memref<?xf32> | ||||
285 | /// } | ||||
286 | /// ``` | ||||
287 | /// | ||||
288 | /// Traditionally, one would vectorize late (after scheduling, tiling, | ||||
289 | /// memory promotion etc) say after stripmining (and potentially unrolling in | ||||
290 | /// the case of LLVM's SLP vectorizer): | ||||
291 | /// ```mlir | ||||
292 | /// affine.for %i = floor(%M, 128) to ceil(%N, 128) { | ||||
293 | /// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) { | ||||
294 | /// %a = affine.load %A[%ii] : memref<?xf32> | ||||
295 | /// } | ||||
296 | /// } | ||||
297 | /// ``` | ||||
298 | /// | ||||
299 | /// Instead, we seek to vectorize early and freeze vector types before | ||||
300 | /// scheduling, so we want to generate a pattern that resembles: | ||||
301 | /// ```mlir | ||||
302 | /// affine.for %i = ? to ? step ? { | ||||
303 | /// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> | ||||
304 | /// } | ||||
305 | /// ``` | ||||
306 | /// | ||||
307 | /// i. simply dividing the lower / upper bounds by 128 creates issues | ||||
308 | /// when representing expressions such as ii + 1 because now we only | ||||
309 | /// have access to original values that have been divided. Additional | ||||
310 | /// information is needed to specify accesses at below-128 granularity; | ||||
311 | /// ii. another alternative is to coarsen the loop step but this may have | ||||
312 | /// consequences on dependence analysis and fusability of loops: fusable | ||||
313 | /// loops probably need to have the same step (because we don't want to | ||||
314 | /// stripmine/unroll to enable fusion). | ||||
315 | /// As a consequence, we choose to represent the coarsening using the loop | ||||
316 | /// step for now and reevaluate in the future. Note that we can renormalize | ||||
317 | /// loop steps later if/when we have evidence that they are problematic. | ||||
318 | /// | ||||
319 | /// For the simple strawman example above, vectorizing for a 1-D vector | ||||
320 | /// abstraction of size 128 returns code similar to: | ||||
321 | /// ```mlir | ||||
322 | /// affine.for %i = %M to %N step 128 { | ||||
323 | /// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32> | ||||
324 | /// } | ||||
325 | /// ``` | ||||
326 | /// | ||||
327 | /// Unsupported cases, extensions, and work in progress (help welcome :-) ): | ||||
328 | /// ======================================================================== | ||||
329 | /// 1. lowering to concrete vector types for various HW; | ||||
330 | /// 2. reduction support for n-D vectorization and non-unit steps; | ||||
331 | /// 3. non-effecting padding during vector.transfer_read and filter during | ||||
332 | /// vector.transfer_write; | ||||
333 | /// 4. misalignment support vector.transfer_read / vector.transfer_write | ||||
334 | /// (hopefully without read-modify-writes); | ||||
335 | /// 5. control-flow support; | ||||
336 | /// 6. cost-models, heuristics and search; | ||||
337 | /// 7. Op implementation, extensions and implication on memref views; | ||||
338 | /// 8. many TODOs left around. | ||||
339 | /// | ||||
340 | /// Examples: | ||||
341 | /// ========= | ||||
342 | /// Consider the following Function: | ||||
343 | /// ```mlir | ||||
344 | /// func @vector_add_2d(%M : index, %N : index) -> f32 { | ||||
345 | /// %A = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
346 | /// %B = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
347 | /// %C = alloc (%M, %N) : memref<?x?xf32, 0> | ||||
348 | /// %f1 = arith.constant 1.0 : f32 | ||||
349 | /// %f2 = arith.constant 2.0 : f32 | ||||
350 | /// affine.for %i0 = 0 to %M { | ||||
351 | /// affine.for %i1 = 0 to %N { | ||||
352 | /// // non-scoped %f1 | ||||
353 | /// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0> | ||||
354 | /// } | ||||
355 | /// } | ||||
356 | /// affine.for %i2 = 0 to %M { | ||||
357 | /// affine.for %i3 = 0 to %N { | ||||
358 | /// // non-scoped %f2 | ||||
359 | /// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0> | ||||
360 | /// } | ||||
361 | /// } | ||||
362 | /// affine.for %i4 = 0 to %M { | ||||
363 | /// affine.for %i5 = 0 to %N { | ||||
364 | /// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0> | ||||
365 | /// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0> | ||||
366 | /// %s5 = arith.addf %a5, %b5 : f32 | ||||
367 | /// // non-scoped %f1 | ||||
368 | /// %s6 = arith.addf %s5, %f1 : f32 | ||||
369 | /// // non-scoped %f2 | ||||
370 | /// %s7 = arith.addf %s5, %f2 : f32 | ||||
371 | /// // diamond dependency. | ||||
372 | /// %s8 = arith.addf %s7, %s6 : f32 | ||||
373 | /// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0> | ||||
374 | /// } | ||||
375 | /// } | ||||
376 | /// %c7 = arith.constant 7 : index | ||||
377 | /// %c42 = arith.constant 42 : index | ||||
378 | /// %res = load %C[%c7, %c42] : memref<?x?xf32, 0> | ||||
379 | /// return %res : f32 | ||||
380 | /// } | ||||
381 | /// ``` | ||||
382 | /// | ||||
383 | /// The -affine-super-vectorize pass with the following arguments: | ||||
384 | /// ``` | ||||
385 | /// -affine-super-vectorize="virtual-vector-size=256 test-fastest-varying=0" | ||||
386 | /// ``` | ||||
387 | /// | ||||
388 | /// produces this standard innermost-loop vectorized code: | ||||
389 | /// ```mlir | ||||
390 | /// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { | ||||
391 | /// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
392 | /// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
393 | /// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
394 | /// %cst = arith.constant 1.0 : f32 | ||||
395 | /// %cst_0 = arith.constant 2.0 : f32 | ||||
396 | /// affine.for %i0 = 0 to %arg0 { | ||||
397 | /// affine.for %i1 = 0 to %arg1 step 256 { | ||||
398 | /// %cst_1 = arith.constant dense<vector<256xf32>, 1.0> : | ||||
399 | /// vector<256xf32> | ||||
400 | /// vector.transfer_write %cst_1, %0[%i0, %i1] : | ||||
401 | /// vector<256xf32>, memref<?x?xf32> | ||||
402 | /// } | ||||
403 | /// } | ||||
404 | /// affine.for %i2 = 0 to %arg0 { | ||||
405 | /// affine.for %i3 = 0 to %arg1 step 256 { | ||||
406 | /// %cst_2 = arith.constant dense<vector<256xf32>, 2.0> : | ||||
407 | /// vector<256xf32> | ||||
408 | /// vector.transfer_write %cst_2, %1[%i2, %i3] : | ||||
409 | /// vector<256xf32>, memref<?x?xf32> | ||||
410 | /// } | ||||
411 | /// } | ||||
412 | /// affine.for %i4 = 0 to %arg0 { | ||||
413 | /// affine.for %i5 = 0 to %arg1 step 256 { | ||||
414 | /// %3 = vector.transfer_read %0[%i4, %i5] : | ||||
415 | /// memref<?x?xf32>, vector<256xf32> | ||||
416 | /// %4 = vector.transfer_read %1[%i4, %i5] : | ||||
417 | /// memref<?x?xf32>, vector<256xf32> | ||||
418 | /// %5 = arith.addf %3, %4 : vector<256xf32> | ||||
419 | /// %cst_3 = arith.constant dense<vector<256xf32>, 1.0> : | ||||
420 | /// vector<256xf32> | ||||
421 | /// %6 = arith.addf %5, %cst_3 : vector<256xf32> | ||||
422 | /// %cst_4 = arith.constant dense<vector<256xf32>, 2.0> : | ||||
423 | /// vector<256xf32> | ||||
424 | /// %7 = arith.addf %5, %cst_4 : vector<256xf32> | ||||
425 | /// %8 = arith.addf %7, %6 : vector<256xf32> | ||||
426 | /// vector.transfer_write %8, %2[%i4, %i5] : | ||||
427 | /// vector<256xf32>, memref<?x?xf32> | ||||
428 | /// } | ||||
429 | /// } | ||||
430 | /// %c7 = arith.constant 7 : index | ||||
431 | /// %c42 = arith.constant 42 : index | ||||
432 | /// %9 = load %2[%c7, %c42] : memref<?x?xf32> | ||||
433 | /// return %9 : f32 | ||||
434 | /// } | ||||
435 | /// ``` | ||||
436 | /// | ||||
437 | /// The -affine-super-vectorize pass with the following arguments: | ||||
438 | /// ``` | ||||
439 | /// -affine-super-vectorize="virtual-vector-size=32,256 \ | ||||
440 | /// test-fastest-varying=1,0" | ||||
441 | /// ``` | ||||
442 | /// | ||||
443 | /// produces this more interesting mixed outer-innermost-loop vectorized code: | ||||
444 | /// ```mlir | ||||
445 | /// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 { | ||||
446 | /// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
447 | /// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
448 | /// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32> | ||||
449 | /// %cst = arith.constant 1.0 : f32 | ||||
450 | /// %cst_0 = arith.constant 2.0 : f32 | ||||
451 | /// affine.for %i0 = 0 to %arg0 step 32 { | ||||
452 | /// affine.for %i1 = 0 to %arg1 step 256 { | ||||
453 | /// %cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> : | ||||
454 | /// vector<32x256xf32> | ||||
455 | /// vector.transfer_write %cst_1, %0[%i0, %i1] : | ||||
456 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
457 | /// } | ||||
458 | /// } | ||||
459 | /// affine.for %i2 = 0 to %arg0 step 32 { | ||||
460 | /// affine.for %i3 = 0 to %arg1 step 256 { | ||||
461 | /// %cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> : | ||||
462 | /// vector<32x256xf32> | ||||
463 | /// vector.transfer_write %cst_2, %1[%i2, %i3] : | ||||
464 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
465 | /// } | ||||
466 | /// } | ||||
467 | /// affine.for %i4 = 0 to %arg0 step 32 { | ||||
468 | /// affine.for %i5 = 0 to %arg1 step 256 { | ||||
469 | /// %3 = vector.transfer_read %0[%i4, %i5] : | ||||
470 | /// memref<?x?xf32> vector<32x256xf32> | ||||
471 | /// %4 = vector.transfer_read %1[%i4, %i5] : | ||||
472 | /// memref<?x?xf32>, vector<32x256xf32> | ||||
473 | /// %5 = arith.addf %3, %4 : vector<32x256xf32> | ||||
474 | /// %cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> : | ||||
475 | /// vector<32x256xf32> | ||||
476 | /// %6 = arith.addf %5, %cst_3 : vector<32x256xf32> | ||||
477 | /// %cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> : | ||||
478 | /// vector<32x256xf32> | ||||
479 | /// %7 = arith.addf %5, %cst_4 : vector<32x256xf32> | ||||
480 | /// %8 = arith.addf %7, %6 : vector<32x256xf32> | ||||
481 | /// vector.transfer_write %8, %2[%i4, %i5] : | ||||
482 | /// vector<32x256xf32>, memref<?x?xf32> | ||||
483 | /// } | ||||
484 | /// } | ||||
485 | /// %c7 = arith.constant 7 : index | ||||
486 | /// %c42 = arith.constant 42 : index | ||||
487 | /// %9 = load %2[%c7, %c42] : memref<?x?xf32> | ||||
488 | /// return %9 : f32 | ||||
489 | /// } | ||||
490 | /// ``` | ||||
491 | /// | ||||
492 | /// Of course, much more intricate n-D imperfectly-nested patterns can be | ||||
493 | /// vectorized too and specified in a fully declarative fashion. | ||||
494 | /// | ||||
495 | /// Reduction: | ||||
496 | /// ========== | ||||
497 | /// Vectorizing reduction loops along the reduction dimension is supported if: | ||||
498 | /// - the reduction kind is supported, | ||||
499 | /// - the vectorization is 1-D, and | ||||
500 | /// - the step size of the loop equals to one. | ||||
501 | /// | ||||
502 | /// Comparing to the non-vector-dimension case, two additional things are done | ||||
503 | /// during vectorization of such loops: | ||||
504 | /// - The resulting vector returned from the loop is reduced to a scalar using | ||||
505 | /// `vector.reduce`. | ||||
506 | /// - In some cases a mask is applied to the vector yielded at the end of the | ||||
507 | /// loop to prevent garbage values from being written to the accumulator. | ||||
508 | /// | ||||
509 | /// Reduction vectorization is switched off by default, it can be enabled by | ||||
510 | /// passing a map from loops to reductions to utility functions, or by passing | ||||
511 | /// `vectorize-reductions=true` to the vectorization pass. | ||||
512 | /// | ||||
513 | /// Consider the following example: | ||||
514 | /// ```mlir | ||||
515 | /// func @vecred(%in: memref<512xf32>) -> f32 { | ||||
516 | /// %cst = arith.constant 0.000000e+00 : f32 | ||||
517 | /// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) { | ||||
518 | /// %ld = affine.load %in[%i] : memref<512xf32> | ||||
519 | /// %cos = math.cos %ld : f32 | ||||
520 | /// %add = arith.addf %part_sum, %cos : f32 | ||||
521 | /// affine.yield %add : f32 | ||||
522 | /// } | ||||
523 | /// return %sum : f32 | ||||
524 | /// } | ||||
525 | /// ``` | ||||
526 | /// | ||||
527 | /// The -affine-super-vectorize pass with the following arguments: | ||||
528 | /// ``` | ||||
529 | /// -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 \ | ||||
530 | /// vectorize-reductions=true" | ||||
531 | /// ``` | ||||
532 | /// produces the following output: | ||||
533 | /// ```mlir | ||||
534 | /// #map = affine_map<(d0) -> (-d0 + 500)> | ||||
535 | /// func @vecred(%arg0: memref<512xf32>) -> f32 { | ||||
536 | /// %cst = arith.constant 0.000000e+00 : f32 | ||||
537 | /// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32> | ||||
538 | /// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0) | ||||
539 | /// -> (vector<128xf32>) { | ||||
540 | /// // %2 is the number of iterations left in the original loop. | ||||
541 | /// %2 = affine.apply #map(%arg1) | ||||
542 | /// %3 = vector.create_mask %2 : vector<128xi1> | ||||
543 | /// %cst_1 = arith.constant 0.000000e+00 : f32 | ||||
544 | /// %4 = vector.transfer_read %arg0[%arg1], %cst_1 : | ||||
545 | /// memref<512xf32>, vector<128xf32> | ||||
546 | /// %5 = math.cos %4 : vector<128xf32> | ||||
547 | /// %6 = arith.addf %arg2, %5 : vector<128xf32> | ||||
548 | /// // We filter out the effect of last 12 elements using the mask. | ||||
549 | /// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32> | ||||
550 | /// affine.yield %7 : vector<128xf32> | ||||
551 | /// } | ||||
552 | /// %1 = vector.reduction <add>, %0 : vector<128xf32> into f32 | ||||
553 | /// return %1 : f32 | ||||
554 | /// } | ||||
555 | /// ``` | ||||
556 | /// | ||||
557 | /// Note that because of loop misalignment we needed to apply a mask to prevent | ||||
558 | /// last 12 elements from affecting the final result. The mask is full of ones | ||||
559 | /// in every iteration except for the last one, in which it has the form | ||||
560 | /// `11...100...0` with 116 ones and 12 zeros. | ||||
561 | |||||
562 | #define DEBUG_TYPE"early-vect" "early-vect" | ||||
563 | |||||
564 | using llvm::dbgs; | ||||
565 | |||||
566 | /// Forward declaration. | ||||
567 | static FilterFunctionType | ||||
568 | isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, | ||||
569 | int fastestVaryingMemRefDimension); | ||||
570 | |||||
571 | /// Creates a vectorization pattern from the command line arguments. | ||||
572 | /// Up to 3-D patterns are supported. | ||||
573 | /// If the command line argument requests a pattern of higher order, returns an | ||||
574 | /// empty pattern list which will conservatively result in no vectorization. | ||||
575 | static Optional<NestedPattern> | ||||
576 | makePattern(const DenseSet<Operation *> ¶llelLoops, int vectorRank, | ||||
577 | ArrayRef<int64_t> fastestVaryingPattern) { | ||||
578 | using matcher::For; | ||||
579 | int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0]; | ||||
580 | int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1]; | ||||
581 | int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2]; | ||||
582 | switch (vectorRank) { | ||||
583 | case 1: | ||||
584 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0)); | ||||
585 | case 2: | ||||
586 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0), | ||||
587 | For(isVectorizableLoopPtrFactory(parallelLoops, d1))); | ||||
588 | case 3: | ||||
589 | return For(isVectorizableLoopPtrFactory(parallelLoops, d0), | ||||
590 | For(isVectorizableLoopPtrFactory(parallelLoops, d1), | ||||
591 | For(isVectorizableLoopPtrFactory(parallelLoops, d2)))); | ||||
592 | default: { | ||||
593 | return llvm::None; | ||||
594 | } | ||||
595 | } | ||||
596 | } | ||||
597 | |||||
598 | static NestedPattern &vectorTransferPattern() { | ||||
599 | static auto pattern = matcher::Op([](Operation &op) { | ||||
600 | return isa<vector::TransferReadOp, vector::TransferWriteOp>(op); | ||||
601 | }); | ||||
602 | return pattern; | ||||
603 | } | ||||
604 | |||||
605 | namespace { | ||||
606 | |||||
607 | /// Base state for the vectorize pass. | ||||
608 | /// Command line arguments are preempted by non-empty pass arguments. | ||||
609 | struct Vectorize : public AffineVectorizeBase<Vectorize> { | ||||
610 | Vectorize() = default; | ||||
611 | Vectorize(ArrayRef<int64_t> virtualVectorSize); | ||||
612 | void runOnOperation() override; | ||||
613 | }; | ||||
614 | |||||
615 | } // namespace | ||||
616 | |||||
617 | Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) { | ||||
618 | vectorSizes = virtualVectorSize; | ||||
619 | } | ||||
620 | |||||
621 | static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern, | ||||
622 | unsigned patternDepth, | ||||
623 | VectorizationStrategy *strategy) { | ||||
624 | assert(patternDepth > depthInPattern &&(static_cast <bool> (patternDepth > depthInPattern && "patternDepth is greater than depthInPattern") ? void (0) : __assert_fail ("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 625 , __extension__ __PRETTY_FUNCTION__)) | ||||
625 | "patternDepth is greater than depthInPattern")(static_cast <bool> (patternDepth > depthInPattern && "patternDepth is greater than depthInPattern") ? void (0) : __assert_fail ("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 625 , __extension__ __PRETTY_FUNCTION__)); | ||||
626 | if (patternDepth - depthInPattern > strategy->vectorSizes.size()) { | ||||
627 | // Don't vectorize this loop | ||||
628 | return; | ||||
629 | } | ||||
630 | strategy->loopToVectorDim[loop] = | ||||
631 | strategy->vectorSizes.size() - (patternDepth - depthInPattern); | ||||
632 | } | ||||
633 | |||||
634 | /// Implements a simple strawman strategy for vectorization. | ||||
635 | /// Given a matched pattern `matches` of depth `patternDepth`, this strategy | ||||
636 | /// greedily assigns the fastest varying dimension ** of the vector ** to the | ||||
637 | /// innermost loop in the pattern. | ||||
638 | /// When coupled with a pattern that looks for the fastest varying dimension in | ||||
639 | /// load/store MemRefs, this creates a generic vectorization strategy that works | ||||
640 | /// for any loop in a hierarchy (outermost, innermost or intermediate). | ||||
641 | /// | ||||
642 | /// TODO: In the future we should additionally increase the power of the | ||||
643 | /// profitability analysis along 3 directions: | ||||
644 | /// 1. account for loop extents (both static and parametric + annotations); | ||||
645 | /// 2. account for data layout permutations; | ||||
646 | /// 3. account for impact of vectorization on maximal loop fusion. | ||||
647 | /// Then we can quantify the above to build a cost model and search over | ||||
648 | /// strategies. | ||||
649 | static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches, | ||||
650 | unsigned depthInPattern, | ||||
651 | unsigned patternDepth, | ||||
652 | VectorizationStrategy *strategy) { | ||||
653 | for (auto m : matches) { | ||||
654 | if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1, | ||||
655 | patternDepth, strategy))) { | ||||
656 | return failure(); | ||||
657 | } | ||||
658 | vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern, | ||||
659 | patternDepth, strategy); | ||||
660 | } | ||||
661 | return success(); | ||||
662 | } | ||||
663 | |||||
664 | ///// end TODO: Hoist to a VectorizationStrategy.cpp when appropriate ///// | ||||
665 | |||||
666 | namespace { | ||||
667 | |||||
668 | struct VectorizationState { | ||||
669 | |||||
670 | VectorizationState(MLIRContext *context) : builder(context) {} | ||||
671 | |||||
672 | /// Registers the vector replacement of a scalar operation and its result | ||||
673 | /// values. Both operations must have the same number of results. | ||||
674 | /// | ||||
675 | /// This utility is used to register the replacement for the vast majority of | ||||
676 | /// the vectorized operations. | ||||
677 | /// | ||||
678 | /// Example: | ||||
679 | /// * 'replaced': %0 = arith.addf %1, %2 : f32 | ||||
680 | /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> | ||||
681 | void registerOpVectorReplacement(Operation *replaced, Operation *replacement); | ||||
682 | |||||
683 | /// Registers the vector replacement of a scalar value. The replacement | ||||
684 | /// operation should have a single result, which replaces the scalar value. | ||||
685 | /// | ||||
686 | /// This utility is used to register the vector replacement of block arguments | ||||
687 | /// and operation results which are not directly vectorized (i.e., their | ||||
688 | /// scalar version still exists after vectorization), like uniforms. | ||||
689 | /// | ||||
690 | /// Example: | ||||
691 | /// * 'replaced': block argument or operation outside of the vectorized | ||||
692 | /// loop. | ||||
693 | /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> | ||||
694 | void registerValueVectorReplacement(Value replaced, Operation *replacement); | ||||
695 | |||||
696 | /// Registers the vector replacement of a block argument (e.g., iter_args). | ||||
697 | /// | ||||
698 | /// Example: | ||||
699 | /// * 'replaced': 'iter_arg' block argument. | ||||
700 | /// * 'replacement': vectorized 'iter_arg' block argument. | ||||
701 | void registerBlockArgVectorReplacement(BlockArgument replaced, | ||||
702 | BlockArgument replacement); | ||||
703 | |||||
704 | /// Registers the scalar replacement of a scalar value. 'replacement' must be | ||||
705 | /// scalar. Both values must be block arguments. Operation results should be | ||||
706 | /// replaced using the 'registerOp*' utilitites. | ||||
707 | /// | ||||
708 | /// This utility is used to register the replacement of block arguments | ||||
709 | /// that are within the loop to be vectorized and will continue being scalar | ||||
710 | /// within the vector loop. | ||||
711 | /// | ||||
712 | /// Example: | ||||
713 | /// * 'replaced': induction variable of a loop to be vectorized. | ||||
714 | /// * 'replacement': new induction variable in the new vector loop. | ||||
715 | void registerValueScalarReplacement(BlockArgument replaced, | ||||
716 | BlockArgument replacement); | ||||
717 | |||||
718 | /// Registers the scalar replacement of a scalar result returned from a | ||||
719 | /// reduction loop. 'replacement' must be scalar. | ||||
720 | /// | ||||
721 | /// This utility is used to register the replacement for scalar results of | ||||
722 | /// vectorized reduction loops with iter_args. | ||||
723 | /// | ||||
724 | /// Example 2: | ||||
725 | /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) | ||||
726 | /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into | ||||
727 | /// f32 | ||||
728 | void registerLoopResultScalarReplacement(Value replaced, Value replacement); | ||||
729 | |||||
730 | /// Returns in 'replacedVals' the scalar replacement for values in | ||||
731 | /// 'inputVals'. | ||||
732 | void getScalarValueReplacementsFor(ValueRange inputVals, | ||||
733 | SmallVectorImpl<Value> &replacedVals); | ||||
734 | |||||
735 | /// Erases the scalar loop nest after its successful vectorization. | ||||
736 | void finishVectorizationPattern(AffineForOp rootLoop); | ||||
737 | |||||
738 | // Used to build and insert all the new operations created. The insertion | ||||
739 | // point is preserved and updated along the vectorization process. | ||||
740 | OpBuilder builder; | ||||
741 | |||||
742 | // Maps input scalar operations to their vector counterparts. | ||||
743 | DenseMap<Operation *, Operation *> opVectorReplacement; | ||||
744 | // Maps input scalar values to their vector counterparts. | ||||
745 | BlockAndValueMapping valueVectorReplacement; | ||||
746 | // Maps input scalar values to their new scalar counterparts in the vector | ||||
747 | // loop nest. | ||||
748 | BlockAndValueMapping valueScalarReplacement; | ||||
749 | // Maps results of reduction loops to their new scalar counterparts. | ||||
750 | DenseMap<Value, Value> loopResultScalarReplacement; | ||||
751 | |||||
752 | // Maps the newly created vector loops to their vector dimension. | ||||
753 | DenseMap<Operation *, unsigned> vecLoopToVecDim; | ||||
754 | |||||
755 | // Maps the new vectorized loops to the corresponding vector masks if it is | ||||
756 | // required. | ||||
757 | DenseMap<Operation *, Value> vecLoopToMask; | ||||
758 | |||||
759 | // The strategy drives which loop to vectorize by which amount. | ||||
760 | const VectorizationStrategy *strategy = nullptr; | ||||
761 | |||||
762 | private: | ||||
763 | /// Internal implementation to map input scalar values to new vector or scalar | ||||
764 | /// values. | ||||
765 | void registerValueVectorReplacementImpl(Value replaced, Value replacement); | ||||
766 | void registerValueScalarReplacementImpl(Value replaced, Value replacement); | ||||
767 | }; | ||||
768 | |||||
769 | } // namespace | ||||
770 | |||||
771 | /// Registers the vector replacement of a scalar operation and its result | ||||
772 | /// values. Both operations must have the same number of results. | ||||
773 | /// | ||||
774 | /// This utility is used to register the replacement for the vast majority of | ||||
775 | /// the vectorized operations. | ||||
776 | /// | ||||
777 | /// Example: | ||||
778 | /// * 'replaced': %0 = arith.addf %1, %2 : f32 | ||||
779 | /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32> | ||||
780 | void VectorizationState::registerOpVectorReplacement(Operation *replaced, | ||||
781 | Operation *replacement) { | ||||
782 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ commit vectorized op:\n" ; } } while (false); | ||||
783 | LLVM_DEBUG(dbgs() << *replaced << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *replaced << "\n"; } } while (false); | ||||
784 | LLVM_DEBUG(dbgs() << "into\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "into\n"; } } while (false); | ||||
785 | LLVM_DEBUG(dbgs() << *replacement << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *replacement << "\n"; } } while (false); | ||||
786 | |||||
787 | assert(replaced->getNumResults() == replacement->getNumResults() &&(static_cast <bool> (replaced->getNumResults() == replacement ->getNumResults() && "Unexpected replaced and replacement results" ) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 788 , __extension__ __PRETTY_FUNCTION__)) | ||||
788 | "Unexpected replaced and replacement results")(static_cast <bool> (replaced->getNumResults() == replacement ->getNumResults() && "Unexpected replaced and replacement results" ) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 788 , __extension__ __PRETTY_FUNCTION__)); | ||||
789 | assert(opVectorReplacement.count(replaced) == 0 && "already registered")(static_cast <bool> (opVectorReplacement.count(replaced ) == 0 && "already registered") ? void (0) : __assert_fail ("opVectorReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 789 , __extension__ __PRETTY_FUNCTION__)); | ||||
790 | opVectorReplacement[replaced] = replacement; | ||||
791 | |||||
792 | for (auto resultTuple : | ||||
793 | llvm::zip(replaced->getResults(), replacement->getResults())) | ||||
794 | registerValueVectorReplacementImpl(std::get<0>(resultTuple), | ||||
795 | std::get<1>(resultTuple)); | ||||
796 | } | ||||
797 | |||||
798 | /// Registers the vector replacement of a scalar value. The replacement | ||||
799 | /// operation should have a single result, which replaces the scalar value. | ||||
800 | /// | ||||
801 | /// This utility is used to register the vector replacement of block arguments | ||||
802 | /// and operation results which are not directly vectorized (i.e., their | ||||
803 | /// scalar version still exists after vectorization), like uniforms. | ||||
804 | /// | ||||
805 | /// Example: | ||||
806 | /// * 'replaced': block argument or operation outside of the vectorized loop. | ||||
807 | /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32> | ||||
808 | void VectorizationState::registerValueVectorReplacement( | ||||
809 | Value replaced, Operation *replacement) { | ||||
810 | assert(replacement->getNumResults() == 1 &&(static_cast <bool> (replacement->getNumResults() == 1 && "Expected single-result replacement") ? void (0 ) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 811 , __extension__ __PRETTY_FUNCTION__)) | ||||
811 | "Expected single-result replacement")(static_cast <bool> (replacement->getNumResults() == 1 && "Expected single-result replacement") ? void (0 ) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 811 , __extension__ __PRETTY_FUNCTION__)); | ||||
812 | if (Operation *defOp = replaced.getDefiningOp()) | ||||
813 | registerOpVectorReplacement(defOp, replacement); | ||||
814 | else | ||||
815 | registerValueVectorReplacementImpl(replaced, replacement->getResult(0)); | ||||
816 | } | ||||
817 | |||||
818 | /// Registers the vector replacement of a block argument (e.g., iter_args). | ||||
819 | /// | ||||
820 | /// Example: | ||||
821 | /// * 'replaced': 'iter_arg' block argument. | ||||
822 | /// * 'replacement': vectorized 'iter_arg' block argument. | ||||
823 | void VectorizationState::registerBlockArgVectorReplacement( | ||||
824 | BlockArgument replaced, BlockArgument replacement) { | ||||
825 | registerValueVectorReplacementImpl(replaced, replacement); | ||||
826 | } | ||||
827 | |||||
828 | void VectorizationState::registerValueVectorReplacementImpl(Value replaced, | ||||
829 | Value replacement) { | ||||
830 | assert(!valueVectorReplacement.contains(replaced) &&(static_cast <bool> (!valueVectorReplacement.contains(replaced ) && "Vector replacement already registered") ? void ( 0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831 , __extension__ __PRETTY_FUNCTION__)) | ||||
831 | "Vector replacement already registered")(static_cast <bool> (!valueVectorReplacement.contains(replaced ) && "Vector replacement already registered") ? void ( 0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831 , __extension__ __PRETTY_FUNCTION__)); | ||||
832 | assert(replacement.getType().isa<VectorType>() &&(static_cast <bool> (replacement.getType().isa<VectorType >() && "Expected vector type in vector replacement" ) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 833 , __extension__ __PRETTY_FUNCTION__)) | ||||
833 | "Expected vector type in vector replacement")(static_cast <bool> (replacement.getType().isa<VectorType >() && "Expected vector type in vector replacement" ) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 833 , __extension__ __PRETTY_FUNCTION__)); | ||||
834 | valueVectorReplacement.map(replaced, replacement); | ||||
835 | } | ||||
836 | |||||
837 | /// Registers the scalar replacement of a scalar value. 'replacement' must be | ||||
838 | /// scalar. Both values must be block arguments. Operation results should be | ||||
839 | /// replaced using the 'registerOp*' utilitites. | ||||
840 | /// | ||||
841 | /// This utility is used to register the replacement of block arguments | ||||
842 | /// that are within the loop to be vectorized and will continue being scalar | ||||
843 | /// within the vector loop. | ||||
844 | /// | ||||
845 | /// Example: | ||||
846 | /// * 'replaced': induction variable of a loop to be vectorized. | ||||
847 | /// * 'replacement': new induction variable in the new vector loop. | ||||
848 | void VectorizationState::registerValueScalarReplacement( | ||||
849 | BlockArgument replaced, BlockArgument replacement) { | ||||
850 | registerValueScalarReplacementImpl(replaced, replacement); | ||||
851 | } | ||||
852 | |||||
853 | /// Registers the scalar replacement of a scalar result returned from a | ||||
854 | /// reduction loop. 'replacement' must be scalar. | ||||
855 | /// | ||||
856 | /// This utility is used to register the replacement for scalar results of | ||||
857 | /// vectorized reduction loops with iter_args. | ||||
858 | /// | ||||
859 | /// Example 2: | ||||
860 | /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32) | ||||
861 | /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into f32 | ||||
862 | void VectorizationState::registerLoopResultScalarReplacement( | ||||
863 | Value replaced, Value replacement) { | ||||
864 | assert(isa<AffineForOp>(replaced.getDefiningOp()))(static_cast <bool> (isa<AffineForOp>(replaced.getDefiningOp ())) ? void (0) : __assert_fail ("isa<AffineForOp>(replaced.getDefiningOp())" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 864 , __extension__ __PRETTY_FUNCTION__)); | ||||
865 | assert(loopResultScalarReplacement.count(replaced) == 0 &&(static_cast <bool> (loopResultScalarReplacement.count( replaced) == 0 && "already registered") ? void (0) : __assert_fail ("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 866 , __extension__ __PRETTY_FUNCTION__)) | ||||
866 | "already registered")(static_cast <bool> (loopResultScalarReplacement.count( replaced) == 0 && "already registered") ? void (0) : __assert_fail ("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 866 , __extension__ __PRETTY_FUNCTION__)); | ||||
867 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false) | ||||
868 | "with scalar: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false) | ||||
869 | << replacement)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop " "with scalar: " << replacement; } } while (false); | ||||
870 | loopResultScalarReplacement[replaced] = replacement; | ||||
871 | } | ||||
872 | |||||
873 | void VectorizationState::registerValueScalarReplacementImpl(Value replaced, | ||||
874 | Value replacement) { | ||||
875 | assert(!valueScalarReplacement.contains(replaced) &&(static_cast <bool> (!valueScalarReplacement.contains(replaced ) && "Scalar value replacement already registered") ? void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876 , __extension__ __PRETTY_FUNCTION__)) | ||||
876 | "Scalar value replacement already registered")(static_cast <bool> (!valueScalarReplacement.contains(replaced ) && "Scalar value replacement already registered") ? void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876 , __extension__ __PRETTY_FUNCTION__)); | ||||
877 | assert(!replacement.getType().isa<VectorType>() &&(static_cast <bool> (!replacement.getType().isa<VectorType >() && "Expected scalar type in scalar replacement" ) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 878 , __extension__ __PRETTY_FUNCTION__)) | ||||
878 | "Expected scalar type in scalar replacement")(static_cast <bool> (!replacement.getType().isa<VectorType >() && "Expected scalar type in scalar replacement" ) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 878 , __extension__ __PRETTY_FUNCTION__)); | ||||
879 | valueScalarReplacement.map(replaced, replacement); | ||||
880 | } | ||||
881 | |||||
882 | /// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'. | ||||
883 | void VectorizationState::getScalarValueReplacementsFor( | ||||
884 | ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) { | ||||
885 | for (Value inputVal : inputVals) | ||||
886 | replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal)); | ||||
887 | } | ||||
888 | |||||
889 | /// Erases a loop nest, including all its nested operations. | ||||
890 | static void eraseLoopNest(AffineForOp forOp) { | ||||
891 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n"; } } while (false); | ||||
892 | forOp.erase(); | ||||
893 | } | ||||
894 | |||||
895 | /// Erases the scalar loop nest after its successful vectorization. | ||||
896 | void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) { | ||||
897 | LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] Finalizing vectorization\n" ; } } while (false); | ||||
898 | eraseLoopNest(rootLoop); | ||||
899 | } | ||||
900 | |||||
901 | // Apply 'map' with 'mapOperands' returning resulting values in 'results'. | ||||
902 | static void computeMemoryOpIndices(Operation *op, AffineMap map, | ||||
903 | ValueRange mapOperands, | ||||
904 | VectorizationState &state, | ||||
905 | SmallVectorImpl<Value> &results) { | ||||
906 | for (auto resultExpr : map.getResults()) { | ||||
907 | auto singleResMap = | ||||
908 | AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr); | ||||
909 | auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap, | ||||
910 | mapOperands); | ||||
911 | results.push_back(afOp); | ||||
912 | } | ||||
913 | } | ||||
914 | |||||
915 | /// Returns a FilterFunctionType that can be used in NestedPattern to match a | ||||
916 | /// loop whose underlying load/store accesses are either invariant or all | ||||
917 | // varying along the `fastestVaryingMemRefDimension`. | ||||
918 | static FilterFunctionType | ||||
919 | isVectorizableLoopPtrFactory(const DenseSet<Operation *> ¶llelLoops, | ||||
920 | int fastestVaryingMemRefDimension) { | ||||
921 | return [¶llelLoops, fastestVaryingMemRefDimension](Operation &forOp) { | ||||
922 | auto loop = cast<AffineForOp>(forOp); | ||||
923 | auto parallelIt = parallelLoops.find(loop); | ||||
924 | if (parallelIt == parallelLoops.end()) | ||||
925 | return false; | ||||
926 | int memRefDim = -1; | ||||
927 | auto vectorizableBody = | ||||
928 | isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern()); | ||||
929 | if (!vectorizableBody) | ||||
930 | return false; | ||||
931 | return memRefDim == -1 || fastestVaryingMemRefDimension == -1 || | ||||
932 | memRefDim == fastestVaryingMemRefDimension; | ||||
933 | }; | ||||
934 | } | ||||
935 | |||||
936 | /// Returns the vector type resulting from applying the provided vectorization | ||||
937 | /// strategy on the scalar type. | ||||
938 | static VectorType getVectorType(Type scalarTy, | ||||
939 | const VectorizationStrategy *strategy) { | ||||
940 | assert(!scalarTy.isa<VectorType>() && "Expected scalar type")(static_cast <bool> (!scalarTy.isa<VectorType>() && "Expected scalar type") ? void (0) : __assert_fail ("!scalarTy.isa<VectorType>() && \"Expected scalar type\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 940 , __extension__ __PRETTY_FUNCTION__)); | ||||
941 | return VectorType::get(strategy->vectorSizes, scalarTy); | ||||
942 | } | ||||
943 | |||||
944 | /// Tries to transform a scalar constant into a vector constant. Returns the | ||||
945 | /// vector constant if the scalar type is valid vector element type. Returns | ||||
946 | /// nullptr, otherwise. | ||||
947 | static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp, | ||||
948 | VectorizationState &state) { | ||||
949 | Type scalarTy = constOp.getType(); | ||||
950 | if (!VectorType::isValidElementType(scalarTy)) | ||||
951 | return nullptr; | ||||
952 | |||||
953 | auto vecTy = getVectorType(scalarTy, state.strategy); | ||||
954 | auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue()); | ||||
955 | |||||
956 | OpBuilder::InsertionGuard guard(state.builder); | ||||
957 | Operation *parentOp = state.builder.getInsertionBlock()->getParentOp(); | ||||
958 | // Find the innermost vectorized ancestor loop to insert the vector constant. | ||||
959 | while (parentOp && !state.vecLoopToVecDim.count(parentOp)) | ||||
960 | parentOp = parentOp->getParentOp(); | ||||
961 | assert(parentOp && state.vecLoopToVecDim.count(parentOp) &&(static_cast <bool> (parentOp && state.vecLoopToVecDim .count(parentOp) && isa<AffineForOp>(parentOp) && "Expected a vectorized for op") ? void (0) : __assert_fail ( "parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 962 , __extension__ __PRETTY_FUNCTION__)) | ||||
962 | isa<AffineForOp>(parentOp) && "Expected a vectorized for op")(static_cast <bool> (parentOp && state.vecLoopToVecDim .count(parentOp) && isa<AffineForOp>(parentOp) && "Expected a vectorized for op") ? void (0) : __assert_fail ( "parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 962 , __extension__ __PRETTY_FUNCTION__)); | ||||
963 | auto vecForOp = cast<AffineForOp>(parentOp); | ||||
964 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
965 | auto newConstOp = | ||||
966 | state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr); | ||||
967 | |||||
968 | // Register vector replacement for future uses in the scope. | ||||
969 | state.registerOpVectorReplacement(constOp, newConstOp); | ||||
970 | return newConstOp; | ||||
971 | } | ||||
972 | |||||
973 | /// Creates a constant vector filled with the neutral elements of the given | ||||
974 | /// reduction. The scalar type of vector elements will be taken from | ||||
975 | /// `oldOperand`. | ||||
976 | static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind, | ||||
977 | Value oldOperand, | ||||
978 | VectorizationState &state) { | ||||
979 | Type scalarTy = oldOperand.getType(); | ||||
980 | if (!VectorType::isValidElementType(scalarTy)) | ||||
981 | return nullptr; | ||||
982 | |||||
983 | Attribute valueAttr = getIdentityValueAttr( | ||||
984 | reductionKind, scalarTy, state.builder, oldOperand.getLoc()); | ||||
985 | auto vecTy = getVectorType(scalarTy, state.strategy); | ||||
986 | auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr); | ||||
987 | auto newConstOp = | ||||
988 | state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr); | ||||
989 | |||||
990 | return newConstOp; | ||||
991 | } | ||||
992 | |||||
993 | /// Creates a mask used to filter out garbage elements in the last iteration | ||||
994 | /// of unaligned loops. If a mask is not required then `nullptr` is returned. | ||||
995 | /// The mask will be a vector of booleans representing meaningful vector | ||||
996 | /// elements in the current iteration. It is filled with ones for each iteration | ||||
997 | /// except for the last one, where it has the form `11...100...0` with the | ||||
998 | /// number of ones equal to the number of meaningful elements (i.e. the number | ||||
999 | /// of iterations that would be left in the original loop). | ||||
1000 | static Value createMask(AffineForOp vecForOp, VectorizationState &state) { | ||||
1001 | assert(state.strategy->vectorSizes.size() == 1 &&(static_cast <bool> (state.strategy->vectorSizes.size () == 1 && "Creating a mask non-1-D vectors is not supported." ) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1002 , __extension__ __PRETTY_FUNCTION__)) | ||||
1002 | "Creating a mask non-1-D vectors is not supported.")(static_cast <bool> (state.strategy->vectorSizes.size () == 1 && "Creating a mask non-1-D vectors is not supported." ) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1002 , __extension__ __PRETTY_FUNCTION__)); | ||||
1003 | assert(vecForOp.getStep() == state.strategy->vectorSizes[0] &&(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005 , __extension__ __PRETTY_FUNCTION__)) | ||||
1004 | "Creating a mask for loops with non-unit original step size is not "(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005 , __extension__ __PRETTY_FUNCTION__)) | ||||
1005 | "supported.")(static_cast <bool> (vecForOp.getStep() == state.strategy ->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not " "supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005 , __extension__ __PRETTY_FUNCTION__)); | ||||
1006 | |||||
1007 | // Check if we have already created the mask. | ||||
1008 | if (Value mask = state.vecLoopToMask.lookup(vecForOp)) | ||||
1009 | return mask; | ||||
1010 | |||||
1011 | // If the loop has constant bounds and the original number of iterations is | ||||
1012 | // divisable by the vector size then we don't need a mask. | ||||
1013 | if (vecForOp.hasConstantBounds()) { | ||||
1014 | int64_t originalTripCount = | ||||
1015 | vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound(); | ||||
1016 | if (originalTripCount % vecForOp.getStep() == 0) | ||||
1017 | return nullptr; | ||||
1018 | } | ||||
1019 | |||||
1020 | OpBuilder::InsertionGuard guard(state.builder); | ||||
1021 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
1022 | |||||
1023 | // We generate the mask using the `vector.create_mask` operation which accepts | ||||
1024 | // the number of meaningful elements (i.e. the length of the prefix of 1s). | ||||
1025 | // To compute the number of meaningful elements we subtract the current value | ||||
1026 | // of the iteration variable from the upper bound of the loop. Example: | ||||
1027 | // | ||||
1028 | // // 500 is the upper bound of the loop | ||||
1029 | // #map = affine_map<(d0) -> (500 - d0)> | ||||
1030 | // %elems_left = affine.apply #map(%iv) | ||||
1031 | // %mask = vector.create_mask %elems_left : vector<128xi1> | ||||
1032 | |||||
1033 | Location loc = vecForOp.getLoc(); | ||||
1034 | |||||
1035 | // First we get the upper bound of the loop using `affine.apply` or | ||||
1036 | // `affine.min`. | ||||
1037 | AffineMap ubMap = vecForOp.getUpperBoundMap(); | ||||
1038 | Value ub; | ||||
1039 | if (ubMap.getNumResults() == 1) | ||||
1040 | ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(), | ||||
1041 | vecForOp.getUpperBoundOperands()); | ||||
1042 | else | ||||
1043 | ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(), | ||||
1044 | vecForOp.getUpperBoundOperands()); | ||||
1045 | // Then we compute the number of (original) iterations left in the loop. | ||||
1046 | AffineExpr subExpr = | ||||
1047 | state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1); | ||||
1048 | Value itersLeft = | ||||
1049 | makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr), | ||||
1050 | {ub, vecForOp.getInductionVar()}); | ||||
1051 | // If the affine maps were successfully composed then `ub` is unneeded. | ||||
1052 | if (ub.use_empty()) | ||||
1053 | ub.getDefiningOp()->erase(); | ||||
1054 | // Finally we create the mask. | ||||
1055 | Type maskTy = VectorType::get(state.strategy->vectorSizes, | ||||
1056 | state.builder.getIntegerType(1)); | ||||
1057 | Value mask = | ||||
1058 | state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft); | ||||
1059 | |||||
1060 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false) | ||||
1061 | << itersLeft << "\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false) | ||||
1062 | << mask << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n" << itersLeft << "\n" << mask << "\n" ; } } while (false); | ||||
1063 | |||||
1064 | state.vecLoopToMask[vecForOp] = mask; | ||||
1065 | return mask; | ||||
1066 | } | ||||
1067 | |||||
1068 | /// Returns true if the provided value is vector uniform given the vectorization | ||||
1069 | /// strategy. | ||||
1070 | // TODO: For now, only values that are induction variables of loops not in | ||||
1071 | // `loopToVectorDim` or invariants to all the loops in the vectorization | ||||
1072 | // strategy are considered vector uniforms. | ||||
1073 | static bool isUniformDefinition(Value value, | ||||
1074 | const VectorizationStrategy *strategy) { | ||||
1075 | AffineForOp forOp = getForInductionVarOwner(value); | ||||
1076 | if (forOp && strategy->loopToVectorDim.count(forOp) == 0) | ||||
1077 | return true; | ||||
1078 | |||||
1079 | for (auto loopToDim : strategy->loopToVectorDim) { | ||||
1080 | auto loop = cast<AffineForOp>(loopToDim.first); | ||||
1081 | if (!loop.isDefinedOutsideOfLoop(value)) | ||||
1082 | return false; | ||||
1083 | } | ||||
1084 | return true; | ||||
1085 | } | ||||
1086 | |||||
1087 | /// Generates a broadcast op for the provided uniform value using the | ||||
1088 | /// vectorization strategy in 'state'. | ||||
1089 | static Operation *vectorizeUniform(Value uniformVal, | ||||
1090 | VectorizationState &state) { | ||||
1091 | OpBuilder::InsertionGuard guard(state.builder); | ||||
1092 | Value uniformScalarRepl = | ||||
1093 | state.valueScalarReplacement.lookupOrDefault(uniformVal); | ||||
1094 | state.builder.setInsertionPointAfterValue(uniformScalarRepl); | ||||
1095 | |||||
1096 | auto vectorTy = getVectorType(uniformVal.getType(), state.strategy); | ||||
1097 | auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(), | ||||
1098 | vectorTy, uniformScalarRepl); | ||||
1099 | state.registerValueVectorReplacement(uniformVal, bcastOp); | ||||
1100 | return bcastOp; | ||||
1101 | } | ||||
1102 | |||||
1103 | /// Tries to vectorize a given `operand` by applying the following logic: | ||||
1104 | /// 1. if the defining operation has been already vectorized, `operand` is | ||||
1105 | /// already in the proper vector form; | ||||
1106 | /// 2. if the `operand` is a constant, returns the vectorized form of the | ||||
1107 | /// constant; | ||||
1108 | /// 3. if the `operand` is uniform, returns a vector broadcast of the `op`; | ||||
1109 | /// 4. otherwise, the vectorization of `operand` is not supported. | ||||
1110 | /// Newly created vector operations are registered in `state` as replacement | ||||
1111 | /// for their scalar counterparts. | ||||
1112 | /// In particular this logic captures some of the use cases where definitions | ||||
1113 | /// that are not scoped under the current pattern are needed to vectorize. | ||||
1114 | /// One such example is top level function constants that need to be splatted. | ||||
1115 | /// | ||||
1116 | /// Returns an operand that has been vectorized to match `state`'s strategy if | ||||
1117 | /// vectorization is possible with the above logic. Returns nullptr otherwise. | ||||
1118 | /// | ||||
1119 | /// TODO: handle more complex cases. | ||||
1120 | static Value vectorizeOperand(Value operand, VectorizationState &state) { | ||||
1121 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand; } } while (false); | ||||
1122 | // If this value is already vectorized, we are done. | ||||
1123 | if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) { | ||||
1124 | LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << " -> already vectorized: " << vecRepl; } } while (false); | ||||
1125 | return vecRepl; | ||||
1126 | } | ||||
1127 | |||||
1128 | // An vector operand that is not in the replacement map should never reach | ||||
1129 | // this point. Reaching this point could mean that the code was already | ||||
1130 | // vectorized and we shouldn't try to vectorize already vectorized code. | ||||
1131 | assert(!operand.getType().isa<VectorType>() &&(static_cast <bool> (!operand.getType().isa<VectorType >() && "Vector op not found in replacement map") ? void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1132 , __extension__ __PRETTY_FUNCTION__)) | ||||
1132 | "Vector op not found in replacement map")(static_cast <bool> (!operand.getType().isa<VectorType >() && "Vector op not found in replacement map") ? void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1132 , __extension__ __PRETTY_FUNCTION__)); | ||||
1133 | |||||
1134 | // Vectorize constant. | ||||
1135 | if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) { | ||||
1136 | auto vecConstant = vectorizeConstant(constOp, state); | ||||
1137 | LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> constant: " << vecConstant; } } while (false); | ||||
1138 | return vecConstant.getResult(); | ||||
1139 | } | ||||
1140 | |||||
1141 | // Vectorize uniform values. | ||||
1142 | if (isUniformDefinition(operand, state.strategy)) { | ||||
1143 | Operation *vecUniform = vectorizeUniform(operand, state); | ||||
1144 | LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> uniform: " << * vecUniform; } } while (false); | ||||
1145 | return vecUniform->getResult(0); | ||||
1146 | } | ||||
1147 | |||||
1148 | // Check for unsupported block argument scenarios. A supported block argument | ||||
1149 | // should have been vectorized already. | ||||
1150 | if (!operand.getDefiningOp()) | ||||
1151 | LLVM_DEBUG(dbgs() << "-> unsupported block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> unsupported block argument\n" ; } } while (false); | ||||
1152 | else | ||||
1153 | // Generic unsupported case. | ||||
1154 | LLVM_DEBUG(dbgs() << "-> non-vectorizable\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "-> non-vectorizable\n"; } } while (false); | ||||
1155 | |||||
1156 | return nullptr; | ||||
1157 | } | ||||
1158 | |||||
1159 | /// Vectorizes an affine load with the vectorization strategy in 'state' by | ||||
1160 | /// generating a 'vector.transfer_read' op with the proper permutation map | ||||
1161 | /// inferred from the indices of the load. The new 'vector.transfer_read' is | ||||
1162 | /// registered as replacement of the scalar load. Returns the newly created | ||||
1163 | /// 'vector.transfer_read' if vectorization was successful. Returns nullptr, | ||||
1164 | /// otherwise. | ||||
1165 | static Operation *vectorizeAffineLoad(AffineLoadOp loadOp, | ||||
1166 | VectorizationState &state) { | ||||
1167 | MemRefType memRefType = loadOp.getMemRefType(); | ||||
1168 | Type elementType = memRefType.getElementType(); | ||||
1169 | auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType); | ||||
1170 | |||||
1171 | // Replace map operands with operands from the vector loop nest. | ||||
1172 | SmallVector<Value, 8> mapOperands; | ||||
1173 | state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands); | ||||
1174 | |||||
1175 | // Compute indices for the transfer op. AffineApplyOp's may be generated. | ||||
1176 | SmallVector<Value, 8> indices; | ||||
1177 | indices.reserve(memRefType.getRank()); | ||||
1178 | if (loadOp.getAffineMap() != | ||||
1179 | state.builder.getMultiDimIdentityMap(memRefType.getRank())) | ||||
1180 | computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state, | ||||
1181 | indices); | ||||
1182 | else | ||||
1183 | indices.append(mapOperands.begin(), mapOperands.end()); | ||||
1184 | |||||
1185 | // Compute permutation map using the information of new vector loops. | ||||
1186 | auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), | ||||
1187 | indices, state.vecLoopToVecDim); | ||||
1188 | if (!permutationMap) { | ||||
1189 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n" ; } } while (false); | ||||
1190 | return nullptr; | ||||
1191 | } | ||||
1192 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: " ; } } while (false); | ||||
1193 | LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { permutationMap.print(dbgs()); } } while (false ); | ||||
1194 | |||||
1195 | auto transfer = state.builder.create<vector::TransferReadOp>( | ||||
1196 | loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap); | ||||
1197 | |||||
1198 | // Register replacement for future uses in the scope. | ||||
1199 | state.registerOpVectorReplacement(loadOp, transfer); | ||||
1200 | return transfer; | ||||
1201 | } | ||||
1202 | |||||
1203 | /// Vectorizes an affine store with the vectorization strategy in 'state' by | ||||
1204 | /// generating a 'vector.transfer_write' op with the proper permutation map | ||||
1205 | /// inferred from the indices of the store. The new 'vector.transfer_store' is | ||||
1206 | /// registered as replacement of the scalar load. Returns the newly created | ||||
1207 | /// 'vector.transfer_write' if vectorization was successful. Returns nullptr, | ||||
1208 | /// otherwise. | ||||
1209 | static Operation *vectorizeAffineStore(AffineStoreOp storeOp, | ||||
1210 | VectorizationState &state) { | ||||
1211 | MemRefType memRefType = storeOp.getMemRefType(); | ||||
1212 | Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state); | ||||
1213 | if (!vectorValue) | ||||
1214 | return nullptr; | ||||
1215 | |||||
1216 | // Replace map operands with operands from the vector loop nest. | ||||
1217 | SmallVector<Value, 8> mapOperands; | ||||
1218 | state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands); | ||||
1219 | |||||
1220 | // Compute indices for the transfer op. AffineApplyOp's may be generated. | ||||
1221 | SmallVector<Value, 8> indices; | ||||
1222 | indices.reserve(memRefType.getRank()); | ||||
1223 | if (storeOp.getAffineMap() != | ||||
1224 | state.builder.getMultiDimIdentityMap(memRefType.getRank())) | ||||
1225 | computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state, | ||||
1226 | indices); | ||||
1227 | else | ||||
1228 | indices.append(mapOperands.begin(), mapOperands.end()); | ||||
1229 | |||||
1230 | // Compute permutation map using the information of new vector loops. | ||||
1231 | auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(), | ||||
1232 | indices, state.vecLoopToVecDim); | ||||
1233 | if (!permutationMap) | ||||
1234 | return nullptr; | ||||
1235 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: " ; } } while (false); | ||||
1236 | LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { permutationMap.print(dbgs()); } } while (false ); | ||||
1237 | |||||
1238 | auto transfer = state.builder.create<vector::TransferWriteOp>( | ||||
1239 | storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices, | ||||
1240 | permutationMap); | ||||
1241 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer; } } while (false); | ||||
1242 | |||||
1243 | // Register replacement for future uses in the scope. | ||||
1244 | state.registerOpVectorReplacement(storeOp, transfer); | ||||
1245 | return transfer; | ||||
1246 | } | ||||
1247 | |||||
1248 | /// Returns true if `value` is a constant equal to the neutral element of the | ||||
1249 | /// given vectorizable reduction. | ||||
1250 | static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind, | ||||
1251 | Value value, VectorizationState &state) { | ||||
1252 | Type scalarTy = value.getType(); | ||||
1253 | if (!VectorType::isValidElementType(scalarTy)) | ||||
1254 | return false; | ||||
1255 | Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy, | ||||
1256 | state.builder, value.getLoc()); | ||||
1257 | if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp())) | ||||
1258 | return constOp.getValue() == valueAttr; | ||||
1259 | return false; | ||||
1260 | } | ||||
1261 | |||||
1262 | /// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is | ||||
1263 | /// created and registered as replacement for the scalar loop. The builder's | ||||
1264 | /// insertion point is set to the new loop's body so that subsequent vectorized | ||||
1265 | /// operations are inserted into the new loop. If the loop is a vector | ||||
1266 | /// dimension, the step of the newly created loop will reflect the vectorization | ||||
1267 | /// factor used to vectorized that dimension. | ||||
1268 | static Operation *vectorizeAffineForOp(AffineForOp forOp, | ||||
1269 | VectorizationState &state) { | ||||
1270 | const VectorizationStrategy &strategy = *state.strategy; | ||||
1271 | auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp); | ||||
1272 | bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end(); | ||||
1273 | |||||
1274 | // TODO: Vectorization of reduction loops is not supported for non-unit steps. | ||||
1275 | if (isLoopVecDim
| ||||
1276 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1277 | dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1278 | << "\n[early-vect]+++++ unsupported step size for reduction loop: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false) | ||||
1279 | << forOp.getStep() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: " << forOp.getStep() << "\n"; } } while (false); | ||||
1280 | return nullptr; | ||||
1281 | } | ||||
1282 | |||||
1283 | // If we are vectorizing a vector dimension, compute a new step for the new | ||||
1284 | // vectorized loop using the vectorization factor for the vector dimension. | ||||
1285 | // Otherwise, propagate the step of the scalar loop. | ||||
1286 | unsigned newStep; | ||||
1287 | if (isLoopVecDim
| ||||
1288 | unsigned vectorDim = loopToVecDimIt->second; | ||||
1289 | assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow")(static_cast <bool> (vectorDim < strategy.vectorSizes .size() && "vector dim overflow") ? void (0) : __assert_fail ("vectorDim < strategy.vectorSizes.size() && \"vector dim overflow\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1289 , __extension__ __PRETTY_FUNCTION__)); | ||||
1290 | int64_t forOpVecFactor = strategy.vectorSizes[vectorDim]; | ||||
1291 | newStep = forOp.getStep() * forOpVecFactor; | ||||
1292 | } else { | ||||
1293 | newStep = forOp.getStep(); | ||||
1294 | } | ||||
1295 | |||||
1296 | // Get information about reduction kinds. | ||||
1297 | ArrayRef<LoopReduction> reductions; | ||||
1298 | if (isLoopVecDim
| ||||
1299 | auto it = strategy.reductionLoops.find(forOp); | ||||
1300 | assert(it != strategy.reductionLoops.end() &&(static_cast <bool> (it != strategy.reductionLoops.end( ) && "Reduction descriptors not found when vectorizing a reduction loop" ) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1301 , __extension__ __PRETTY_FUNCTION__)) | ||||
1301 | "Reduction descriptors not found when vectorizing a reduction loop")(static_cast <bool> (it != strategy.reductionLoops.end( ) && "Reduction descriptors not found when vectorizing a reduction loop" ) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1301 , __extension__ __PRETTY_FUNCTION__)); | ||||
1302 | reductions = it->second; | ||||
1303 | assert(reductions.size() == forOp.getNumIterOperands() &&(static_cast <bool> (reductions.size() == forOp.getNumIterOperands () && "The size of reductions array must match the number of iter_args" ) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1304 , __extension__ __PRETTY_FUNCTION__)) | ||||
1304 | "The size of reductions array must match the number of iter_args")(static_cast <bool> (reductions.size() == forOp.getNumIterOperands () && "The size of reductions array must match the number of iter_args" ) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1304 , __extension__ __PRETTY_FUNCTION__)); | ||||
1305 | } | ||||
1306 | |||||
1307 | // Vectorize 'iter_args'. | ||||
1308 | SmallVector<Value, 8> vecIterOperands; | ||||
1309 | if (!isLoopVecDim
| ||||
1310 | for (auto operand : forOp.getIterOperands()) | ||||
1311 | vecIterOperands.push_back(vectorizeOperand(operand, state)); | ||||
1312 | } else { | ||||
1313 | // For reduction loops we need to pass a vector of neutral elements as an | ||||
1314 | // initial value of the accumulator. We will add the original initial value | ||||
1315 | // later. | ||||
1316 | for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) { | ||||
1317 | vecIterOperands.push_back(createInitialVector( | ||||
1318 | std::get<0>(redAndOperand).kind, std::get<1>(redAndOperand), state)); | ||||
1319 | } | ||||
1320 | } | ||||
1321 | |||||
1322 | auto vecForOp = state.builder.create<AffineForOp>( | ||||
1323 | forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(), | ||||
1324 | forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep, | ||||
1325 | vecIterOperands, | ||||
1326 | /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) { | ||||
1327 | // Make sure we don't create a default terminator in the loop body as | ||||
1328 | // the proper terminator will be added during vectorization. | ||||
1329 | }); | ||||
1330 | |||||
1331 | // Register loop-related replacements: | ||||
1332 | // 1) The new vectorized loop is registered as vector replacement of the | ||||
1333 | // scalar loop. | ||||
1334 | // 2) The new iv of the vectorized loop is registered as scalar replacement | ||||
1335 | // since a scalar copy of the iv will prevail in the vectorized loop. | ||||
1336 | // TODO: A vector replacement will also be added in the future when | ||||
1337 | // vectorization of linear ops is supported. | ||||
1338 | // 3) The new 'iter_args' region arguments are registered as vector | ||||
1339 | // replacements since they have been vectorized. | ||||
1340 | // 4) If the loop performs a reduction along the vector dimension, a | ||||
1341 | // `vector.reduction` or similar op is inserted for each resulting value | ||||
1342 | // of the loop and its scalar value replaces the corresponding scalar | ||||
1343 | // result of the loop. | ||||
1344 | state.registerOpVectorReplacement(forOp, vecForOp); | ||||
1345 | state.registerValueScalarReplacement(forOp.getInductionVar(), | ||||
1346 | vecForOp.getInductionVar()); | ||||
1347 | for (auto iterTuple : | ||||
1348 | llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs())) | ||||
1349 | state.registerBlockArgVectorReplacement(std::get<0>(iterTuple), | ||||
1350 | std::get<1>(iterTuple)); | ||||
1351 | |||||
1352 | if (isLoopVecDim) { | ||||
1353 | for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) { | ||||
1354 | // First, we reduce the vector returned from the loop into a scalar. | ||||
1355 | Value reducedRes = | ||||
1356 | getVectorReductionOp(reductions[i].kind, state.builder, | ||||
1357 | vecForOp.getLoc(), vecForOp.getResult(i)); | ||||
1358 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: " << reducedRes; } } while (false) | ||||
1359 | << reducedRes)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: " << reducedRes; } } while (false); | ||||
1360 | // Then we combine it with the original (scalar) initial value unless it | ||||
1361 | // is equal to the neutral element of the reduction. | ||||
1362 | Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i); | ||||
1363 | Value finalRes = reducedRes; | ||||
1364 | if (!isNeutralElementConst(reductions[i].kind, origInit, state)) | ||||
1365 | finalRes = | ||||
1366 | arith::getReductionOp(reductions[i].kind, state.builder, | ||||
1367 | reducedRes.getLoc(), reducedRes, origInit); | ||||
1368 | state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes); | ||||
1369 | } | ||||
1370 | } | ||||
1371 | |||||
1372 | if (isLoopVecDim) | ||||
1373 | state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second; | ||||
1374 | |||||
1375 | // Change insertion point so that upcoming vectorized instructions are | ||||
1376 | // inserted into the vectorized loop's body. | ||||
1377 | state.builder.setInsertionPointToStart(vecForOp.getBody()); | ||||
1378 | |||||
1379 | // If this is a reduction loop then we may need to create a mask to filter out | ||||
1380 | // garbage in the last iteration. | ||||
1381 | if (isLoopVecDim && forOp.getNumIterOperands() > 0) | ||||
1382 | createMask(vecForOp, state); | ||||
1383 | |||||
1384 | return vecForOp; | ||||
1385 | } | ||||
1386 | |||||
1387 | /// Vectorizes arbitrary operation by plain widening. We apply generic type | ||||
1388 | /// widening of all its results and retrieve the vector counterparts for all its | ||||
1389 | /// operands. | ||||
1390 | static Operation *widenOp(Operation *op, VectorizationState &state) { | ||||
1391 | SmallVector<Type, 8> vectorTypes; | ||||
1392 | for (Value result : op->getResults()) | ||||
1393 | vectorTypes.push_back( | ||||
1394 | VectorType::get(state.strategy->vectorSizes, result.getType())); | ||||
1395 | |||||
1396 | SmallVector<Value, 8> vectorOperands; | ||||
1397 | for (Value operand : op->getOperands()) { | ||||
1398 | Value vecOperand = vectorizeOperand(operand, state); | ||||
1399 | if (!vecOperand) { | ||||
1400 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n" ; } } while (false); | ||||
1401 | return nullptr; | ||||
1402 | } | ||||
1403 | vectorOperands.push_back(vecOperand); | ||||
1404 | } | ||||
1405 | |||||
1406 | // Create a clone of the op with the proper operands and return types. | ||||
1407 | // TODO: The following assumes there is always an op with a fixed | ||||
1408 | // name that works both in scalar mode and vector mode. | ||||
1409 | // TODO: Is it worth considering an Operation.clone operation which | ||||
1410 | // changes the type so we can promote an Operation with less boilerplate? | ||||
1411 | Operation *vecOp = | ||||
1412 | state.builder.create(op->getLoc(), op->getName().getIdentifier(), | ||||
1413 | vectorOperands, vectorTypes, op->getAttrs()); | ||||
1414 | state.registerOpVectorReplacement(op, vecOp); | ||||
1415 | return vecOp; | ||||
1416 | } | ||||
1417 | |||||
1418 | /// Vectorizes a yield operation by widening its types. The builder's insertion | ||||
1419 | /// point is set after the vectorized parent op to continue vectorizing the | ||||
1420 | /// operations after the parent op. When vectorizing a reduction loop a mask may | ||||
1421 | /// be used to prevent adding garbage values to the accumulator. | ||||
1422 | static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp, | ||||
1423 | VectorizationState &state) { | ||||
1424 | Operation *newYieldOp = widenOp(yieldOp, state); | ||||
1425 | Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp(); | ||||
1426 | |||||
1427 | // If there is a mask for this loop then we must prevent garbage values from | ||||
1428 | // being added to the accumulator by inserting `select` operations, for | ||||
1429 | // example: | ||||
1430 | // | ||||
1431 | // %res = arith.addf %acc, %val : vector<128xf32> | ||||
1432 | // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32> | ||||
1433 | // affine.yield %res_masked : vector<128xf32> | ||||
1434 | // | ||||
1435 | if (Value mask = state.vecLoopToMask.lookup(newParentOp)) { | ||||
1436 | state.builder.setInsertionPoint(newYieldOp); | ||||
1437 | for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) { | ||||
1438 | Value result = newYieldOp->getOperand(i); | ||||
1439 | Value iterArg = cast<AffineForOp>(newParentOp).getRegionIterArgs()[i]; | ||||
1440 | Value maskedResult = state.builder.create<arith::SelectOp>( | ||||
1441 | result.getLoc(), mask, result, iterArg); | ||||
1442 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " << maskedResult; } } while (false) | ||||
1443 | dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " << maskedResult; } } while (false) | ||||
1444 | << maskedResult)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: " << maskedResult; } } while (false); | ||||
1445 | newYieldOp->setOperand(i, maskedResult); | ||||
1446 | } | ||||
1447 | } | ||||
1448 | |||||
1449 | state.builder.setInsertionPointAfter(newParentOp); | ||||
1450 | return newYieldOp; | ||||
1451 | } | ||||
1452 | |||||
1453 | /// Encodes Operation-specific behavior for vectorization. In general we | ||||
1454 | /// assume that all operands of an op must be vectorized but this is not | ||||
1455 | /// always true. In the future, it would be nice to have a trait that | ||||
1456 | /// describes how a particular operation vectorizes. For now we implement the | ||||
1457 | /// case distinction here. Returns a vectorized form of an operation or | ||||
1458 | /// nullptr if vectorization fails. | ||||
1459 | // TODO: consider adding a trait to Op to describe how it gets vectorized. | ||||
1460 | // Maybe some Ops are not vectorizable or require some tricky logic, we cannot | ||||
1461 | // do one-off logic here; ideally it would be TableGen'd. | ||||
1462 | static Operation *vectorizeOneOperation(Operation *op, | ||||
1463 | VectorizationState &state) { | ||||
1464 | // Sanity checks. | ||||
1465 | assert(!isa<vector::TransferReadOp>(op) &&(static_cast <bool> (!isa<vector::TransferReadOp> (op) && "vector.transfer_read cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1466 , __extension__ __PRETTY_FUNCTION__)) | ||||
1466 | "vector.transfer_read cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferReadOp> (op) && "vector.transfer_read cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1466 , __extension__ __PRETTY_FUNCTION__)); | ||||
1467 | assert(!isa<vector::TransferWriteOp>(op) &&(static_cast <bool> (!isa<vector::TransferWriteOp> (op) && "vector.transfer_write cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1468 , __extension__ __PRETTY_FUNCTION__)) | ||||
1468 | "vector.transfer_write cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferWriteOp> (op) && "vector.transfer_write cannot be further vectorized" ) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1468 , __extension__ __PRETTY_FUNCTION__)); | ||||
1469 | |||||
1470 | if (auto loadOp = dyn_cast<AffineLoadOp>(op)) | ||||
1471 | return vectorizeAffineLoad(loadOp, state); | ||||
1472 | if (auto storeOp = dyn_cast<AffineStoreOp>(op)) | ||||
1473 | return vectorizeAffineStore(storeOp, state); | ||||
1474 | if (auto forOp = dyn_cast<AffineForOp>(op)) | ||||
1475 | return vectorizeAffineForOp(forOp, state); | ||||
1476 | if (auto yieldOp = dyn_cast<AffineYieldOp>(op)) | ||||
1477 | return vectorizeAffineYieldOp(yieldOp, state); | ||||
1478 | if (auto constant = dyn_cast<arith::ConstantOp>(op)) | ||||
1479 | return vectorizeConstant(constant, state); | ||||
1480 | |||||
1481 | // Other ops with regions are not supported. | ||||
1482 | if (op->getNumRegions() != 0) | ||||
1483 | return nullptr; | ||||
1484 | |||||
1485 | return widenOp(op, state); | ||||
1486 | } | ||||
1487 | |||||
1488 | /// Recursive implementation to convert all the nested loops in 'match' to a 2D | ||||
1489 | /// vector container that preserves the relative nesting level of each loop with | ||||
1490 | /// respect to the others in 'match'. 'currentLevel' is the nesting level that | ||||
1491 | /// will be assigned to the loop in the current 'match'. | ||||
1492 | static void | ||||
1493 | getMatchedAffineLoopsRec(NestedMatch match, unsigned currentLevel, | ||||
1494 | std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1495 | // Add a new empty level to the output if it doesn't exist already. | ||||
1496 | assert(currentLevel <= loops.size() && "Unexpected currentLevel")(static_cast <bool> (currentLevel <= loops.size() && "Unexpected currentLevel") ? void (0) : __assert_fail ("currentLevel <= loops.size() && \"Unexpected currentLevel\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1496 , __extension__ __PRETTY_FUNCTION__)); | ||||
1497 | if (currentLevel == loops.size()) | ||||
1498 | loops.emplace_back(); | ||||
1499 | |||||
1500 | // Add current match and recursively visit its children. | ||||
1501 | loops[currentLevel].push_back(cast<AffineForOp>(match.getMatchedOperation())); | ||||
1502 | for (auto childMatch : match.getMatchedChildren()) { | ||||
1503 | getMatchedAffineLoopsRec(childMatch, currentLevel + 1, loops); | ||||
1504 | } | ||||
1505 | } | ||||
1506 | |||||
1507 | /// Converts all the nested loops in 'match' to a 2D vector container that | ||||
1508 | /// preserves the relative nesting level of each loop with respect to the others | ||||
1509 | /// in 'match'. This means that every loop in 'loops[i]' will have a parent loop | ||||
1510 | /// in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in | ||||
1511 | /// 'loops[i+1]'. | ||||
1512 | static void | ||||
1513 | getMatchedAffineLoops(NestedMatch match, | ||||
1514 | std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1515 | getMatchedAffineLoopsRec(match, /*currLoopDepth=*/0, loops); | ||||
1516 | } | ||||
1517 | |||||
1518 | /// Internal implementation to vectorize affine loops from a single loop nest | ||||
1519 | /// using an n-D vectorization strategy. | ||||
1520 | static LogicalResult | ||||
1521 | vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops, | ||||
1522 | const VectorizationStrategy &strategy) { | ||||
1523 | assert(loops[0].size() == 1 && "Expected single root loop")(static_cast <bool> (loops[0].size() == 1 && "Expected single root loop" ) ? void (0) : __assert_fail ("loops[0].size() == 1 && \"Expected single root loop\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1523 , __extension__ __PRETTY_FUNCTION__)); | ||||
1524 | AffineForOp rootLoop = loops[0][0]; | ||||
1525 | VectorizationState state(rootLoop.getContext()); | ||||
1526 | state.builder.setInsertionPointAfter(rootLoop); | ||||
1527 | state.strategy = &strategy; | ||||
1528 | |||||
1529 | // Since patterns are recursive, they can very well intersect. | ||||
1530 | // Since we do not want a fully greedy strategy in general, we decouple | ||||
1531 | // pattern matching, from profitability analysis, from application. | ||||
1532 | // As a consequence we must check that each root pattern is still | ||||
1533 | // vectorizable. If a pattern is not vectorizable anymore, we just skip it. | ||||
1534 | // TODO: implement a non-greedy profitability analysis that keeps only | ||||
1535 | // non-intersecting patterns. | ||||
1536 | if (!isVectorizableLoopBody(rootLoop, vectorTransferPattern())) { | ||||
1537 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ loop is not vectorizable" ; } } while (false); | ||||
1538 | return failure(); | ||||
1539 | } | ||||
1540 | |||||
1541 | ////////////////////////////////////////////////////////////////////////////// | ||||
1542 | // Vectorize the scalar loop nest following a topological order. A new vector | ||||
1543 | // loop nest with the vectorized operations is created along the process. If | ||||
1544 | // vectorization succeeds, the scalar loop nest is erased. If vectorization | ||||
1545 | // fails, the vector loop nest is erased and the scalar loop nest is not | ||||
1546 | // modified. | ||||
1547 | ////////////////////////////////////////////////////////////////////////////// | ||||
1548 | |||||
1549 | auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) { | ||||
1550 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ Vectorizing: " << *op; } } while (false); | ||||
| |||||
1551 | Operation *vectorOp = vectorizeOneOperation(op, state); | ||||
1552 | if (!vectorOp) { | ||||
1553 | LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false) | ||||
1554 | dbgs() << "[early-vect]+++++ failed vectorizing the operation: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false) | ||||
1555 | << *op << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: " << *op << "\n"; } } while (false); | ||||
1556 | return WalkResult::interrupt(); | ||||
1557 | } | ||||
1558 | |||||
1559 | return WalkResult::advance(); | ||||
1560 | }); | ||||
1561 | |||||
1562 | if (opVecResult.wasInterrupted()) { | ||||
1563 | LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: " << rootLoop << "\n"; } } while (false) | ||||
1564 | << rootLoop << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: " << rootLoop << "\n"; } } while (false); | ||||
1565 | // Erase vector loop nest if it was created. | ||||
1566 | auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop); | ||||
1567 | if (vecRootLoopIt != state.opVectorReplacement.end()) | ||||
1568 | eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second)); | ||||
1569 | |||||
1570 | return failure(); | ||||
1571 | } | ||||
1572 | |||||
1573 | // Replace results of reduction loops with the scalar values computed using | ||||
1574 | // `vector.reduce` or similar ops. | ||||
1575 | for (auto resPair : state.loopResultScalarReplacement) | ||||
1576 | resPair.first.replaceAllUsesWith(resPair.second); | ||||
1577 | |||||
1578 | assert(state.opVectorReplacement.count(rootLoop) == 1 &&(static_cast <bool> (state.opVectorReplacement.count(rootLoop ) == 1 && "Expected vector replacement for loop nest" ) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1579 , __extension__ __PRETTY_FUNCTION__)) | ||||
1579 | "Expected vector replacement for loop nest")(static_cast <bool> (state.opVectorReplacement.count(rootLoop ) == 1 && "Expected vector replacement for loop nest" ) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1579 , __extension__ __PRETTY_FUNCTION__)); | ||||
1580 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ success vectorizing pattern" ; } } while (false); | ||||
1581 | LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n" << *state.opVectorReplacement[rootLoop]; } } while (false ) | ||||
1582 | << *state.opVectorReplacement[rootLoop])do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n" << *state.opVectorReplacement[rootLoop]; } } while (false ); | ||||
1583 | |||||
1584 | // Finish this vectorization pattern. | ||||
1585 | state.finishVectorizationPattern(rootLoop); | ||||
1586 | return success(); | ||||
1587 | } | ||||
1588 | |||||
1589 | /// Extracts the matched loops and vectorizes them following a topological | ||||
1590 | /// order. A new vector loop nest will be created if vectorization succeeds. The | ||||
1591 | /// original loop nest won't be modified in any case. | ||||
1592 | static LogicalResult vectorizeRootMatch(NestedMatch m, | ||||
1593 | const VectorizationStrategy &strategy) { | ||||
1594 | std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize; | ||||
1595 | getMatchedAffineLoops(m, loopsToVectorize); | ||||
1596 | return vectorizeLoopNest(loopsToVectorize, strategy); | ||||
1597 | } | ||||
1598 | |||||
1599 | /// Traverses all the loop matches and classifies them into intersection | ||||
1600 | /// buckets. Two matches intersect if any of them encloses the other one. A | ||||
1601 | /// match intersects with a bucket if the match intersects with the root | ||||
1602 | /// (outermost) loop in that bucket. | ||||
1603 | static void computeIntersectionBuckets( | ||||
1604 | ArrayRef<NestedMatch> matches, | ||||
1605 | std::vector<SmallVector<NestedMatch, 8>> &intersectionBuckets) { | ||||
1606 | assert(intersectionBuckets.empty() && "Expected empty output")(static_cast <bool> (intersectionBuckets.empty() && "Expected empty output") ? void (0) : __assert_fail ("intersectionBuckets.empty() && \"Expected empty output\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1606 , __extension__ __PRETTY_FUNCTION__)); | ||||
1607 | // Keeps track of the root (outermost) loop of each bucket. | ||||
1608 | SmallVector<AffineForOp, 8> bucketRoots; | ||||
1609 | |||||
1610 | for (const NestedMatch &match : matches) { | ||||
1611 | AffineForOp matchRoot = cast<AffineForOp>(match.getMatchedOperation()); | ||||
1612 | bool intersects = false; | ||||
1613 | for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) { | ||||
1614 | AffineForOp bucketRoot = bucketRoots[i]; | ||||
1615 | // Add match to the bucket if the bucket root encloses the match root. | ||||
1616 | if (bucketRoot->isAncestor(matchRoot)) { | ||||
1617 | intersectionBuckets[i].push_back(match); | ||||
1618 | intersects = true; | ||||
1619 | break; | ||||
1620 | } | ||||
1621 | // Add match to the bucket if the match root encloses the bucket root. The | ||||
1622 | // match root becomes the new bucket root. | ||||
1623 | if (matchRoot->isAncestor(bucketRoot)) { | ||||
1624 | bucketRoots[i] = matchRoot; | ||||
1625 | intersectionBuckets[i].push_back(match); | ||||
1626 | intersects = true; | ||||
1627 | break; | ||||
1628 | } | ||||
1629 | } | ||||
1630 | |||||
1631 | // Match doesn't intersect with any existing bucket. Create a new bucket for | ||||
1632 | // it. | ||||
1633 | if (!intersects) { | ||||
1634 | bucketRoots.push_back(matchRoot); | ||||
1635 | intersectionBuckets.emplace_back(); | ||||
1636 | intersectionBuckets.back().push_back(match); | ||||
1637 | } | ||||
1638 | } | ||||
1639 | } | ||||
1640 | |||||
1641 | /// Internal implementation to vectorize affine loops in 'loops' using the n-D | ||||
1642 | /// vectorization factors in 'vectorSizes'. By default, each vectorization | ||||
1643 | /// factor is applied inner-to-outer to the loops of each loop nest. | ||||
1644 | /// 'fastestVaryingPattern' can be optionally used to provide a different loop | ||||
1645 | /// vectorization order. `reductionLoops` can be provided to specify loops which | ||||
1646 | /// can be vectorized along the reduction dimension. | ||||
1647 | static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops, | ||||
1648 | ArrayRef<int64_t> vectorSizes, | ||||
1649 | ArrayRef<int64_t> fastestVaryingPattern, | ||||
1650 | const ReductionLoopMap &reductionLoops) { | ||||
1651 | assert((reductionLoops.empty() || vectorSizes.size() == 1) &&(static_cast <bool> ((reductionLoops.empty() || vectorSizes .size() == 1) && "Vectorizing reductions is supported only for 1-D vectors" ) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1652 , __extension__ __PRETTY_FUNCTION__)) | ||||
1652 | "Vectorizing reductions is supported only for 1-D vectors")(static_cast <bool> ((reductionLoops.empty() || vectorSizes .size() == 1) && "Vectorizing reductions is supported only for 1-D vectors" ) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\"" , "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1652 , __extension__ __PRETTY_FUNCTION__)); | ||||
1653 | |||||
1654 | // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops. | ||||
1655 | Optional<NestedPattern> pattern = | ||||
1656 | makePattern(loops, vectorSizes.size(), fastestVaryingPattern); | ||||
1657 | if (!pattern.hasValue()) { | ||||
1658 | LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] pattern couldn't be computed\n" ; } } while (false); | ||||
1659 | return; | ||||
1660 | } | ||||
1661 | |||||
1662 | LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n******************************************" ; } } while (false); | ||||
1663 | LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n******************************************" ; } } while (false); | ||||
1664 | LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n[early-vect] new pattern on parent op\n" ; } } while (false); | ||||
1665 | LLVM_DEBUG(dbgs() << *parentOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << *parentOp << "\n"; } } while (false); | ||||
1666 | |||||
1667 | unsigned patternDepth = pattern->getDepth(); | ||||
1668 | |||||
1669 | // Compute all the pattern matches and classify them into buckets of | ||||
1670 | // intersecting matches. | ||||
1671 | SmallVector<NestedMatch, 32> allMatches; | ||||
1672 | pattern->match(parentOp, &allMatches); | ||||
1673 | std::vector<SmallVector<NestedMatch, 8>> intersectionBuckets; | ||||
1674 | computeIntersectionBuckets(allMatches, intersectionBuckets); | ||||
1675 | |||||
1676 | // Iterate over all buckets and vectorize the matches eagerly. We can only | ||||
1677 | // vectorize one match from each bucket since all the matches within a bucket | ||||
1678 | // intersect. | ||||
1679 | for (auto &intersectingMatches : intersectionBuckets) { | ||||
1680 | for (NestedMatch &match : intersectingMatches) { | ||||
1681 | VectorizationStrategy strategy; | ||||
1682 | // TODO: depending on profitability, elect to reduce the vector size. | ||||
1683 | strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end()); | ||||
1684 | strategy.reductionLoops = reductionLoops; | ||||
1685 | if (failed(analyzeProfitability(match.getMatchedChildren(), 1, | ||||
1686 | patternDepth, &strategy))) { | ||||
1687 | continue; | ||||
1688 | } | ||||
1689 | vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth, | ||||
1690 | &strategy); | ||||
1691 | // Vectorize match. Skip the rest of intersecting matches in the bucket if | ||||
1692 | // vectorization succeeded. | ||||
1693 | // TODO: if pattern does not apply, report it; alter the cost/benefit. | ||||
1694 | // TODO: some diagnostics if failure to vectorize occurs. | ||||
1695 | if (succeeded(vectorizeRootMatch(match, strategy))) | ||||
1696 | break; | ||||
1697 | } | ||||
1698 | } | ||||
1699 | |||||
1700 | LLVM_DEBUG(dbgs() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType ("early-vect")) { dbgs() << "\n"; } } while (false); | ||||
1701 | } | ||||
1702 | |||||
1703 | std::unique_ptr<OperationPass<func::FuncOp>> | ||||
1704 | createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) { | ||||
1705 | return std::make_unique<Vectorize>(virtualVectorSize); | ||||
1706 | } | ||||
1707 | std::unique_ptr<OperationPass<func::FuncOp>> createSuperVectorizePass() { | ||||
1708 | return std::make_unique<Vectorize>(); | ||||
1709 | } | ||||
1710 | |||||
1711 | /// Applies vectorization to the current function by searching over a bunch of | ||||
1712 | /// predetermined patterns. | ||||
1713 | void Vectorize::runOnOperation() { | ||||
1714 | func::FuncOp f = getOperation(); | ||||
1715 | if (!fastestVaryingPattern.empty() && | ||||
1716 | fastestVaryingPattern.size() != vectorSizes.size()) { | ||||
1717 | f.emitRemark("Fastest varying pattern specified with different size than " | ||||
1718 | "the vector size."); | ||||
1719 | return signalPassFailure(); | ||||
1720 | } | ||||
1721 | |||||
1722 | if (vectorizeReductions && vectorSizes.size() != 1) { | ||||
1723 | f.emitError("Vectorizing reductions is supported only for 1-D vectors."); | ||||
1724 | return signalPassFailure(); | ||||
1725 | } | ||||
1726 | |||||
1727 | DenseSet<Operation *> parallelLoops; | ||||
1728 | ReductionLoopMap reductionLoops; | ||||
1729 | |||||
1730 | // If 'vectorize-reduction=true' is provided, we also populate the | ||||
1731 | // `reductionLoops` map. | ||||
1732 | if (vectorizeReductions) { | ||||
1733 | f.walk([¶llelLoops, &reductionLoops](AffineForOp loop) { | ||||
1734 | SmallVector<LoopReduction, 2> reductions; | ||||
1735 | if (isLoopParallel(loop, &reductions)) { | ||||
1736 | parallelLoops.insert(loop); | ||||
1737 | // If it's not a reduction loop, adding it to the map is not necessary. | ||||
1738 | if (!reductions.empty()) | ||||
1739 | reductionLoops[loop] = reductions; | ||||
1740 | } | ||||
1741 | }); | ||||
1742 | } else { | ||||
1743 | f.walk([¶llelLoops](AffineForOp loop) { | ||||
1744 | if (isLoopParallel(loop)) | ||||
1745 | parallelLoops.insert(loop); | ||||
1746 | }); | ||||
1747 | } | ||||
1748 | |||||
1749 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1750 | NestedPatternContext mlContext; | ||||
1751 | vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern, | ||||
1752 | reductionLoops); | ||||
1753 | } | ||||
1754 | |||||
1755 | /// Verify that affine loops in 'loops' meet the nesting criteria expected by | ||||
1756 | /// SuperVectorizer: | ||||
1757 | /// * There must be at least one loop. | ||||
1758 | /// * There must be a single root loop (nesting level 0). | ||||
1759 | /// * Each loop at a given nesting level must be nested in a loop from a | ||||
1760 | /// previous nesting level. | ||||
1761 | static LogicalResult | ||||
1762 | verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) { | ||||
1763 | // Expected at least one loop. | ||||
1764 | if (loops.empty()) | ||||
1765 | return failure(); | ||||
1766 | |||||
1767 | // Expected only one root loop. | ||||
1768 | if (loops[0].size() != 1) | ||||
1769 | return failure(); | ||||
1770 | |||||
1771 | // Traverse loops outer-to-inner to check some invariants. | ||||
1772 | for (int i = 1, end = loops.size(); i < end; ++i) { | ||||
1773 | for (AffineForOp loop : loops[i]) { | ||||
1774 | // Check that each loop at this level is nested in one of the loops from | ||||
1775 | // the previous level. | ||||
1776 | if (none_of(loops[i - 1], [&](AffineForOp maybeParent) { | ||||
1777 | return maybeParent->isProperAncestor(loop); | ||||
1778 | })) | ||||
1779 | return failure(); | ||||
1780 | |||||
1781 | // Check that each loop at this level is not nested in another loop from | ||||
1782 | // this level. | ||||
1783 | for (AffineForOp sibling : loops[i]) { | ||||
1784 | if (sibling->isProperAncestor(loop)) | ||||
1785 | return failure(); | ||||
1786 | } | ||||
1787 | } | ||||
1788 | } | ||||
1789 | |||||
1790 | return success(); | ||||
1791 | } | ||||
1792 | |||||
1793 | namespace mlir { | ||||
1794 | |||||
1795 | /// External utility to vectorize affine loops in 'loops' using the n-D | ||||
1796 | /// vectorization factors in 'vectorSizes'. By default, each vectorization | ||||
1797 | /// factor is applied inner-to-outer to the loops of each loop nest. | ||||
1798 | /// 'fastestVaryingPattern' can be optionally used to provide a different loop | ||||
1799 | /// vectorization order. | ||||
1800 | /// If `reductionLoops` is not empty, the given reduction loops may be | ||||
1801 | /// vectorized along the reduction dimension. | ||||
1802 | /// TODO: Vectorizing reductions is supported only for 1-D vectorization. | ||||
1803 | void vectorizeAffineLoops(Operation *parentOp, DenseSet<Operation *> &loops, | ||||
1804 | ArrayRef<int64_t> vectorSizes, | ||||
1805 | ArrayRef<int64_t> fastestVaryingPattern, | ||||
1806 | const ReductionLoopMap &reductionLoops) { | ||||
1807 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1808 | NestedPatternContext mlContext; | ||||
1809 | vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern, | ||||
1810 | reductionLoops); | ||||
1811 | } | ||||
1812 | |||||
1813 | /// External utility to vectorize affine loops from a single loop nest using an | ||||
1814 | /// n-D vectorization strategy (see doc in VectorizationStrategy definition). | ||||
1815 | /// Loops are provided in a 2D vector container. The first dimension represents | ||||
1816 | /// the nesting level relative to the loops to be vectorized. The second | ||||
1817 | /// dimension contains the loops. This means that: | ||||
1818 | /// a) every loop in 'loops[i]' must have a parent loop in 'loops[i-1]', | ||||
1819 | /// b) a loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'. | ||||
1820 | /// | ||||
1821 | /// For example, for the following loop nest: | ||||
1822 | /// | ||||
1823 | /// func @vec2d(%in0: memref<64x128x512xf32>, %in1: memref<64x128x128xf32>, | ||||
1824 | /// %out0: memref<64x128x512xf32>, | ||||
1825 | /// %out1: memref<64x128x128xf32>) { | ||||
1826 | /// affine.for %i0 = 0 to 64 { | ||||
1827 | /// affine.for %i1 = 0 to 128 { | ||||
1828 | /// affine.for %i2 = 0 to 512 { | ||||
1829 | /// %ld = affine.load %in0[%i0, %i1, %i2] : memref<64x128x512xf32> | ||||
1830 | /// affine.store %ld, %out0[%i0, %i1, %i2] : memref<64x128x512xf32> | ||||
1831 | /// } | ||||
1832 | /// affine.for %i3 = 0 to 128 { | ||||
1833 | /// %ld = affine.load %in1[%i0, %i1, %i3] : memref<64x128x128xf32> | ||||
1834 | /// affine.store %ld, %out1[%i0, %i1, %i3] : memref<64x128x128xf32> | ||||
1835 | /// } | ||||
1836 | /// } | ||||
1837 | /// } | ||||
1838 | /// return | ||||
1839 | /// } | ||||
1840 | /// | ||||
1841 | /// loops = {{%i0}, {%i2, %i3}}, to vectorize the outermost and the two | ||||
1842 | /// innermost loops; | ||||
1843 | /// loops = {{%i1}, {%i2, %i3}}, to vectorize the middle and the two innermost | ||||
1844 | /// loops; | ||||
1845 | /// loops = {{%i2}}, to vectorize only the first innermost loop; | ||||
1846 | /// loops = {{%i3}}, to vectorize only the second innermost loop; | ||||
1847 | /// loops = {{%i1}}, to vectorize only the middle loop. | ||||
1848 | LogicalResult | ||||
1849 | vectorizeAffineLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops, | ||||
1850 | const VectorizationStrategy &strategy) { | ||||
1851 | // Thread-safe RAII local context, BumpPtrAllocator freed on exit. | ||||
1852 | NestedPatternContext mlContext; | ||||
1853 | if (failed(verifyLoopNesting(loops))) | ||||
1854 | return failure(); | ||||
1855 | return vectorizeLoopNest(loops, strategy); | ||||
1856 | } | ||||
1857 | |||||
1858 | std::unique_ptr<OperationPass<func::FuncOp>> | ||||
1859 | createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) { | ||||
1860 | return std::make_unique<Vectorize>(virtualVectorSize); | ||||
1861 | } | ||||
1862 | std::unique_ptr<OperationPass<func::FuncOp>> createSuperVectorizePass() { | ||||
1863 | return std::make_unique<Vectorize>(); | ||||
1864 | } | ||||
1865 | |||||
1866 | } // namespace mlir |
1 | //===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===// | |||
2 | // | |||
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | |||
4 | // See https://llvm.org/LICENSE.txt for license information. | |||
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | |||
6 | // | |||
7 | //===----------------------------------------------------------------------===// | |||
8 | // | |||
9 | // This file implements helper classes for implementing the "Op" types. This | |||
10 | // includes the Op type, which is the base class for Op class definitions, | |||
11 | // as well as number of traits in the OpTrait namespace that provide a | |||
12 | // declarative way to specify properties of Ops. | |||
13 | // | |||
14 | // The purpose of these types are to allow light-weight implementation of | |||
15 | // concrete ops (like DimOp) with very little boilerplate. | |||
16 | // | |||
17 | //===----------------------------------------------------------------------===// | |||
18 | ||||
19 | #ifndef MLIR_IR_OPDEFINITION_H | |||
20 | #define MLIR_IR_OPDEFINITION_H | |||
21 | ||||
22 | #include "mlir/IR/Dialect.h" | |||
23 | #include "mlir/IR/Operation.h" | |||
24 | #include "llvm/Support/PointerLikeTypeTraits.h" | |||
25 | ||||
26 | #include <type_traits> | |||
27 | ||||
28 | namespace mlir { | |||
29 | class Builder; | |||
30 | class OpBuilder; | |||
31 | ||||
32 | /// This class implements `Optional` functionality for ParseResult. We don't | |||
33 | /// directly use Optional here, because it provides an implicit conversion | |||
34 | /// to 'bool' which we want to avoid. This class is used to implement tri-state | |||
35 | /// 'parseOptional' functions that may have a failure mode when parsing that | |||
36 | /// shouldn't be attributed to "not present". | |||
37 | class OptionalParseResult { | |||
38 | public: | |||
39 | OptionalParseResult() = default; | |||
40 | OptionalParseResult(LogicalResult result) : impl(result) {} | |||
41 | OptionalParseResult(ParseResult result) : impl(result) {} | |||
42 | OptionalParseResult(const InFlightDiagnostic &) | |||
43 | : OptionalParseResult(failure()) {} | |||
44 | OptionalParseResult(llvm::NoneType) : impl(llvm::None) {} | |||
45 | ||||
46 | /// Returns true if we contain a valid ParseResult value. | |||
47 | bool hasValue() const { return impl.hasValue(); } | |||
48 | ||||
49 | /// Access the internal ParseResult value. | |||
50 | ParseResult getValue() const { return impl.getValue(); } | |||
51 | ParseResult operator*() const { return getValue(); } | |||
52 | ||||
53 | private: | |||
54 | Optional<ParseResult> impl; | |||
55 | }; | |||
56 | ||||
57 | // These functions are out-of-line utilities, which avoids them being template | |||
58 | // instantiated/duplicated. | |||
59 | namespace impl { | |||
60 | /// Insert an operation, generated by `buildTerminatorOp`, at the end of the | |||
61 | /// region's only block if it does not have a terminator already. If the region | |||
62 | /// is empty, insert a new block first. `buildTerminatorOp` should return the | |||
63 | /// terminator operation to insert. | |||
64 | void ensureRegionTerminator( | |||
65 | Region ®ion, OpBuilder &builder, Location loc, | |||
66 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); | |||
67 | void ensureRegionTerminator( | |||
68 | Region ®ion, Builder &builder, Location loc, | |||
69 | function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp); | |||
70 | ||||
71 | } // namespace impl | |||
72 | ||||
73 | /// This is the concrete base class that holds the operation pointer and has | |||
74 | /// non-generic methods that only depend on State (to avoid having them | |||
75 | /// instantiated on template types that don't affect them. | |||
76 | /// | |||
77 | /// This also has the fallback implementations of customization hooks for when | |||
78 | /// they aren't customized. | |||
79 | class OpState { | |||
80 | public: | |||
81 | /// Ops are pointer-like, so we allow conversion to bool. | |||
82 | explicit operator bool() { return getOperation() != nullptr; } | |||
83 | ||||
84 | /// This implicitly converts to Operation*. | |||
85 | operator Operation *() const { return state; } | |||
86 | ||||
87 | /// Shortcut of `->` to access a member of Operation. | |||
88 | Operation *operator->() const { return state; } | |||
89 | ||||
90 | /// Return the operation that this refers to. | |||
91 | Operation *getOperation() { return state; } | |||
92 | ||||
93 | /// Return the context this operation belongs to. | |||
94 | MLIRContext *getContext() { return getOperation()->getContext(); } | |||
95 | ||||
96 | /// Print the operation to the given stream. | |||
97 | void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) { | |||
98 | state->print(os, flags); | |||
| ||||
99 | } | |||
100 | void print(raw_ostream &os, AsmState &asmState) { | |||
101 | state->print(os, asmState); | |||
102 | } | |||
103 | ||||
104 | /// Dump this operation. | |||
105 | void dump() { state->dump(); } | |||
106 | ||||
107 | /// The source location the operation was defined or derived from. | |||
108 | Location getLoc() { return state->getLoc(); } | |||
109 | ||||
110 | /// Return true if there are no users of any results of this operation. | |||
111 | bool use_empty() { return state->use_empty(); } | |||
112 | ||||
113 | /// Remove this operation from its parent block and delete it. | |||
114 | void erase() { state->erase(); } | |||
115 | ||||
116 | /// Emit an error with the op name prefixed, like "'dim' op " which is | |||
117 | /// convenient for verifiers. | |||
118 | InFlightDiagnostic emitOpError(const Twine &message = {}); | |||
119 | ||||
120 | /// Emit an error about fatal conditions with this operation, reporting up to | |||
121 | /// any diagnostic handlers that may be listening. | |||
122 | InFlightDiagnostic emitError(const Twine &message = {}); | |||
123 | ||||
124 | /// Emit a warning about this operation, reporting up to any diagnostic | |||
125 | /// handlers that may be listening. | |||
126 | InFlightDiagnostic emitWarning(const Twine &message = {}); | |||
127 | ||||
128 | /// Emit a remark about this operation, reporting up to any diagnostic | |||
129 | /// handlers that may be listening. | |||
130 | InFlightDiagnostic emitRemark(const Twine &message = {}); | |||
131 | ||||
132 | /// Walk the operation by calling the callback for each nested operation | |||
133 | /// (including this one), block or region, depending on the callback provided. | |||
134 | /// Regions, blocks and operations at the same nesting level are visited in | |||
135 | /// lexicographical order. The walk order for enclosing regions, blocks and | |||
136 | /// operations with respect to their nested ones is specified by 'Order' | |||
137 | /// (post-order by default). A callback on a block or operation is allowed to | |||
138 | /// erase that block or operation if either: | |||
139 | /// * the walk is in post-order, or | |||
140 | /// * the walk is in pre-order and the walk is skipped after the erasure. | |||
141 | /// See Operation::walk for more details. | |||
142 | template <WalkOrder Order = WalkOrder::PostOrder, typename FnT, | |||
143 | typename RetT = detail::walkResultType<FnT>> | |||
144 | typename std::enable_if< | |||
145 | llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type | |||
146 | walk(FnT &&callback) { | |||
147 | return state->walk<Order>(std::forward<FnT>(callback)); | |||
148 | } | |||
149 | ||||
150 | /// Generic walker with a stage aware callback. Walk the operation by calling | |||
151 | /// the callback for each nested operation (including this one) N+1 times, | |||
152 | /// where N is the number of regions attached to that operation. | |||
153 | /// | |||
154 | /// The callback method can take any of the following forms: | |||
155 | /// void(Operation *, const WalkStage &) : Walk all operation opaquely | |||
156 | /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...}); | |||
157 | /// void(OpT, const WalkStage &) : Walk all operations of the given derived | |||
158 | /// type. | |||
159 | /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...}); | |||
160 | /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations, | |||
161 | /// but allow for interruption/skipping. | |||
162 | /// * op.walk([](... op, const WalkStage &stage) { | |||
163 | /// // Skip the walk of this op based on some invariant. | |||
164 | /// if (some_invariant) | |||
165 | /// return WalkResult::skip(); | |||
166 | /// // Interrupt, i.e cancel, the walk based on some invariant. | |||
167 | /// if (another_invariant) | |||
168 | /// return WalkResult::interrupt(); | |||
169 | /// return WalkResult::advance(); | |||
170 | /// }); | |||
171 | template <typename FnT, typename RetT = detail::walkResultType<FnT>> | |||
172 | typename std::enable_if< | |||
173 | llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type | |||
174 | walk(FnT &&callback) { | |||
175 | return state->walk(std::forward<FnT>(callback)); | |||
176 | } | |||
177 | ||||
178 | // These are default implementations of customization hooks. | |||
179 | public: | |||
180 | /// This hook returns any canonicalization pattern rewrites that the operation | |||
181 | /// supports, for use by the canonicalization pass. | |||
182 | static void getCanonicalizationPatterns(RewritePatternSet &results, | |||
183 | MLIRContext *context) {} | |||
184 | ||||
185 | protected: | |||
186 | /// If the concrete type didn't implement a custom verifier hook, just fall | |||
187 | /// back to this one which accepts everything. | |||
188 | LogicalResult verify() { return success(); } | |||
189 | LogicalResult verifyRegions() { return success(); } | |||
190 | ||||
191 | /// Parse the custom form of an operation. Unless overridden, this method will | |||
192 | /// first try to get an operation parser from the op's dialect. Otherwise the | |||
193 | /// custom assembly form of an op is always rejected. Op implementations | |||
194 | /// should implement this to return failure. On success, they should fill in | |||
195 | /// result with the fields to use. | |||
196 | static ParseResult parse(OpAsmParser &parser, OperationState &result); | |||
197 | ||||
198 | /// Print the operation. Unless overridden, this method will first try to get | |||
199 | /// an operation printer from the dialect. Otherwise, it prints the operation | |||
200 | /// in generic form. | |||
201 | static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect); | |||
202 | ||||
203 | /// Print an operation name, eliding the dialect prefix if necessary. | |||
204 | static void printOpName(Operation *op, OpAsmPrinter &p, | |||
205 | StringRef defaultDialect); | |||
206 | ||||
207 | /// Mutability management is handled by the OpWrapper/OpConstWrapper classes, | |||
208 | /// so we can cast it away here. | |||
209 | explicit OpState(Operation *state) : state(state) {} | |||
210 | ||||
211 | private: | |||
212 | Operation *state; | |||
213 | ||||
214 | /// Allow access to internal hook implementation methods. | |||
215 | friend RegisteredOperationName; | |||
216 | }; | |||
217 | ||||
218 | // Allow comparing operators. | |||
219 | inline bool operator==(OpState lhs, OpState rhs) { | |||
220 | return lhs.getOperation() == rhs.getOperation(); | |||
221 | } | |||
222 | inline bool operator!=(OpState lhs, OpState rhs) { | |||
223 | return lhs.getOperation() != rhs.getOperation(); | |||
224 | } | |||
225 | ||||
226 | raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr); | |||
227 | ||||
228 | /// This class represents a single result from folding an operation. | |||
229 | class OpFoldResult : public PointerUnion<Attribute, Value> { | |||
230 | using PointerUnion<Attribute, Value>::PointerUnion; | |||
231 | ||||
232 | public: | |||
233 | void dump() { llvm::errs() << *this << "\n"; } | |||
234 | }; | |||
235 | ||||
236 | /// Allow printing to a stream. | |||
237 | inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) { | |||
238 | if (Value value = ofr.dyn_cast<Value>()) | |||
239 | value.print(os); | |||
240 | else | |||
241 | ofr.dyn_cast<Attribute>().print(os); | |||
242 | return os; | |||
243 | } | |||
244 | ||||
245 | /// Allow printing to a stream. | |||
246 | inline raw_ostream &operator<<(raw_ostream &os, OpState op) { | |||
247 | op.print(os, OpPrintingFlags().useLocalScope()); | |||
248 | return os; | |||
249 | } | |||
250 | ||||
251 | //===----------------------------------------------------------------------===// | |||
252 | // Operation Trait Types | |||
253 | //===----------------------------------------------------------------------===// | |||
254 | ||||
255 | namespace OpTrait { | |||
256 | ||||
257 | // These functions are out-of-line implementations of the methods in the | |||
258 | // corresponding trait classes. This avoids them being template | |||
259 | // instantiated/duplicated. | |||
260 | namespace impl { | |||
261 | OpFoldResult foldIdempotent(Operation *op); | |||
262 | OpFoldResult foldInvolution(Operation *op); | |||
263 | LogicalResult verifyZeroOperands(Operation *op); | |||
264 | LogicalResult verifyOneOperand(Operation *op); | |||
265 | LogicalResult verifyNOperands(Operation *op, unsigned numOperands); | |||
266 | LogicalResult verifyIsIdempotent(Operation *op); | |||
267 | LogicalResult verifyIsInvolution(Operation *op); | |||
268 | LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands); | |||
269 | LogicalResult verifyOperandsAreFloatLike(Operation *op); | |||
270 | LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op); | |||
271 | LogicalResult verifySameTypeOperands(Operation *op); | |||
272 | LogicalResult verifyZeroRegion(Operation *op); | |||
273 | LogicalResult verifyOneRegion(Operation *op); | |||
274 | LogicalResult verifyNRegions(Operation *op, unsigned numRegions); | |||
275 | LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions); | |||
276 | LogicalResult verifyZeroResult(Operation *op); | |||
277 | LogicalResult verifyOneResult(Operation *op); | |||
278 | LogicalResult verifyNResults(Operation *op, unsigned numOperands); | |||
279 | LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands); | |||
280 | LogicalResult verifySameOperandsShape(Operation *op); | |||
281 | LogicalResult verifySameOperandsAndResultShape(Operation *op); | |||
282 | LogicalResult verifySameOperandsElementType(Operation *op); | |||
283 | LogicalResult verifySameOperandsAndResultElementType(Operation *op); | |||
284 | LogicalResult verifySameOperandsAndResultType(Operation *op); | |||
285 | LogicalResult verifyResultsAreBoolLike(Operation *op); | |||
286 | LogicalResult verifyResultsAreFloatLike(Operation *op); | |||
287 | LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op); | |||
288 | LogicalResult verifyIsTerminator(Operation *op); | |||
289 | LogicalResult verifyZeroSuccessor(Operation *op); | |||
290 | LogicalResult verifyOneSuccessor(Operation *op); | |||
291 | LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors); | |||
292 | LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors); | |||
293 | LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName, | |||
294 | StringRef valueGroupName, | |||
295 | size_t expectedCount); | |||
296 | LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName); | |||
297 | LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName); | |||
298 | LogicalResult verifyNoRegionArguments(Operation *op); | |||
299 | LogicalResult verifyElementwise(Operation *op); | |||
300 | LogicalResult verifyIsIsolatedFromAbove(Operation *op); | |||
301 | } // namespace impl | |||
302 | ||||
303 | /// Helper class for implementing traits. Clients are not expected to interact | |||
304 | /// with this directly, so its members are all protected. | |||
305 | template <typename ConcreteType, template <typename> class TraitType> | |||
306 | class TraitBase { | |||
307 | protected: | |||
308 | /// Return the ultimate Operation being worked on. | |||
309 | Operation *getOperation() { | |||
310 | auto *concrete = static_cast<ConcreteType *>(this); | |||
311 | return concrete->getOperation(); | |||
312 | } | |||
313 | }; | |||
314 | ||||
315 | //===----------------------------------------------------------------------===// | |||
316 | // Operand Traits | |||
317 | ||||
318 | namespace detail { | |||
319 | /// Utility trait base that provides accessors for derived traits that have | |||
320 | /// multiple operands. | |||
321 | template <typename ConcreteType, template <typename> class TraitType> | |||
322 | struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> { | |||
323 | using operand_iterator = Operation::operand_iterator; | |||
324 | using operand_range = Operation::operand_range; | |||
325 | using operand_type_iterator = Operation::operand_type_iterator; | |||
326 | using operand_type_range = Operation::operand_type_range; | |||
327 | ||||
328 | /// Return the number of operands. | |||
329 | unsigned getNumOperands() { return this->getOperation()->getNumOperands(); } | |||
330 | ||||
331 | /// Return the operand at index 'i'. | |||
332 | Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); } | |||
333 | ||||
334 | /// Set the operand at index 'i' to 'value'. | |||
335 | void setOperand(unsigned i, Value value) { | |||
336 | this->getOperation()->setOperand(i, value); | |||
337 | } | |||
338 | ||||
339 | /// Operand iterator access. | |||
340 | operand_iterator operand_begin() { | |||
341 | return this->getOperation()->operand_begin(); | |||
342 | } | |||
343 | operand_iterator operand_end() { return this->getOperation()->operand_end(); } | |||
344 | operand_range getOperands() { return this->getOperation()->getOperands(); } | |||
345 | ||||
346 | /// Operand type access. | |||
347 | operand_type_iterator operand_type_begin() { | |||
348 | return this->getOperation()->operand_type_begin(); | |||
349 | } | |||
350 | operand_type_iterator operand_type_end() { | |||
351 | return this->getOperation()->operand_type_end(); | |||
352 | } | |||
353 | operand_type_range getOperandTypes() { | |||
354 | return this->getOperation()->getOperandTypes(); | |||
355 | } | |||
356 | }; | |||
357 | } // namespace detail | |||
358 | ||||
359 | /// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc. | |||
360 | /// It should be run after core traits and before any other user defined traits. | |||
361 | /// In order to run it in the correct order, wrap it with OpInvariants trait so | |||
362 | /// that tblgen will be able to put it in the right order. | |||
363 | template <typename ConcreteType> | |||
364 | class OpInvariants : public TraitBase<ConcreteType, OpInvariants> { | |||
365 | public: | |||
366 | static LogicalResult verifyTrait(Operation *op) { | |||
367 | return cast<ConcreteType>(op).verifyInvariantsImpl(); | |||
368 | } | |||
369 | }; | |||
370 | ||||
371 | /// This class provides the API for ops that are known to have no | |||
372 | /// SSA operand. | |||
373 | template <typename ConcreteType> | |||
374 | class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> { | |||
375 | public: | |||
376 | static LogicalResult verifyTrait(Operation *op) { | |||
377 | return impl::verifyZeroOperands(op); | |||
378 | } | |||
379 | ||||
380 | private: | |||
381 | // Disable these. | |||
382 | void getOperand() {} | |||
383 | void setOperand() {} | |||
384 | }; | |||
385 | ||||
386 | /// This class provides the API for ops that are known to have exactly one | |||
387 | /// SSA operand. | |||
388 | template <typename ConcreteType> | |||
389 | class OneOperand : public TraitBase<ConcreteType, OneOperand> { | |||
390 | public: | |||
391 | Value getOperand() { return this->getOperation()->getOperand(0); } | |||
392 | ||||
393 | void setOperand(Value value) { this->getOperation()->setOperand(0, value); } | |||
394 | ||||
395 | static LogicalResult verifyTrait(Operation *op) { | |||
396 | return impl::verifyOneOperand(op); | |||
397 | } | |||
398 | }; | |||
399 | ||||
400 | /// This class provides the API for ops that are known to have a specified | |||
401 | /// number of operands. This is used as a trait like this: | |||
402 | /// | |||
403 | /// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> { | |||
404 | /// | |||
405 | template <unsigned N> | |||
406 | class NOperands { | |||
407 | public: | |||
408 | static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2"); | |||
409 | ||||
410 | template <typename ConcreteType> | |||
411 | class Impl | |||
412 | : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> { | |||
413 | public: | |||
414 | static LogicalResult verifyTrait(Operation *op) { | |||
415 | return impl::verifyNOperands(op, N); | |||
416 | } | |||
417 | }; | |||
418 | }; | |||
419 | ||||
420 | /// This class provides the API for ops that are known to have a at least a | |||
421 | /// specified number of operands. This is used as a trait like this: | |||
422 | /// | |||
423 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> { | |||
424 | /// | |||
425 | template <unsigned N> | |||
426 | class AtLeastNOperands { | |||
427 | public: | |||
428 | template <typename ConcreteType> | |||
429 | class Impl : public detail::MultiOperandTraitBase<ConcreteType, | |||
430 | AtLeastNOperands<N>::Impl> { | |||
431 | public: | |||
432 | static LogicalResult verifyTrait(Operation *op) { | |||
433 | return impl::verifyAtLeastNOperands(op, N); | |||
434 | } | |||
435 | }; | |||
436 | }; | |||
437 | ||||
438 | /// This class provides the API for ops which have an unknown number of | |||
439 | /// SSA operands. | |||
440 | template <typename ConcreteType> | |||
441 | class VariadicOperands | |||
442 | : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {}; | |||
443 | ||||
444 | //===----------------------------------------------------------------------===// | |||
445 | // Region Traits | |||
446 | ||||
447 | /// This class provides verification for ops that are known to have zero | |||
448 | /// regions. | |||
449 | template <typename ConcreteType> | |||
450 | class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> { | |||
451 | public: | |||
452 | static LogicalResult verifyTrait(Operation *op) { | |||
453 | return impl::verifyZeroRegion(op); | |||
454 | } | |||
455 | }; | |||
456 | ||||
457 | namespace detail { | |||
458 | /// Utility trait base that provides accessors for derived traits that have | |||
459 | /// multiple regions. | |||
460 | template <typename ConcreteType, template <typename> class TraitType> | |||
461 | struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> { | |||
462 | using region_iterator = MutableArrayRef<Region>; | |||
463 | using region_range = RegionRange; | |||
464 | ||||
465 | /// Return the number of regions. | |||
466 | unsigned getNumRegions() { return this->getOperation()->getNumRegions(); } | |||
467 | ||||
468 | /// Return the region at `index`. | |||
469 | Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); } | |||
470 | ||||
471 | /// Region iterator access. | |||
472 | region_iterator region_begin() { | |||
473 | return this->getOperation()->region_begin(); | |||
474 | } | |||
475 | region_iterator region_end() { return this->getOperation()->region_end(); } | |||
476 | region_range getRegions() { return this->getOperation()->getRegions(); } | |||
477 | }; | |||
478 | } // namespace detail | |||
479 | ||||
480 | /// This class provides APIs for ops that are known to have a single region. | |||
481 | template <typename ConcreteType> | |||
482 | class OneRegion : public TraitBase<ConcreteType, OneRegion> { | |||
483 | public: | |||
484 | Region &getRegion() { return this->getOperation()->getRegion(0); } | |||
485 | ||||
486 | /// Returns a range of operations within the region of this operation. | |||
487 | auto getOps() { return getRegion().getOps(); } | |||
488 | template <typename OpT> | |||
489 | auto getOps() { | |||
490 | return getRegion().template getOps<OpT>(); | |||
491 | } | |||
492 | ||||
493 | static LogicalResult verifyTrait(Operation *op) { | |||
494 | return impl::verifyOneRegion(op); | |||
495 | } | |||
496 | }; | |||
497 | ||||
498 | /// This class provides the API for ops that are known to have a specified | |||
499 | /// number of regions. | |||
500 | template <unsigned N> | |||
501 | class NRegions { | |||
502 | public: | |||
503 | static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2"); | |||
504 | ||||
505 | template <typename ConcreteType> | |||
506 | class Impl | |||
507 | : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> { | |||
508 | public: | |||
509 | static LogicalResult verifyTrait(Operation *op) { | |||
510 | return impl::verifyNRegions(op, N); | |||
511 | } | |||
512 | }; | |||
513 | }; | |||
514 | ||||
515 | /// This class provides APIs for ops that are known to have at least a specified | |||
516 | /// number of regions. | |||
517 | template <unsigned N> | |||
518 | class AtLeastNRegions { | |||
519 | public: | |||
520 | template <typename ConcreteType> | |||
521 | class Impl : public detail::MultiRegionTraitBase<ConcreteType, | |||
522 | AtLeastNRegions<N>::Impl> { | |||
523 | public: | |||
524 | static LogicalResult verifyTrait(Operation *op) { | |||
525 | return impl::verifyAtLeastNRegions(op, N); | |||
526 | } | |||
527 | }; | |||
528 | }; | |||
529 | ||||
530 | /// This class provides the API for ops which have an unknown number of | |||
531 | /// regions. | |||
532 | template <typename ConcreteType> | |||
533 | class VariadicRegions | |||
534 | : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {}; | |||
535 | ||||
536 | //===----------------------------------------------------------------------===// | |||
537 | // Result Traits | |||
538 | ||||
539 | /// This class provides return value APIs for ops that are known to have | |||
540 | /// zero results. | |||
541 | template <typename ConcreteType> | |||
542 | class ZeroResult : public TraitBase<ConcreteType, ZeroResult> { | |||
543 | public: | |||
544 | static LogicalResult verifyTrait(Operation *op) { | |||
545 | return impl::verifyZeroResult(op); | |||
546 | } | |||
547 | }; | |||
548 | ||||
549 | namespace detail { | |||
550 | /// Utility trait base that provides accessors for derived traits that have | |||
551 | /// multiple results. | |||
552 | template <typename ConcreteType, template <typename> class TraitType> | |||
553 | struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> { | |||
554 | using result_iterator = Operation::result_iterator; | |||
555 | using result_range = Operation::result_range; | |||
556 | using result_type_iterator = Operation::result_type_iterator; | |||
557 | using result_type_range = Operation::result_type_range; | |||
558 | ||||
559 | /// Return the number of results. | |||
560 | unsigned getNumResults() { return this->getOperation()->getNumResults(); } | |||
561 | ||||
562 | /// Return the result at index 'i'. | |||
563 | Value getResult(unsigned i) { return this->getOperation()->getResult(i); } | |||
564 | ||||
565 | /// Replace all uses of results of this operation with the provided 'values'. | |||
566 | /// 'values' may correspond to an existing operation, or a range of 'Value'. | |||
567 | template <typename ValuesT> | |||
568 | void replaceAllUsesWith(ValuesT &&values) { | |||
569 | this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values)); | |||
570 | } | |||
571 | ||||
572 | /// Return the type of the `i`-th result. | |||
573 | Type getType(unsigned i) { return getResult(i).getType(); } | |||
574 | ||||
575 | /// Result iterator access. | |||
576 | result_iterator result_begin() { | |||
577 | return this->getOperation()->result_begin(); | |||
578 | } | |||
579 | result_iterator result_end() { return this->getOperation()->result_end(); } | |||
580 | result_range getResults() { return this->getOperation()->getResults(); } | |||
581 | ||||
582 | /// Result type access. | |||
583 | result_type_iterator result_type_begin() { | |||
584 | return this->getOperation()->result_type_begin(); | |||
585 | } | |||
586 | result_type_iterator result_type_end() { | |||
587 | return this->getOperation()->result_type_end(); | |||
588 | } | |||
589 | result_type_range getResultTypes() { | |||
590 | return this->getOperation()->getResultTypes(); | |||
591 | } | |||
592 | }; | |||
593 | } // namespace detail | |||
594 | ||||
595 | /// This class provides return value APIs for ops that are known to have a | |||
596 | /// single result. ResultType is the concrete type returned by getType(). | |||
597 | template <typename ConcreteType> | |||
598 | class OneResult : public TraitBase<ConcreteType, OneResult> { | |||
599 | public: | |||
600 | Value getResult() { return this->getOperation()->getResult(0); } | |||
601 | ||||
602 | /// If the operation returns a single value, then the Op can be implicitly | |||
603 | /// converted to an Value. This yields the value of the only result. | |||
604 | operator Value() { return getResult(); } | |||
605 | ||||
606 | /// Replace all uses of 'this' value with the new value, updating anything | |||
607 | /// in the IR that uses 'this' to use the other value instead. When this | |||
608 | /// returns there are zero uses of 'this'. | |||
609 | void replaceAllUsesWith(Value newValue) { | |||
610 | getResult().replaceAllUsesWith(newValue); | |||
611 | } | |||
612 | ||||
613 | /// Replace all uses of 'this' value with the result of 'op'. | |||
614 | void replaceAllUsesWith(Operation *op) { | |||
615 | this->getOperation()->replaceAllUsesWith(op); | |||
616 | } | |||
617 | ||||
618 | static LogicalResult verifyTrait(Operation *op) { | |||
619 | return impl::verifyOneResult(op); | |||
620 | } | |||
621 | }; | |||
622 | ||||
623 | /// This trait is used for return value APIs for ops that are known to have a | |||
624 | /// specific type other than `Type`. This allows the "getType()" member to be | |||
625 | /// more specific for an op. This should be used in conjunction with OneResult, | |||
626 | /// and occur in the trait list before OneResult. | |||
627 | template <typename ResultType> | |||
628 | class OneTypedResult { | |||
629 | public: | |||
630 | /// This class provides return value APIs for ops that are known to have a | |||
631 | /// single result. ResultType is the concrete type returned by getType(). | |||
632 | template <typename ConcreteType> | |||
633 | class Impl | |||
634 | : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> { | |||
635 | public: | |||
636 | ResultType getType() { | |||
637 | auto resultTy = this->getOperation()->getResult(0).getType(); | |||
638 | return resultTy.template cast<ResultType>(); | |||
639 | } | |||
640 | }; | |||
641 | }; | |||
642 | ||||
643 | /// This class provides the API for ops that are known to have a specified | |||
644 | /// number of results. This is used as a trait like this: | |||
645 | /// | |||
646 | /// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> { | |||
647 | /// | |||
648 | template <unsigned N> | |||
649 | class NResults { | |||
650 | public: | |||
651 | static_assert(N > 1, "use ZeroResult/OneResult for N < 2"); | |||
652 | ||||
653 | template <typename ConcreteType> | |||
654 | class Impl | |||
655 | : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> { | |||
656 | public: | |||
657 | static LogicalResult verifyTrait(Operation *op) { | |||
658 | return impl::verifyNResults(op, N); | |||
659 | } | |||
660 | }; | |||
661 | }; | |||
662 | ||||
663 | /// This class provides the API for ops that are known to have at least a | |||
664 | /// specified number of results. This is used as a trait like this: | |||
665 | /// | |||
666 | /// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> { | |||
667 | /// | |||
668 | template <unsigned N> | |||
669 | class AtLeastNResults { | |||
670 | public: | |||
671 | template <typename ConcreteType> | |||
672 | class Impl : public detail::MultiResultTraitBase<ConcreteType, | |||
673 | AtLeastNResults<N>::Impl> { | |||
674 | public: | |||
675 | static LogicalResult verifyTrait(Operation *op) { | |||
676 | return impl::verifyAtLeastNResults(op, N); | |||
677 | } | |||
678 | }; | |||
679 | }; | |||
680 | ||||
681 | /// This class provides the API for ops which have an unknown number of | |||
682 | /// results. | |||
683 | template <typename ConcreteType> | |||
684 | class VariadicResults | |||
685 | : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {}; | |||
686 | ||||
687 | //===----------------------------------------------------------------------===// | |||
688 | // Terminator Traits | |||
689 | ||||
690 | /// This class indicates that the regions associated with this op don't have | |||
691 | /// terminators. | |||
692 | template <typename ConcreteType> | |||
693 | class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {}; | |||
694 | ||||
695 | /// This class provides the API for ops that are known to be terminators. | |||
696 | template <typename ConcreteType> | |||
697 | class IsTerminator : public TraitBase<ConcreteType, IsTerminator> { | |||
698 | public: | |||
699 | static LogicalResult verifyTrait(Operation *op) { | |||
700 | return impl::verifyIsTerminator(op); | |||
701 | } | |||
702 | }; | |||
703 | ||||
704 | /// This class provides verification for ops that are known to have zero | |||
705 | /// successors. | |||
706 | template <typename ConcreteType> | |||
707 | class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> { | |||
708 | public: | |||
709 | static LogicalResult verifyTrait(Operation *op) { | |||
710 | return impl::verifyZeroSuccessor(op); | |||
711 | } | |||
712 | }; | |||
713 | ||||
714 | namespace detail { | |||
715 | /// Utility trait base that provides accessors for derived traits that have | |||
716 | /// multiple successors. | |||
717 | template <typename ConcreteType, template <typename> class TraitType> | |||
718 | struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> { | |||
719 | using succ_iterator = Operation::succ_iterator; | |||
720 | using succ_range = SuccessorRange; | |||
721 | ||||
722 | /// Return the number of successors. | |||
723 | unsigned getNumSuccessors() { | |||
724 | return this->getOperation()->getNumSuccessors(); | |||
725 | } | |||
726 | ||||
727 | /// Return the successor at `index`. | |||
728 | Block *getSuccessor(unsigned i) { | |||
729 | return this->getOperation()->getSuccessor(i); | |||
730 | } | |||
731 | ||||
732 | /// Set the successor at `index`. | |||
733 | void setSuccessor(Block *block, unsigned i) { | |||
734 | return this->getOperation()->setSuccessor(block, i); | |||
735 | } | |||
736 | ||||
737 | /// Successor iterator access. | |||
738 | succ_iterator succ_begin() { return this->getOperation()->succ_begin(); } | |||
739 | succ_iterator succ_end() { return this->getOperation()->succ_end(); } | |||
740 | succ_range getSuccessors() { return this->getOperation()->getSuccessors(); } | |||
741 | }; | |||
742 | } // namespace detail | |||
743 | ||||
744 | /// This class provides APIs for ops that are known to have a single successor. | |||
745 | template <typename ConcreteType> | |||
746 | class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> { | |||
747 | public: | |||
748 | Block *getSuccessor() { return this->getOperation()->getSuccessor(0); } | |||
749 | void setSuccessor(Block *succ) { | |||
750 | this->getOperation()->setSuccessor(succ, 0); | |||
751 | } | |||
752 | ||||
753 | static LogicalResult verifyTrait(Operation *op) { | |||
754 | return impl::verifyOneSuccessor(op); | |||
755 | } | |||
756 | }; | |||
757 | ||||
758 | /// This class provides the API for ops that are known to have a specified | |||
759 | /// number of successors. | |||
760 | template <unsigned N> | |||
761 | class NSuccessors { | |||
762 | public: | |||
763 | static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2"); | |||
764 | ||||
765 | template <typename ConcreteType> | |||
766 | class Impl : public detail::MultiSuccessorTraitBase<ConcreteType, | |||
767 | NSuccessors<N>::Impl> { | |||
768 | public: | |||
769 | static LogicalResult verifyTrait(Operation *op) { | |||
770 | return impl::verifyNSuccessors(op, N); | |||
771 | } | |||
772 | }; | |||
773 | }; | |||
774 | ||||
775 | /// This class provides APIs for ops that are known to have at least a specified | |||
776 | /// number of successors. | |||
777 | template <unsigned N> | |||
778 | class AtLeastNSuccessors { | |||
779 | public: | |||
780 | template <typename ConcreteType> | |||
781 | class Impl | |||
782 | : public detail::MultiSuccessorTraitBase<ConcreteType, | |||
783 | AtLeastNSuccessors<N>::Impl> { | |||
784 | public: | |||
785 | static LogicalResult verifyTrait(Operation *op) { | |||
786 | return impl::verifyAtLeastNSuccessors(op, N); | |||
787 | } | |||
788 | }; | |||
789 | }; | |||
790 | ||||
791 | /// This class provides the API for ops which have an unknown number of | |||
792 | /// successors. | |||
793 | template <typename ConcreteType> | |||
794 | class VariadicSuccessors | |||
795 | : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> { | |||
796 | }; | |||
797 | ||||
798 | //===----------------------------------------------------------------------===// | |||
799 | // SingleBlock | |||
800 | ||||
801 | /// This class provides APIs and verifiers for ops with regions having a single | |||
802 | /// block. | |||
803 | template <typename ConcreteType> | |||
804 | struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> { | |||
805 | public: | |||
806 | static LogicalResult verifyTrait(Operation *op) { | |||
807 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { | |||
808 | Region ®ion = op->getRegion(i); | |||
809 | ||||
810 | // Empty regions are fine. | |||
811 | if (region.empty()) | |||
812 | continue; | |||
813 | ||||
814 | // Non-empty regions must contain a single basic block. | |||
815 | if (!llvm::hasSingleElement(region)) | |||
816 | return op->emitOpError("expects region #") | |||
817 | << i << " to have 0 or 1 blocks"; | |||
818 | ||||
819 | if (!ConcreteType::template hasTrait<NoTerminator>()) { | |||
820 | Block &block = region.front(); | |||
821 | if (block.empty()) | |||
822 | return op->emitOpError() << "expects a non-empty block"; | |||
823 | } | |||
824 | } | |||
825 | return success(); | |||
826 | } | |||
827 | ||||
828 | Block *getBody(unsigned idx = 0) { | |||
829 | Region ®ion = this->getOperation()->getRegion(idx); | |||
830 | assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region" ) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\"" , "mlir/include/mlir/IR/OpDefinition.h", 830, __extension__ __PRETTY_FUNCTION__ )); | |||
831 | return ®ion.front(); | |||
832 | } | |||
833 | Region &getBodyRegion(unsigned idx = 0) { | |||
834 | return this->getOperation()->getRegion(idx); | |||
835 | } | |||
836 | ||||
837 | //===------------------------------------------------------------------===// | |||
838 | // Single Region Utilities | |||
839 | //===------------------------------------------------------------------===// | |||
840 | ||||
841 | /// The following are a set of methods only enabled when the parent | |||
842 | /// operation has a single region. Each of these methods take an additional | |||
843 | /// template parameter that represents the concrete operation so that we | |||
844 | /// can use SFINAE to disable the methods for non-single region operations. | |||
845 | template <typename OpT, typename T = void> | |||
846 | using enable_if_single_region = | |||
847 | typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; | |||
848 | ||||
849 | template <typename OpT = ConcreteType> | |||
850 | enable_if_single_region<OpT, Block::iterator> begin() { | |||
851 | return getBody()->begin(); | |||
852 | } | |||
853 | template <typename OpT = ConcreteType> | |||
854 | enable_if_single_region<OpT, Block::iterator> end() { | |||
855 | return getBody()->end(); | |||
856 | } | |||
857 | template <typename OpT = ConcreteType> | |||
858 | enable_if_single_region<OpT, Operation &> front() { | |||
859 | return *begin(); | |||
860 | } | |||
861 | ||||
862 | /// Insert the operation into the back of the body. | |||
863 | template <typename OpT = ConcreteType> | |||
864 | enable_if_single_region<OpT> push_back(Operation *op) { | |||
865 | insert(Block::iterator(getBody()->end()), op); | |||
866 | } | |||
867 | ||||
868 | /// Insert the operation at the given insertion point. | |||
869 | template <typename OpT = ConcreteType> | |||
870 | enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { | |||
871 | insert(Block::iterator(insertPt), op); | |||
872 | } | |||
873 | template <typename OpT = ConcreteType> | |||
874 | enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) { | |||
875 | getBody()->getOperations().insert(insertPt, op); | |||
876 | } | |||
877 | }; | |||
878 | ||||
879 | //===----------------------------------------------------------------------===// | |||
880 | // SingleBlockImplicitTerminator | |||
881 | ||||
882 | /// This class provides APIs and verifiers for ops with regions having a single | |||
883 | /// block that must terminate with `TerminatorOpType`. | |||
884 | template <typename TerminatorOpType> | |||
885 | struct SingleBlockImplicitTerminator { | |||
886 | template <typename ConcreteType> | |||
887 | class Impl : public SingleBlock<ConcreteType> { | |||
888 | private: | |||
889 | using Base = SingleBlock<ConcreteType>; | |||
890 | /// Builds a terminator operation without relying on OpBuilder APIs to avoid | |||
891 | /// cyclic header inclusion. | |||
892 | static Operation *buildTerminator(OpBuilder &builder, Location loc) { | |||
893 | OperationState state(loc, TerminatorOpType::getOperationName()); | |||
894 | TerminatorOpType::build(builder, state); | |||
895 | return Operation::create(state); | |||
896 | } | |||
897 | ||||
898 | public: | |||
899 | /// The type of the operation used as the implicit terminator type. | |||
900 | using ImplicitTerminatorOpT = TerminatorOpType; | |||
901 | ||||
902 | static LogicalResult verifyRegionTrait(Operation *op) { | |||
903 | if (failed(Base::verifyTrait(op))) | |||
904 | return failure(); | |||
905 | for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) { | |||
906 | Region ®ion = op->getRegion(i); | |||
907 | // Empty regions are fine. | |||
908 | if (region.empty()) | |||
909 | continue; | |||
910 | Operation &terminator = region.front().back(); | |||
911 | if (isa<TerminatorOpType>(terminator)) | |||
912 | continue; | |||
913 | ||||
914 | return op->emitOpError("expects regions to end with '" + | |||
915 | TerminatorOpType::getOperationName() + | |||
916 | "', found '" + | |||
917 | terminator.getName().getStringRef() + "'") | |||
918 | .attachNote() | |||
919 | << "in custom textual format, the absence of terminator implies " | |||
920 | "'" | |||
921 | << TerminatorOpType::getOperationName() << '\''; | |||
922 | } | |||
923 | ||||
924 | return success(); | |||
925 | } | |||
926 | ||||
927 | /// Ensure that the given region has the terminator required by this trait. | |||
928 | /// If OpBuilder is provided, use it to build the terminator and notify the | |||
929 | /// OpBuilder listeners accordingly. If only a Builder is provided, locally | |||
930 | /// construct an OpBuilder with no listeners; this should only be used if no | |||
931 | /// OpBuilder is available at the call site, e.g., in the parser. | |||
932 | static void ensureTerminator(Region ®ion, Builder &builder, | |||
933 | Location loc) { | |||
934 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, | |||
935 | buildTerminator); | |||
936 | } | |||
937 | static void ensureTerminator(Region ®ion, OpBuilder &builder, | |||
938 | Location loc) { | |||
939 | ::mlir::impl::ensureRegionTerminator(region, builder, loc, | |||
940 | buildTerminator); | |||
941 | } | |||
942 | ||||
943 | //===------------------------------------------------------------------===// | |||
944 | // Single Region Utilities | |||
945 | //===------------------------------------------------------------------===// | |||
946 | using Base::getBody; | |||
947 | ||||
948 | template <typename OpT, typename T = void> | |||
949 | using enable_if_single_region = | |||
950 | typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>; | |||
951 | ||||
952 | /// Insert the operation into the back of the body, before the terminator. | |||
953 | template <typename OpT = ConcreteType> | |||
954 | enable_if_single_region<OpT> push_back(Operation *op) { | |||
955 | insert(Block::iterator(getBody()->getTerminator()), op); | |||
956 | } | |||
957 | ||||
958 | /// Insert the operation at the given insertion point. Note: The operation | |||
959 | /// is never inserted after the terminator, even if the insertion point is | |||
960 | /// end(). | |||
961 | template <typename OpT = ConcreteType> | |||
962 | enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) { | |||
963 | insert(Block::iterator(insertPt), op); | |||
964 | } | |||
965 | template <typename OpT = ConcreteType> | |||
966 | enable_if_single_region<OpT> insert(Block::iterator insertPt, | |||
967 | Operation *op) { | |||
968 | auto *body = getBody(); | |||
969 | if (insertPt == body->end()) | |||
970 | insertPt = Block::iterator(body->getTerminator()); | |||
971 | body->getOperations().insert(insertPt, op); | |||
972 | } | |||
973 | }; | |||
974 | }; | |||
975 | ||||
976 | /// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended | |||
977 | /// to be used with `llvm::is_detected`. | |||
978 | template <class T> | |||
979 | using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT; | |||
980 | ||||
981 | /// Support to check if an operation has the SingleBlockImplicitTerminator | |||
982 | /// trait. We can't just use `hasTrait` because this class is templated on a | |||
983 | /// specific terminator op. | |||
984 | template <class Op, bool hasTerminator = | |||
985 | llvm::is_detected<has_implicit_terminator_t, Op>::value> | |||
986 | struct hasSingleBlockImplicitTerminator { | |||
987 | static constexpr bool value = std::is_base_of< | |||
988 | typename OpTrait::SingleBlockImplicitTerminator< | |||
989 | typename Op::ImplicitTerminatorOpT>::template Impl<Op>, | |||
990 | Op>::value; | |||
991 | }; | |||
992 | template <class Op> | |||
993 | struct hasSingleBlockImplicitTerminator<Op, false> { | |||
994 | static constexpr bool value = false; | |||
995 | }; | |||
996 | ||||
997 | //===----------------------------------------------------------------------===// | |||
998 | // Misc Traits | |||
999 | ||||
1000 | /// This class provides verification for ops that are known to have the same | |||
1001 | /// operand shape: all operands are scalars, vectors/tensors of the same | |||
1002 | /// shape. | |||
1003 | template <typename ConcreteType> | |||
1004 | class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> { | |||
1005 | public: | |||
1006 | static LogicalResult verifyTrait(Operation *op) { | |||
1007 | return impl::verifySameOperandsShape(op); | |||
1008 | } | |||
1009 | }; | |||
1010 | ||||
1011 | /// This class provides verification for ops that are known to have the same | |||
1012 | /// operand and result shape: both are scalars, vectors/tensors of the same | |||
1013 | /// shape. | |||
1014 | template <typename ConcreteType> | |||
1015 | class SameOperandsAndResultShape | |||
1016 | : public TraitBase<ConcreteType, SameOperandsAndResultShape> { | |||
1017 | public: | |||
1018 | static LogicalResult verifyTrait(Operation *op) { | |||
1019 | return impl::verifySameOperandsAndResultShape(op); | |||
1020 | } | |||
1021 | }; | |||
1022 | ||||
1023 | /// This class provides verification for ops that are known to have the same | |||
1024 | /// operand element type (or the type itself if it is scalar). | |||
1025 | /// | |||
1026 | template <typename ConcreteType> | |||
1027 | class SameOperandsElementType | |||
1028 | : public TraitBase<ConcreteType, SameOperandsElementType> { | |||
1029 | public: | |||
1030 | static LogicalResult verifyTrait(Operation *op) { | |||
1031 | return impl::verifySameOperandsElementType(op); | |||
1032 | } | |||
1033 | }; | |||
1034 | ||||
1035 | /// This class provides verification for ops that are known to have the same | |||
1036 | /// operand and result element type (or the type itself if it is scalar). | |||
1037 | /// | |||
1038 | template <typename ConcreteType> | |||
1039 | class SameOperandsAndResultElementType | |||
1040 | : public TraitBase<ConcreteType, SameOperandsAndResultElementType> { | |||
1041 | public: | |||
1042 | static LogicalResult verifyTrait(Operation *op) { | |||
1043 | return impl::verifySameOperandsAndResultElementType(op); | |||
1044 | } | |||
1045 | }; | |||
1046 | ||||
1047 | /// This class provides verification for ops that are known to have the same | |||
1048 | /// operand and result type. | |||
1049 | /// | |||
1050 | /// Note: this trait subsumes the SameOperandsAndResultShape and | |||
1051 | /// SameOperandsAndResultElementType traits. | |||
1052 | template <typename ConcreteType> | |||
1053 | class SameOperandsAndResultType | |||
1054 | : public TraitBase<ConcreteType, SameOperandsAndResultType> { | |||
1055 | public: | |||
1056 | static LogicalResult verifyTrait(Operation *op) { | |||
1057 | return impl::verifySameOperandsAndResultType(op); | |||
1058 | } | |||
1059 | }; | |||
1060 | ||||
1061 | /// This class verifies that any results of the specified op have a boolean | |||
1062 | /// type, a vector thereof, or a tensor thereof. | |||
1063 | template <typename ConcreteType> | |||
1064 | class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> { | |||
1065 | public: | |||
1066 | static LogicalResult verifyTrait(Operation *op) { | |||
1067 | return impl::verifyResultsAreBoolLike(op); | |||
1068 | } | |||
1069 | }; | |||
1070 | ||||
1071 | /// This class verifies that any results of the specified op have a floating | |||
1072 | /// point type, a vector thereof, or a tensor thereof. | |||
1073 | template <typename ConcreteType> | |||
1074 | class ResultsAreFloatLike | |||
1075 | : public TraitBase<ConcreteType, ResultsAreFloatLike> { | |||
1076 | public: | |||
1077 | static LogicalResult verifyTrait(Operation *op) { | |||
1078 | return impl::verifyResultsAreFloatLike(op); | |||
1079 | } | |||
1080 | }; | |||
1081 | ||||
1082 | /// This class verifies that any results of the specified op have a signless | |||
1083 | /// integer or index type, a vector thereof, or a tensor thereof. | |||
1084 | template <typename ConcreteType> | |||
1085 | class ResultsAreSignlessIntegerLike | |||
1086 | : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> { | |||
1087 | public: | |||
1088 | static LogicalResult verifyTrait(Operation *op) { | |||
1089 | return impl::verifyResultsAreSignlessIntegerLike(op); | |||
1090 | } | |||
1091 | }; | |||
1092 | ||||
1093 | /// This class adds property that the operation is commutative. | |||
1094 | template <typename ConcreteType> | |||
1095 | class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {}; | |||
1096 | ||||
1097 | /// This class adds property that the operation is an involution. | |||
1098 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = x | |||
1099 | template <typename ConcreteType> | |||
1100 | class IsInvolution : public TraitBase<ConcreteType, IsInvolution> { | |||
1101 | public: | |||
1102 | static LogicalResult verifyTrait(Operation *op) { | |||
1103 | static_assert(ConcreteType::template hasTrait<OneResult>(), | |||
1104 | "expected operation to produce one result"); | |||
1105 | static_assert(ConcreteType::template hasTrait<OneOperand>(), | |||
1106 | "expected operation to take one operand"); | |||
1107 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), | |||
1108 | "expected operation to preserve type"); | |||
1109 | // Involution requires the operation to be side effect free as well | |||
1110 | // but currently this check is under a FIXME and is not actually done. | |||
1111 | return impl::verifyIsInvolution(op); | |||
1112 | } | |||
1113 | ||||
1114 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { | |||
1115 | return impl::foldInvolution(op); | |||
1116 | } | |||
1117 | }; | |||
1118 | ||||
1119 | /// This class adds property that the operation is idempotent. | |||
1120 | /// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x), | |||
1121 | /// or a binary operation "g" that satisfies g(x, x) = x. | |||
1122 | template <typename ConcreteType> | |||
1123 | class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> { | |||
1124 | public: | |||
1125 | static LogicalResult verifyTrait(Operation *op) { | |||
1126 | static_assert(ConcreteType::template hasTrait<OneResult>(), | |||
1127 | "expected operation to produce one result"); | |||
1128 | static_assert(ConcreteType::template hasTrait<OneOperand>() || | |||
1129 | ConcreteType::template hasTrait<NOperands<2>::Impl>(), | |||
1130 | "expected operation to take one or two operands"); | |||
1131 | static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(), | |||
1132 | "expected operation to preserve type"); | |||
1133 | // Idempotent requires the operation to be side effect free as well | |||
1134 | // but currently this check is under a FIXME and is not actually done. | |||
1135 | return impl::verifyIsIdempotent(op); | |||
1136 | } | |||
1137 | ||||
1138 | static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) { | |||
1139 | return impl::foldIdempotent(op); | |||
1140 | } | |||
1141 | }; | |||
1142 | ||||
1143 | /// This class verifies that all operands of the specified op have a float type, | |||
1144 | /// a vector thereof, or a tensor thereof. | |||
1145 | template <typename ConcreteType> | |||
1146 | class OperandsAreFloatLike | |||
1147 | : public TraitBase<ConcreteType, OperandsAreFloatLike> { | |||
1148 | public: | |||
1149 | static LogicalResult verifyTrait(Operation *op) { | |||
1150 | return impl::verifyOperandsAreFloatLike(op); | |||
1151 | } | |||
1152 | }; | |||
1153 | ||||
1154 | /// This class verifies that all operands of the specified op have a signless | |||
1155 | /// integer or index type, a vector thereof, or a tensor thereof. | |||
1156 | template <typename ConcreteType> | |||
1157 | class OperandsAreSignlessIntegerLike | |||
1158 | : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> { | |||
1159 | public: | |||
1160 | static LogicalResult verifyTrait(Operation *op) { | |||
1161 | return impl::verifyOperandsAreSignlessIntegerLike(op); | |||
1162 | } | |||
1163 | }; | |||
1164 | ||||
1165 | /// This class verifies that all operands of the specified op have the same | |||
1166 | /// type. | |||
1167 | template <typename ConcreteType> | |||
1168 | class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> { | |||
1169 | public: | |||
1170 | static LogicalResult verifyTrait(Operation *op) { | |||
1171 | return impl::verifySameTypeOperands(op); | |||
1172 | } | |||
1173 | }; | |||
1174 | ||||
1175 | /// This class provides the API for a sub-set of ops that are known to be | |||
1176 | /// constant-like. These are non-side effecting operations with one result and | |||
1177 | /// zero operands that can always be folded to a specific attribute value. | |||
1178 | template <typename ConcreteType> | |||
1179 | class ConstantLike : public TraitBase<ConcreteType, ConstantLike> { | |||
1180 | public: | |||
1181 | static LogicalResult verifyTrait(Operation *op) { | |||
1182 | static_assert(ConcreteType::template hasTrait<OneResult>(), | |||
1183 | "expected operation to produce one result"); | |||
1184 | static_assert(ConcreteType::template hasTrait<ZeroOperands>(), | |||
1185 | "expected operation to take zero operands"); | |||
1186 | // TODO: We should verify that the operation can always be folded, but this | |||
1187 | // requires that the attributes of the op already be verified. We should add | |||
1188 | // support for verifying traits "after" the operation to enable this use | |||
1189 | // case. | |||
1190 | return success(); | |||
1191 | } | |||
1192 | }; | |||
1193 | ||||
1194 | /// This class provides the API for ops that are known to be isolated from | |||
1195 | /// above. | |||
1196 | template <typename ConcreteType> | |||
1197 | class IsIsolatedFromAbove | |||
1198 | : public TraitBase<ConcreteType, IsIsolatedFromAbove> { | |||
1199 | public: | |||
1200 | static LogicalResult verifyRegionTrait(Operation *op) { | |||
1201 | return impl::verifyIsIsolatedFromAbove(op); | |||
1202 | } | |||
1203 | }; | |||
1204 | ||||
1205 | /// A trait of region holding operations that defines a new scope for polyhedral | |||
1206 | /// optimization purposes. Any SSA values of 'index' type that either dominate | |||
1207 | /// such an operation or are used at the top-level of such an operation | |||
1208 | /// automatically become valid symbols for the polyhedral scope defined by that | |||
1209 | /// operation. For more details, see `Traits.md#AffineScope`. | |||
1210 | template <typename ConcreteType> | |||
1211 | class AffineScope : public TraitBase<ConcreteType, AffineScope> { | |||
1212 | public: | |||
1213 | static LogicalResult verifyTrait(Operation *op) { | |||
1214 | static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), | |||
1215 | "expected operation to have one or more regions"); | |||
1216 | return success(); | |||
1217 | } | |||
1218 | }; | |||
1219 | ||||
1220 | /// A trait of region holding operations that define a new scope for automatic | |||
1221 | /// allocations, i.e., allocations that are freed when control is transferred | |||
1222 | /// back from the operation's region. Any operations performing such allocations | |||
1223 | /// (for eg. memref.alloca) will have their allocations automatically freed at | |||
1224 | /// their closest enclosing operation with this trait. | |||
1225 | template <typename ConcreteType> | |||
1226 | class AutomaticAllocationScope | |||
1227 | : public TraitBase<ConcreteType, AutomaticAllocationScope> { | |||
1228 | public: | |||
1229 | static LogicalResult verifyTrait(Operation *op) { | |||
1230 | static_assert(!ConcreteType::template hasTrait<ZeroRegion>(), | |||
1231 | "expected operation to have one or more regions"); | |||
1232 | return success(); | |||
1233 | } | |||
1234 | }; | |||
1235 | ||||
1236 | /// This class provides a verifier for ops that are expecting their parent | |||
1237 | /// to be one of the given parent ops | |||
1238 | template <typename... ParentOpTypes> | |||
1239 | struct HasParent { | |||
1240 | template <typename ConcreteType> | |||
1241 | class Impl : public TraitBase<ConcreteType, Impl> { | |||
1242 | public: | |||
1243 | static LogicalResult verifyTrait(Operation *op) { | |||
1244 | if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp())) | |||
1245 | return success(); | |||
1246 | ||||
1247 | return op->emitOpError() | |||
1248 | << "expects parent op " | |||
1249 | << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'") | |||
1250 | << llvm::makeArrayRef({ParentOpTypes::getOperationName()...}) | |||
1251 | << "'"; | |||
1252 | } | |||
1253 | }; | |||
1254 | }; | |||
1255 | ||||
1256 | /// A trait for operations that have an attribute specifying operand segments. | |||
1257 | /// | |||
1258 | /// Certain operations can have multiple variadic operands and their size | |||
1259 | /// relationship is not always known statically. For such cases, we need | |||
1260 | /// a per-op-instance specification to divide the operands into logical groups | |||
1261 | /// or segments. This can be modeled by attributes. The attribute will be named | |||
1262 | /// as `operand_segment_sizes`. | |||
1263 | /// | |||
1264 | /// This trait verifies the attribute for specifying operand segments has | |||
1265 | /// the correct type (1D vector) and values (non-negative), etc. | |||
1266 | template <typename ConcreteType> | |||
1267 | class AttrSizedOperandSegments | |||
1268 | : public TraitBase<ConcreteType, AttrSizedOperandSegments> { | |||
1269 | public: | |||
1270 | static StringRef getOperandSegmentSizeAttr() { | |||
1271 | return "operand_segment_sizes"; | |||
1272 | } | |||
1273 | ||||
1274 | static LogicalResult verifyTrait(Operation *op) { | |||
1275 | return ::mlir::OpTrait::impl::verifyOperandSizeAttr( | |||
1276 | op, getOperandSegmentSizeAttr()); | |||
1277 | } | |||
1278 | }; | |||
1279 | ||||
1280 | /// Similar to AttrSizedOperandSegments but used for results. | |||
1281 | template <typename ConcreteType> | |||
1282 | class AttrSizedResultSegments | |||
1283 | : public TraitBase<ConcreteType, AttrSizedResultSegments> { | |||
1284 | public: | |||
1285 | static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; } | |||
1286 | ||||
1287 | static LogicalResult verifyTrait(Operation *op) { | |||
1288 | return ::mlir::OpTrait::impl::verifyResultSizeAttr( | |||
1289 | op, getResultSegmentSizeAttr()); | |||
1290 | } | |||
1291 | }; | |||
1292 | ||||
1293 | /// This trait provides a verifier for ops that are expecting their regions to | |||
1294 | /// not have any arguments | |||
1295 | template <typename ConcrentType> | |||
1296 | struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> { | |||
1297 | static LogicalResult verifyTrait(Operation *op) { | |||
1298 | return ::mlir::OpTrait::impl::verifyNoRegionArguments(op); | |||
1299 | } | |||
1300 | }; | |||
1301 | ||||
1302 | // This trait is used to flag operations that consume or produce | |||
1303 | // values of `MemRef` type where those references can be 'normalized'. | |||
1304 | // TODO: Right now, the operands of an operation are either all normalizable, | |||
1305 | // or not. In the future, we may want to allow some of the operands to be | |||
1306 | // normalizable. | |||
1307 | template <typename ConcrentType> | |||
1308 | struct MemRefsNormalizable | |||
1309 | : public TraitBase<ConcrentType, MemRefsNormalizable> {}; | |||
1310 | ||||
1311 | /// This trait tags element-wise ops on vectors or tensors. | |||
1312 | /// | |||
1313 | /// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this | |||
1314 | /// trait. In particular, broadcasting behavior is not allowed. | |||
1315 | /// | |||
1316 | /// An `Elementwise` op must satisfy the following properties: | |||
1317 | /// | |||
1318 | /// 1. If any result is a vector/tensor then at least one operand must also be a | |||
1319 | /// vector/tensor. | |||
1320 | /// 2. If any operand is a vector/tensor then there must be at least one result | |||
1321 | /// and all results must be vectors/tensors. | |||
1322 | /// 3. All operand and result vector/tensor types must be of the same shape. The | |||
1323 | /// shape may be dynamic in which case the op's behaviour is undefined for | |||
1324 | /// non-matching shapes. | |||
1325 | /// 4. The operation must be elementwise on its vector/tensor operands and | |||
1326 | /// results. When applied to single-element vectors/tensors, the result must | |||
1327 | /// be the same per elememnt. | |||
1328 | /// | |||
1329 | /// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new | |||
1330 | /// interface `ElementwiseTypeInterface` that describes the container types for | |||
1331 | /// which the operation is elementwise. | |||
1332 | /// | |||
1333 | /// Rationale: | |||
1334 | /// - 1. and 2. guarantee a well-defined iteration space and exclude the cases | |||
1335 | /// of 0 non-scalar operands or 0 non-scalar results, which complicate a | |||
1336 | /// generic definition of the iteration space. | |||
1337 | /// - 3. guarantees that folding can be done across scalars/vectors/tensors with | |||
1338 | /// the same pattern, as otherwise lots of special handling for type | |||
1339 | /// mismatches would be needed. | |||
1340 | /// - 4. guarantees that no error handling is needed. Higher-level dialects | |||
1341 | /// should reify any needed guards or error handling code before lowering to | |||
1342 | /// an `Elementwise` op. | |||
1343 | template <typename ConcreteType> | |||
1344 | struct Elementwise : public TraitBase<ConcreteType, Elementwise> { | |||
1345 | static LogicalResult verifyTrait(Operation *op) { | |||
1346 | return ::mlir::OpTrait::impl::verifyElementwise(op); | |||
1347 | } | |||
1348 | }; | |||
1349 | ||||
1350 | /// This trait tags `Elementwise` operatons that can be systematically | |||
1351 | /// scalarized. All vector/tensor operands and results are then replaced by | |||
1352 | /// scalars of the respective element type. Semantically, this is the operation | |||
1353 | /// on a single element of the vector/tensor. | |||
1354 | /// | |||
1355 | /// Rationale: | |||
1356 | /// Allow to define the vector/tensor semantics of elementwise operations based | |||
1357 | /// on the same op's behavior on scalars. This provides a constructive procedure | |||
1358 | /// for IR transformations to, e.g., create scalar loop bodies from tensor ops. | |||
1359 | /// | |||
1360 | /// Example: | |||
1361 | /// ``` | |||
1362 | /// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val) | |||
1363 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) | |||
1364 | /// -> tensor<?xf32> | |||
1365 | /// ``` | |||
1366 | /// can be scalarized to | |||
1367 | /// | |||
1368 | /// ``` | |||
1369 | /// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar) | |||
1370 | /// : (i1, f32, f32) -> f32 | |||
1371 | /// ``` | |||
1372 | template <typename ConcreteType> | |||
1373 | struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> { | |||
1374 | static LogicalResult verifyTrait(Operation *op) { | |||
1375 | static_assert( | |||
1376 | ConcreteType::template hasTrait<Elementwise>(), | |||
1377 | "`Scalarizable` trait is only applicable to `Elementwise` ops."); | |||
1378 | return success(); | |||
1379 | } | |||
1380 | }; | |||
1381 | ||||
1382 | /// This trait tags `Elementwise` operatons that can be systematically | |||
1383 | /// vectorized. All scalar operands and results are then replaced by vectors | |||
1384 | /// with the respective element type. Semantically, this is the operation on | |||
1385 | /// multiple elements simultaneously. See also `Tensorizable`. | |||
1386 | /// | |||
1387 | /// Rationale: | |||
1388 | /// Provide the reverse to `Scalarizable` which, when chained together, allows | |||
1389 | /// reasoning about the relationship between the tensor and vector case. | |||
1390 | /// Additionally, it permits reasoning about promoting scalars to vectors via | |||
1391 | /// broadcasting in cases like `%select_scalar_pred` below. | |||
1392 | template <typename ConcreteType> | |||
1393 | struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> { | |||
1394 | static LogicalResult verifyTrait(Operation *op) { | |||
1395 | static_assert( | |||
1396 | ConcreteType::template hasTrait<Elementwise>(), | |||
1397 | "`Vectorizable` trait is only applicable to `Elementwise` ops."); | |||
1398 | return success(); | |||
1399 | } | |||
1400 | }; | |||
1401 | ||||
1402 | /// This trait tags `Elementwise` operatons that can be systematically | |||
1403 | /// tensorized. All scalar operands and results are then replaced by tensors | |||
1404 | /// with the respective element type. Semantically, this is the operation on | |||
1405 | /// multiple elements simultaneously. See also `Vectorizable`. | |||
1406 | /// | |||
1407 | /// Rationale: | |||
1408 | /// Provide the reverse to `Scalarizable` which, when chained together, allows | |||
1409 | /// reasoning about the relationship between the tensor and vector case. | |||
1410 | /// Additionally, it permits reasoning about promoting scalars to tensors via | |||
1411 | /// broadcasting in cases like `%select_scalar_pred` below. | |||
1412 | /// | |||
1413 | /// Examples: | |||
1414 | /// ``` | |||
1415 | /// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32 | |||
1416 | /// ``` | |||
1417 | /// can be tensorized to | |||
1418 | /// ``` | |||
1419 | /// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>) | |||
1420 | /// -> tensor<?xf32> | |||
1421 | /// ``` | |||
1422 | /// | |||
1423 | /// ``` | |||
1424 | /// %scalar_pred = "arith.select"(%pred, %true_val, %false_val) | |||
1425 | /// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32> | |||
1426 | /// ``` | |||
1427 | /// can be tensorized to | |||
1428 | /// ``` | |||
1429 | /// %tensor_pred = "arith.select"(%pred, %true_val, %false_val) | |||
1430 | /// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>) | |||
1431 | /// -> tensor<?xf32> | |||
1432 | /// ``` | |||
1433 | template <typename ConcreteType> | |||
1434 | struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> { | |||
1435 | static LogicalResult verifyTrait(Operation *op) { | |||
1436 | static_assert( | |||
1437 | ConcreteType::template hasTrait<Elementwise>(), | |||
1438 | "`Tensorizable` trait is only applicable to `Elementwise` ops."); | |||
1439 | return success(); | |||
1440 | } | |||
1441 | }; | |||
1442 | ||||
1443 | /// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable` | |||
1444 | /// provide an easy way for scalar operations to conveniently generalize their | |||
1445 | /// behavior to vectors/tensors, and systematize conversion between these forms. | |||
1446 | bool hasElementwiseMappableTraits(Operation *op); | |||
1447 | ||||
1448 | } // namespace OpTrait | |||
1449 | ||||
1450 | //===----------------------------------------------------------------------===// | |||
1451 | // Internal Trait Utilities | |||
1452 | //===----------------------------------------------------------------------===// | |||
1453 | ||||
1454 | namespace op_definition_impl { | |||
1455 | //===----------------------------------------------------------------------===// | |||
1456 | // Trait Existence | |||
1457 | ||||
1458 | /// Returns true if this given Trait ID matches the IDs of any of the provided | |||
1459 | /// trait types `Traits`. | |||
1460 | template <template <typename T> class... Traits> | |||
1461 | static bool hasTrait(TypeID traitID) { | |||
1462 | TypeID traitIDs[] = {TypeID::get<Traits>()...}; | |||
1463 | for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i) | |||
1464 | if (traitIDs[i] == traitID) | |||
1465 | return true; | |||
1466 | return false; | |||
1467 | } | |||
1468 | ||||
1469 | //===----------------------------------------------------------------------===// | |||
1470 | // Trait Folding | |||
1471 | ||||
1472 | /// Trait to check if T provides a 'foldTrait' method for single result | |||
1473 | /// operations. | |||
1474 | template <typename T, typename... Args> | |||
1475 | using has_single_result_fold_trait = decltype(T::foldTrait( | |||
1476 | std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>())); | |||
1477 | template <typename T> | |||
1478 | using detect_has_single_result_fold_trait = | |||
1479 | llvm::is_detected<has_single_result_fold_trait, T>; | |||
1480 | /// Trait to check if T provides a general 'foldTrait' method. | |||
1481 | template <typename T, typename... Args> | |||
1482 | using has_fold_trait = | |||
1483 | decltype(T::foldTrait(std::declval<Operation *>(), | |||
1484 | std::declval<ArrayRef<Attribute>>(), | |||
1485 | std::declval<SmallVectorImpl<OpFoldResult> &>())); | |||
1486 | template <typename T> | |||
1487 | using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>; | |||
1488 | /// Trait to check if T provides any `foldTrait` method. | |||
1489 | /// NOTE: This should use std::disjunction when C++17 is available. | |||
1490 | template <typename T> | |||
1491 | using detect_has_any_fold_trait = | |||
1492 | std::conditional_t<bool(detect_has_fold_trait<T>::value), | |||
1493 | detect_has_fold_trait<T>, | |||
1494 | detect_has_single_result_fold_trait<T>>; | |||
1495 | ||||
1496 | /// Returns the result of folding a trait that implements a `foldTrait` function | |||
1497 | /// that is specialized for operations that have a single result. | |||
1498 | template <typename Trait> | |||
1499 | static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value, | |||
1500 | LogicalResult> | |||
1501 | foldTrait(Operation *op, ArrayRef<Attribute> operands, | |||
1502 | SmallVectorImpl<OpFoldResult> &results) { | |||
1503 | assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__ )) | |||
1504 | "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__ )) | |||
1505 | "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult >() && "expected trait on non single-result operation to implement the " "general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\"" , "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__ )); | |||
1506 | // If a previous trait has already been folded and replaced this operation, we | |||
1507 | // fail to fold this trait. | |||
1508 | if (!results.empty()) | |||
1509 | return failure(); | |||
1510 | ||||
1511 | if (OpFoldResult result = Trait::foldTrait(op, operands)) { | |||
1512 | if (result.template dyn_cast<Value>() != op->getResult(0)) | |||
1513 | results.push_back(result); | |||
1514 | return success(); | |||
1515 | } | |||
1516 | return failure(); | |||
1517 | } | |||
1518 | /// Returns the result of folding a trait that implements a generalized | |||
1519 | /// `foldTrait` function that is supports any operation type. | |||
1520 | template <typename Trait> | |||
1521 | static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult> | |||
1522 | foldTrait(Operation *op, ArrayRef<Attribute> operands, | |||
1523 | SmallVectorImpl<OpFoldResult> &results) { | |||
1524 | // If a previous trait has already been folded and replaced this operation, we | |||
1525 | // fail to fold this trait. | |||
1526 | return results.empty() ? Trait::foldTrait(op, operands, results) : failure(); | |||
1527 | } | |||
1528 | template <typename Trait> | |||
1529 | static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value, | |||
1530 | LogicalResult> | |||
1531 | foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) { | |||
1532 | return failure(); | |||
1533 | } | |||
1534 | ||||
1535 | /// Given a tuple type containing a set of traits, return the result of folding | |||
1536 | /// the given operation. | |||
1537 | template <typename... Ts> | |||
1538 | static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands, | |||
1539 | SmallVectorImpl<OpFoldResult> &results) { | |||
1540 | bool anyFolded = false; | |||
1541 | (void)std::initializer_list<int>{ | |||
1542 | (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...}; | |||
1543 | return success(anyFolded); | |||
1544 | } | |||
1545 | ||||
1546 | //===----------------------------------------------------------------------===// | |||
1547 | // Trait Verification | |||
1548 | ||||
1549 | /// Trait to check if T provides a `verifyTrait` method. | |||
1550 | template <typename T, typename... Args> | |||
1551 | using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>())); | |||
1552 | template <typename T> | |||
1553 | using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>; | |||
1554 | ||||
1555 | /// Trait to check if T provides a `verifyTrait` method. | |||
1556 | template <typename T, typename... Args> | |||
1557 | using has_verify_region_trait = | |||
1558 | decltype(T::verifyRegionTrait(std::declval<Operation *>())); | |||
1559 | template <typename T> | |||
1560 | using detect_has_verify_region_trait = | |||
1561 | llvm::is_detected<has_verify_region_trait, T>; | |||
1562 | ||||
1563 | /// Verify the given trait if it provides a verifier. | |||
1564 | template <typename T> | |||
1565 | std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult> | |||
1566 | verifyTrait(Operation *op) { | |||
1567 | return T::verifyTrait(op); | |||
1568 | } | |||
1569 | template <typename T> | |||
1570 | inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult> | |||
1571 | verifyTrait(Operation *) { | |||
1572 | return success(); | |||
1573 | } | |||
1574 | ||||
1575 | /// Given a set of traits, return the result of verifying the given operation. | |||
1576 | template <typename... Ts> | |||
1577 | LogicalResult verifyTraits(Operation *op) { | |||
1578 | LogicalResult result = success(); | |||
1579 | (void)std::initializer_list<int>{ | |||
1580 | (result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...}; | |||
1581 | return result; | |||
1582 | } | |||
1583 | ||||
1584 | /// Verify the given trait if it provides a region verifier. | |||
1585 | template <typename T> | |||
1586 | std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult> | |||
1587 | verifyRegionTrait(Operation *op) { | |||
1588 | return T::verifyRegionTrait(op); | |||
1589 | } | |||
1590 | template <typename T> | |||
1591 | inline std::enable_if_t<!detect_has_verify_region_trait<T>::value, | |||
1592 | LogicalResult> | |||
1593 | verifyRegionTrait(Operation *) { | |||
1594 | return success(); | |||
1595 | } | |||
1596 | ||||
1597 | /// Given a set of traits, return the result of verifying the regions of the | |||
1598 | /// given operation. | |||
1599 | template <typename... Ts> | |||
1600 | LogicalResult verifyRegionTraits(Operation *op) { | |||
1601 | (void)op; | |||
1602 | LogicalResult result = success(); | |||
1603 | (void)std::initializer_list<int>{ | |||
1604 | (result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(), | |||
1605 | 0)...}; | |||
1606 | return result; | |||
1607 | } | |||
1608 | } // namespace op_definition_impl | |||
1609 | ||||
1610 | //===----------------------------------------------------------------------===// | |||
1611 | // Operation Definition classes | |||
1612 | //===----------------------------------------------------------------------===// | |||
1613 | ||||
1614 | /// This provides public APIs that all operations should have. The template | |||
1615 | /// argument 'ConcreteType' should be the concrete type by CRTP and the others | |||
1616 | /// are base classes by the policy pattern. | |||
1617 | template <typename ConcreteType, template <typename T> class... Traits> | |||
1618 | class Op : public OpState, public Traits<ConcreteType>... { | |||
1619 | public: | |||
1620 | /// Inherit getOperation from `OpState`. | |||
1621 | using OpState::getOperation; | |||
1622 | using OpState::verify; | |||
1623 | using OpState::verifyRegions; | |||
1624 | ||||
1625 | /// Return if this operation contains the provided trait. | |||
1626 | template <template <typename T> class Trait> | |||
1627 | static constexpr bool hasTrait() { | |||
1628 | return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value; | |||
1629 | } | |||
1630 | ||||
1631 | /// Create a deep copy of this operation. | |||
1632 | ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); } | |||
1633 | ||||
1634 | /// Create a partial copy of this operation without traversing into attached | |||
1635 | /// regions. The new operation will have the same number of regions as the | |||
1636 | /// original one, but they will be left empty. | |||
1637 | ConcreteType cloneWithoutRegions() { | |||
1638 | return cast<ConcreteType>(getOperation()->cloneWithoutRegions()); | |||
1639 | } | |||
1640 | ||||
1641 | /// Return true if this "op class" can match against the specified operation. | |||
1642 | static bool classof(Operation *op) { | |||
1643 | if (auto info = op->getRegisteredInfo()) | |||
1644 | return TypeID::get<ConcreteType>() == info->getTypeID(); | |||
1645 | #ifndef NDEBUG | |||
1646 | if (op->getName().getStringRef() == ConcreteType::getOperationName()) | |||
1647 | llvm::report_fatal_error( | |||
1648 | "classof on '" + ConcreteType::getOperationName() + | |||
1649 | "' failed due to the operation not being registered"); | |||
1650 | #endif | |||
1651 | return false; | |||
1652 | } | |||
1653 | /// Provide `classof` support for other OpBase derived classes, such as | |||
1654 | /// Interfaces. | |||
1655 | template <typename T> | |||
1656 | static std::enable_if_t<std::is_base_of<OpState, T>::value, bool> | |||
1657 | classof(const T *op) { | |||
1658 | return classof(const_cast<T *>(op)->getOperation()); | |||
1659 | } | |||
1660 | ||||
1661 | /// Expose the type we are instantiated on to template machinery that may want | |||
1662 | /// to introspect traits on this operation. | |||
1663 | using ConcreteOpType = ConcreteType; | |||
1664 | ||||
1665 | /// This is a public constructor. Any op can be initialized to null. | |||
1666 | explicit Op() : OpState(nullptr) {} | |||
1667 | Op(std::nullptr_t) : OpState(nullptr) {} | |||
1668 | ||||
1669 | /// This is a public constructor to enable access via the llvm::cast family of | |||
1670 | /// methods. This should not be used directly. | |||
1671 | explicit Op(Operation *state) : OpState(state) {} | |||
1672 | ||||
1673 | /// Methods for supporting PointerLikeTypeTraits. | |||
1674 | const void *getAsOpaquePointer() const { | |||
1675 | return static_cast<const void *>((Operation *)*this); | |||
1676 | } | |||
1677 | static ConcreteOpType getFromOpaquePointer(const void *pointer) { | |||
1678 | return ConcreteOpType( | |||
1679 | reinterpret_cast<Operation *>(const_cast<void *>(pointer))); | |||
1680 | } | |||
1681 | ||||
1682 | /// Attach the given models as implementations of the corresponding interfaces | |||
1683 | /// for the concrete operation. | |||
1684 | template <typename... Models> | |||
1685 | static void attachInterface(MLIRContext &context) { | |||
1686 | Optional<RegisteredOperationName> info = RegisteredOperationName::lookup( | |||
1687 | ConcreteType::getOperationName(), &context); | |||
1688 | if (!info) | |||
1689 | llvm::report_fatal_error( | |||
1690 | "Attempting to attach an interface to an unregistered operation " + | |||
1691 | ConcreteType::getOperationName() + "."); | |||
1692 | info->attachInterface<Models...>(); | |||
1693 | } | |||
1694 | ||||
1695 | private: | |||
1696 | /// Trait to check if T provides a 'fold' method for a single result op. | |||
1697 | template <typename T, typename... Args> | |||
1698 | using has_single_result_fold = | |||
1699 | decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>())); | |||
1700 | template <typename T> | |||
1701 | using detect_has_single_result_fold = | |||
1702 | llvm::is_detected<has_single_result_fold, T>; | |||
1703 | /// Trait to check if T provides a general 'fold' method. | |||
1704 | template <typename T, typename... Args> | |||
1705 | using has_fold = decltype(std::declval<T>().fold( | |||
1706 | std::declval<ArrayRef<Attribute>>(), | |||
1707 | std::declval<SmallVectorImpl<OpFoldResult> &>())); | |||
1708 | template <typename T> | |||
1709 | using detect_has_fold = llvm::is_detected<has_fold, T>; | |||
1710 | /// Trait to check if T provides a 'print' method. | |||
1711 | template <typename T, typename... Args> | |||
1712 | using has_print = | |||
1713 | decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>())); | |||
1714 | template <typename T> | |||
1715 | using detect_has_print = llvm::is_detected<has_print, T>; | |||
1716 | ||||
1717 | /// Returns an interface map containing the interfaces registered to this | |||
1718 | /// operation. | |||
1719 | static detail::InterfaceMap getInterfaceMap() { | |||
1720 | return detail::InterfaceMap::template get<Traits<ConcreteType>...>(); | |||
1721 | } | |||
1722 | ||||
1723 | /// Return the internal implementations of each of the OperationName | |||
1724 | /// hooks. | |||
1725 | /// Implementation of `FoldHookFn` OperationName hook. | |||
1726 | static OperationName::FoldHookFn getFoldHookFn() { | |||
1727 | return getFoldHookFnImpl<ConcreteType>(); | |||
1728 | } | |||
1729 | /// The internal implementation of `getFoldHookFn` above that is invoked if | |||
1730 | /// the operation is single result and defines a `fold` method. | |||
1731 | template <typename ConcreteOpT> | |||
1732 | static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, | |||
1733 | Traits<ConcreteOpT>...>::value && | |||
1734 | detect_has_single_result_fold<ConcreteOpT>::value, | |||
1735 | OperationName::FoldHookFn> | |||
1736 | getFoldHookFnImpl() { | |||
1737 | return [](Operation *op, ArrayRef<Attribute> operands, | |||
1738 | SmallVectorImpl<OpFoldResult> &results) { | |||
1739 | return foldSingleResultHook<ConcreteOpT>(op, operands, results); | |||
1740 | }; | |||
1741 | } | |||
1742 | /// The internal implementation of `getFoldHookFn` above that is invoked if | |||
1743 | /// the operation is not single result and defines a `fold` method. | |||
1744 | template <typename ConcreteOpT> | |||
1745 | static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>, | |||
1746 | Traits<ConcreteOpT>...>::value && | |||
1747 | detect_has_fold<ConcreteOpT>::value, | |||
1748 | OperationName::FoldHookFn> | |||
1749 | getFoldHookFnImpl() { | |||
1750 | return [](Operation *op, ArrayRef<Attribute> operands, | |||
1751 | SmallVectorImpl<OpFoldResult> &results) { | |||
1752 | return foldHook<ConcreteOpT>(op, operands, results); | |||
1753 | }; | |||
1754 | } | |||
1755 | /// The internal implementation of `getFoldHookFn` above that is invoked if | |||
1756 | /// the operation does not define a `fold` method. | |||
1757 | template <typename ConcreteOpT> | |||
1758 | static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value && | |||
1759 | !detect_has_fold<ConcreteOpT>::value, | |||
1760 | OperationName::FoldHookFn> | |||
1761 | getFoldHookFnImpl() { | |||
1762 | return [](Operation *op, ArrayRef<Attribute> operands, | |||
1763 | SmallVectorImpl<OpFoldResult> &results) { | |||
1764 | // In this case, we only need to fold the traits of the operation. | |||
1765 | return op_definition_impl::foldTraits<Traits<ConcreteType>...>( | |||
1766 | op, operands, results); | |||
1767 | }; | |||
1768 | } | |||
1769 | /// Return the result of folding a single result operation that defines a | |||
1770 | /// `fold` method. | |||
1771 | template <typename ConcreteOpT> | |||
1772 | static LogicalResult | |||
1773 | foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands, | |||
1774 | SmallVectorImpl<OpFoldResult> &results) { | |||
1775 | OpFoldResult result = cast<ConcreteOpT>(op).fold(operands); | |||
1776 | ||||
1777 | // If the fold failed or was in-place, try to fold the traits of the | |||
1778 | // operation. | |||
1779 | if (!result || result.template dyn_cast<Value>() == op->getResult(0)) { | |||
1780 | if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( | |||
1781 | op, operands, results))) | |||
1782 | return success(); | |||
1783 | return success(static_cast<bool>(result)); | |||
1784 | } | |||
1785 | results.push_back(result); | |||
1786 | return success(); | |||
1787 | } | |||
1788 | /// Return the result of folding an operation that defines a `fold` method. | |||
1789 | template <typename ConcreteOpT> | |||
1790 | static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands, | |||
1791 | SmallVectorImpl<OpFoldResult> &results) { | |||
1792 | LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results); | |||
1793 | ||||
1794 | // If the fold failed or was in-place, try to fold the traits of the | |||
1795 | // operation. | |||
1796 | if (failed(result) || results.empty()) { | |||
1797 | if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>( | |||
1798 | op, operands, results))) | |||
1799 | return success(); | |||
1800 | } | |||
1801 | return result; | |||
1802 | } | |||
1803 | ||||
1804 | /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook. | |||
1805 | static OperationName::GetCanonicalizationPatternsFn | |||
1806 | getGetCanonicalizationPatternsFn() { | |||
1807 | return &ConcreteType::getCanonicalizationPatterns; | |||
1808 | } | |||
1809 | /// Implementation of `GetHasTraitFn` | |||
1810 | static OperationName::HasTraitFn getHasTraitFn() { | |||
1811 | return | |||
1812 | [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); }; | |||
1813 | } | |||
1814 | /// Implementation of `ParseAssemblyFn` OperationName hook. | |||
1815 | static OperationName::ParseAssemblyFn getParseAssemblyFn() { | |||
1816 | return &ConcreteType::parse; | |||
1817 | } | |||
1818 | /// Implementation of `PrintAssemblyFn` OperationName hook. | |||
1819 | static OperationName::PrintAssemblyFn getPrintAssemblyFn() { | |||
1820 | return getPrintAssemblyFnImpl<ConcreteType>(); | |||
1821 | } | |||
1822 | /// The internal implementation of `getPrintAssemblyFn` that is invoked when | |||
1823 | /// the concrete operation does not define a `print` method. | |||
1824 | template <typename ConcreteOpT> | |||
1825 | static std::enable_if_t<!detect_has_print<ConcreteOpT>::value, | |||
1826 | OperationName::PrintAssemblyFn> | |||
1827 | getPrintAssemblyFnImpl() { | |||
1828 | return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) { | |||
1829 | return OpState::print(op, printer, defaultDialect); | |||
1830 | }; | |||
1831 | } | |||
1832 | /// The internal implementation of `getPrintAssemblyFn` that is invoked when | |||
1833 | /// the concrete operation defines a `print` method. | |||
1834 | template <typename ConcreteOpT> | |||
1835 | static std::enable_if_t<detect_has_print<ConcreteOpT>::value, | |||
1836 | OperationName::PrintAssemblyFn> | |||
1837 | getPrintAssemblyFnImpl() { | |||
1838 | return &printAssembly; | |||
1839 | } | |||
1840 | static void printAssembly(Operation *op, OpAsmPrinter &p, | |||
1841 | StringRef defaultDialect) { | |||
1842 | OpState::printOpName(op, p, defaultDialect); | |||
1843 | return cast<ConcreteType>(op).print(p); | |||
1844 | } | |||
1845 | /// Implementation of `VerifyInvariantsFn` OperationName hook. | |||
1846 | static LogicalResult verifyInvariants(Operation *op) { | |||
1847 | static_assert(hasNoDataMembers(), | |||
1848 | "Op class shouldn't define new data members"); | |||
1849 | return failure( | |||
1850 | failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) || | |||
1851 | failed(cast<ConcreteType>(op).verify())); | |||
1852 | } | |||
1853 | static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() { | |||
1854 | return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants); | |||
1855 | } | |||
1856 | /// Implementation of `VerifyRegionInvariantsFn` OperationName hook. | |||
1857 | static LogicalResult verifyRegionInvariants(Operation *op) { | |||
1858 | static_assert(hasNoDataMembers(), | |||
1859 | "Op class shouldn't define new data members"); | |||
1860 | return failure( | |||
1861 | failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>( | |||
1862 | op)) || | |||
1863 | failed(cast<ConcreteType>(op).verifyRegions())); | |||
1864 | } | |||
1865 | static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() { | |||
1866 | return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants); | |||
1867 | } | |||
1868 | ||||
1869 | static constexpr bool hasNoDataMembers() { | |||
1870 | // Checking that the derived class does not define any member by comparing | |||
1871 | // its size to an ad-hoc EmptyOp. | |||
1872 | class EmptyOp : public Op<EmptyOp, Traits...> {}; | |||
1873 | return sizeof(ConcreteType) == sizeof(EmptyOp); | |||
1874 | } | |||
1875 | ||||
1876 | /// Allow access to internal implementation methods. | |||
1877 | friend RegisteredOperationName; | |||
1878 | }; | |||
1879 | ||||
1880 | /// This class represents the base of an operation interface. See the definition | |||
1881 | /// of `detail::Interface` for requirements on the `Traits` type. | |||
1882 | template <typename ConcreteType, typename Traits> | |||
1883 | class OpInterface | |||
1884 | : public detail::Interface<ConcreteType, Operation *, Traits, | |||
1885 | Op<ConcreteType>, OpTrait::TraitBase> { | |||
1886 | public: | |||
1887 | using Base = OpInterface<ConcreteType, Traits>; | |||
1888 | using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits, | |||
1889 | Op<ConcreteType>, OpTrait::TraitBase>; | |||
1890 | ||||
1891 | /// Inherit the base class constructor. | |||
1892 | using InterfaceBase::InterfaceBase; | |||
1893 | ||||
1894 | protected: | |||
1895 | /// Returns the impl interface instance for the given operation. | |||
1896 | static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) { | |||
1897 | OperationName name = op->getName(); | |||
1898 | ||||
1899 | // Access the raw interface from the operation info. | |||
1900 | if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) { | |||
1901 | if (auto *opIface = rInfo->getInterface<ConcreteType>()) | |||
1902 | return opIface; | |||
1903 | // Fallback to the dialect to provide it with a chance to implement this | |||
1904 | // interface for this operation. | |||
1905 | return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>( | |||
1906 | op->getName()); | |||
1907 | } | |||
1908 | // Fallback to the dialect to provide it with a chance to implement this | |||
1909 | // interface for this operation. | |||
1910 | if (Dialect *dialect = name.getDialect()) | |||
1911 | return dialect->getRegisteredInterfaceForOp<ConcreteType>(name); | |||
1912 | return nullptr; | |||
1913 | } | |||
1914 | ||||
1915 | /// Allow access to `getInterfaceFor`. | |||
1916 | friend InterfaceBase; | |||
1917 | }; | |||
1918 | ||||
1919 | //===----------------------------------------------------------------------===// | |||
1920 | // CastOpInterface utilities | |||
1921 | //===----------------------------------------------------------------------===// | |||
1922 | ||||
1923 | // These functions are out-of-line implementations of the methods in | |||
1924 | // CastOpInterface, which avoids them being template instantiated/duplicated. | |||
1925 | namespace impl { | |||
1926 | /// Attempt to fold the given cast operation. | |||
1927 | LogicalResult foldCastInterfaceOp(Operation *op, | |||
1928 | ArrayRef<Attribute> attrOperands, | |||
1929 | SmallVectorImpl<OpFoldResult> &foldResults); | |||
1930 | /// Attempt to verify the given cast operation. | |||
1931 | LogicalResult verifyCastInterfaceOp( | |||
1932 | Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible); | |||
1933 | } // namespace impl | |||
1934 | } // namespace mlir | |||
1935 | ||||
1936 | namespace llvm { | |||
1937 | ||||
1938 | template <typename T> | |||
1939 | struct DenseMapInfo< | |||
1940 | T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> { | |||
1941 | static inline T getEmptyKey() { | |||
1942 | auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey(); | |||
1943 | return T::getFromOpaquePointer(pointer); | |||
1944 | } | |||
1945 | static inline T getTombstoneKey() { | |||
1946 | auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey(); | |||
1947 | return T::getFromOpaquePointer(pointer); | |||
1948 | } | |||
1949 | static unsigned getHashValue(T val) { | |||
1950 | return hash_value(val.getAsOpaquePointer()); | |||
1951 | } | |||
1952 | static bool isEqual(T lhs, T rhs) { return lhs == rhs; } | |||
1953 | }; | |||
1954 | ||||
1955 | } // namespace llvm | |||
1956 | ||||
1957 | #endif |