Bug Summary

File:build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/include/mlir/IR/OpDefinition.h
Warning:line 98, column 5
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SuperVectorize.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-15/lib/clang/15.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Affine/Transforms -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/Dialect/Affine/Transforms -I include -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/llvm/include -I /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-15/lib/clang/15.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-04-20-140412-16051-1 -x c++ /build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

1//===- SuperVectorize.cpp - Vectorize Pass Impl ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements vectorization of loops, operations and data types to
10// a target-independent, n-D super-vector abstraction.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Affine/Utils.h"
20#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21#include "mlir/Dialect/Vector/IR/VectorOps.h"
22#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
23#include "mlir/IR/BlockAndValueMapping.h"
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/Support/Debug.h"
27
28using namespace mlir;
29using namespace vector;
30
31///
32/// Implements a high-level vectorization strategy on a Function.
33/// The abstraction used is that of super-vectors, which provide a single,
34/// compact, representation in the vector types, information that is expected
35/// to reduce the impact of the phase ordering problem
36///
37/// Vector granularity:
38/// ===================
39/// This pass is designed to perform vectorization at a super-vector
40/// granularity. A super-vector is loosely defined as a vector type that is a
41/// multiple of a "good" vector size so the HW can efficiently implement a set
42/// of high-level primitives. Multiple is understood along any dimension; e.g.
43/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a
44/// vector<8xf32> HW vector. Note that a "good vector size so the HW can
45/// efficiently implement a set of high-level primitives" is not necessarily an
46/// integer multiple of actual hardware registers. We leave details of this
47/// distinction unspecified for now.
48///
49/// Some may prefer the terminology a "tile of HW vectors". In this case, one
50/// should note that super-vectors implement an "always full tile" abstraction.
51/// They guarantee no partial-tile separation is necessary by relying on a
52/// high-level copy-reshape abstraction that we call vector.transfer. This
53/// copy-reshape operations is also responsible for performing layout
54/// transposition if necessary. In the general case this will require a scoped
55/// allocation in some notional local memory.
56///
57/// Whatever the mental model one prefers to use for this abstraction, the key
58/// point is that we burn into a single, compact, representation in the vector
59/// types, information that is expected to reduce the impact of the phase
60/// ordering problem. Indeed, a vector type conveys information that:
61/// 1. the associated loops have dependency semantics that do not prevent
62/// vectorization;
63/// 2. the associate loops have been sliced in chunks of static sizes that are
64/// compatible with vector sizes (i.e. similar to unroll-and-jam);
65/// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by
66/// the
67/// vector type and no vectorization hampering transformations can be
68/// applied to them anymore;
69/// 4. the underlying memrefs are accessed in some notional contiguous way
70/// that allows loading into vectors with some amount of spatial locality;
71/// In other words, super-vectorization provides a level of separation of
72/// concern by way of opacity to subsequent passes. This has the effect of
73/// encapsulating and propagating vectorization constraints down the list of
74/// passes until we are ready to lower further.
75///
76/// For a particular target, a notion of minimal n-d vector size will be
77/// specified and vectorization targets a multiple of those. In the following
78/// paragraph, let "k ." represent "a multiple of", to be understood as a
79/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes
80/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).
81///
82/// Some non-exhaustive notable super-vector sizes of interest include:
83/// - CPU: vector<k . HW_vector_size>,
84/// vector<k' . core_count x k . HW_vector_size>,
85/// vector<socket_count x k' . core_count x k . HW_vector_size>;
86/// - GPU: vector<k . warp_size>,
87/// vector<k . warp_size x float2>,
88/// vector<k . warp_size x float4>,
89/// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).
90///
91/// Loops and operations are emitted that operate on those super-vector shapes.
92/// Subsequent lowering passes will materialize to actual HW vector sizes. These
93/// passes are expected to be (gradually) more target-specific.
94///
95/// At a high level, a vectorized load in a loop will resemble:
96/// ```mlir
97/// affine.for %i = ? to ? step ? {
98/// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
99/// }
100/// ```
101/// It is the responsibility of the implementation of vector.transfer_read to
102/// materialize vector registers from the original scalar memrefs. A later (more
103/// target-dependent) lowering pass will materialize to actual HW vector sizes.
104/// This lowering may be occur at different times:
105/// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp +
106/// DmaWaitOp + vectorized operations for data transformations and shuffle;
107/// thus opening opportunities for unrolling and pipelining. This is an
108/// instance of library call "whiteboxing"; or
109/// 2. later in the a target-specific lowering pass or hand-written library
110/// call; achieving full separation of concerns. This is an instance of
111/// library call; or
112/// 3. a mix of both, e.g. based on a model.
113/// In the future, these operations will expose a contract to constrain the
114/// search on vectorization patterns and sizes.
115///
116/// Occurrence of super-vectorization in the compiler flow:
117/// =======================================================
118/// This is an active area of investigation. We start with 2 remarks to position
119/// super-vectorization in the context of existing ongoing work: LLVM VPLAN
120/// and LLVM SLP Vectorizer.
121///
122/// LLVM VPLAN:
123/// -----------
124/// The astute reader may have noticed that in the limit, super-vectorization
125/// can be applied at a similar time and with similar objectives than VPLAN.
126/// For instance, in the case of a traditional, polyhedral compilation-flow (for
127/// instance, the PPCG project uses ISL to provide dependence analysis,
128/// multi-level(scheduling + tiling), lifting footprint to fast memory,
129/// communication synthesis, mapping, register optimizations) and before
130/// unrolling. When vectorization is applied at this *late* level in a typical
131/// polyhedral flow, and is instantiated with actual hardware vector sizes,
132/// super-vectorization is expected to match (or subsume) the type of patterns
133/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR
134/// is higher level and our implementation should be significantly simpler. Also
135/// note that in this mode, recursive patterns are probably a bit of an overkill
136/// although it is reasonable to expect that mixing a bit of outer loop and
137/// inner loop vectorization + unrolling will provide interesting choices to
138/// MLIR.
139///
140/// LLVM SLP Vectorizer:
141/// --------------------
142/// Super-vectorization however is not meant to be usable in a similar fashion
143/// to the SLP vectorizer. The main difference lies in the information that
144/// both vectorizers use: super-vectorization examines contiguity of memory
145/// references along fastest varying dimensions and loops with recursive nested
146/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on
147/// the other hand, performs flat pattern matching inside a single unrolled loop
148/// body and stitches together pieces of load and store operations into full
149/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture
150/// innermost loop, control-flow dependent patterns that super-vectorization may
151/// not be able to capture easily. In other words, super-vectorization does not
152/// aim at replacing the SLP vectorizer and the two solutions are complementary.
153///
154/// Ongoing investigations:
155/// -----------------------
156/// We discuss the following *early* places where super-vectorization is
157/// applicable and touch on the expected benefits and risks . We list the
158/// opportunities in the context of the traditional polyhedral compiler flow
159/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline
160/// we expect to experiment with super-vectorization:
161/// 1. Right after language lowering to MLIR: this is the earliest time where
162/// super-vectorization is expected to be applied. At this level, all the
163/// language/user/library-level annotations are available and can be fully
164/// exploited. Examples include loop-type annotations (such as parallel,
165/// reduction, scan, dependence distance vector, vectorizable) as well as
166/// memory access annotations (such as non-aliasing writes guaranteed,
167/// indirect accesses that are permutations by construction) accesses or
168/// that a particular operation is prescribed atomic by the user. At this
169/// level, anything that enriches what dependence analysis can do should be
170/// aggressively exploited. At this level we are close to having explicit
171/// vector types in the language, except we do not impose that burden on the
172/// programmer/library: we derive information from scalar code + annotations.
173/// 2. After dependence analysis and before polyhedral scheduling: the
174/// information that supports vectorization does not need to be supplied by a
175/// higher level of abstraction. Traditional dependence analysis is available
176/// in MLIR and will be used to drive vectorization and cost models.
177///
178/// Let's pause here and remark that applying super-vectorization as described
179/// in 1. and 2. presents clear opportunities and risks:
180/// - the opportunity is that vectorization is burned in the type system and
181/// is protected from the adverse effect of loop scheduling, tiling, loop
182/// interchange and all passes downstream. Provided that subsequent passes are
183/// able to operate on vector types; the vector shapes, associated loop
184/// iterator properties, alignment, and contiguity of fastest varying
185/// dimensions are preserved until we lower the super-vector types. We expect
186/// this to significantly rein in on the adverse effects of phase ordering.
187/// - the risks are that a. all passes after super-vectorization have to work
188/// on elemental vector types (not that this is always true, wherever
189/// vectorization is applied) and b. that imposing vectorization constraints
190/// too early may be overall detrimental to loop fusion, tiling and other
191/// transformations because the dependence distances are coarsened when
192/// operating on elemental vector types. For this reason, the pattern
193/// profitability analysis should include a component that also captures the
194/// maximal amount of fusion available under a particular pattern. This is
195/// still at the stage of rough ideas but in this context, search is our
196/// friend as the Tensor Comprehensions and auto-TVM contributions
197/// demonstrated previously.
198/// Bottom-line is we do not yet have good answers for the above but aim at
199/// making it easy to answer such questions.
200///
201/// Back to our listing, the last places where early super-vectorization makes
202/// sense are:
203/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known
204/// to improve locality, parallelism and be configurable (e.g. max-fuse,
205/// smart-fuse etc). They can also have adverse effects on contiguity
206/// properties that are required for vectorization but the vector.transfer
207/// copy-reshape-pad-transpose abstraction is expected to help recapture
208/// these properties.
209/// 4. right after polyhedral-style scheduling+tiling;
210/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent
211/// probably the most promising places because applying tiling achieves a
212/// separation of concerns that allows rescheduling to worry less about
213/// locality and more about parallelism and distribution (e.g. min-fuse).
214///
215/// At these levels the risk-reward looks different: on one hand we probably
216/// lost a good deal of language/user/library-level annotation; on the other
217/// hand we gained parallelism and locality through scheduling and tiling.
218/// However we probably want to ensure tiling is compatible with the
219/// full-tile-only abstraction used in super-vectorization or suffer the
220/// consequences. It is too early to place bets on what will win but we expect
221/// super-vectorization to be the right abstraction to allow exploring at all
222/// these levels. And again, search is our friend.
223///
224/// Lastly, we mention it again here:
225/// 6. as a MLIR-based alternative to VPLAN.
226///
227/// Lowering, unrolling, pipelining:
228/// ================================
229/// TODO: point to the proper places.
230///
231/// Algorithm:
232/// ==========
233/// The algorithm proceeds in a few steps:
234/// 1. defining super-vectorization patterns and matching them on the tree of
235/// AffineForOp. A super-vectorization pattern is defined as a recursive
236/// data structures that matches and captures nested, imperfectly-nested
237/// loops that have a. conformable loop annotations attached (e.g. parallel,
238/// reduction, vectorizable, ...) as well as b. all contiguous load/store
239/// operations along a specified minor dimension (not necessarily the
240/// fastest varying) ;
241/// 2. analyzing those patterns for profitability (TODO: and
242/// interference);
243/// 3. then, for each pattern in order:
244/// a. applying iterative rewriting of the loops and all their nested
245/// operations in topological order. Rewriting is implemented by
246/// coarsening the loops and converting operations and operands to their
247/// vector forms. Processing operations in topological order is relatively
248/// simple due to the structured nature of the control-flow
249/// representation. This order ensures that all the operands of a given
250/// operation have been vectorized before the operation itself in a single
251/// traversal, except for operands defined outside of the loop nest. The
252/// algorithm can convert the following operations to their vector form:
253/// * Affine load and store operations are converted to opaque vector
254/// transfer read and write operations.
255/// * Scalar constant operations/operands are converted to vector
256/// constant operations (splat).
257/// * Uniform operands (only induction variables of loops not mapped to
258/// a vector dimension, or operands defined outside of the loop nest
259/// for now) are broadcasted to a vector.
260/// TODO: Support more uniform cases.
261/// * Affine for operations with 'iter_args' are vectorized by
262/// vectorizing their 'iter_args' operands and results.
263/// TODO: Support more complex loops with divergent lbs and/or ubs.
264/// * The remaining operations in the loop nest are vectorized by
265/// widening their scalar types to vector types.
266/// b. if everything under the root AffineForOp in the current pattern
267/// is vectorized properly, we commit that loop to the IR and remove the
268/// scalar loop. Otherwise, we discard the vectorized loop and keep the
269/// original scalar loop.
270/// c. vectorization is applied on the next pattern in the list. Because
271/// pattern interference avoidance is not yet implemented and that we do
272/// not support further vectorizing an already vector load we need to
273/// re-verify that the pattern is still vectorizable. This is expected to
274/// make cost models more difficult to write and is subject to improvement
275/// in the future.
276///
277/// Choice of loop transformation to support the algorithm:
278/// =======================================================
279/// The choice of loop transformation to apply for coarsening vectorized loops
280/// is still subject to exploratory tradeoffs. In particular, say we want to
281/// vectorize by a factor 128, we want to transform the following input:
282/// ```mlir
283/// affine.for %i = %M to %N {
284/// %a = affine.load %A[%i] : memref<?xf32>
285/// }
286/// ```
287///
288/// Traditionally, one would vectorize late (after scheduling, tiling,
289/// memory promotion etc) say after stripmining (and potentially unrolling in
290/// the case of LLVM's SLP vectorizer):
291/// ```mlir
292/// affine.for %i = floor(%M, 128) to ceil(%N, 128) {
293/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
294/// %a = affine.load %A[%ii] : memref<?xf32>
295/// }
296/// }
297/// ```
298///
299/// Instead, we seek to vectorize early and freeze vector types before
300/// scheduling, so we want to generate a pattern that resembles:
301/// ```mlir
302/// affine.for %i = ? to ? step ? {
303/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
304/// }
305/// ```
306///
307/// i. simply dividing the lower / upper bounds by 128 creates issues
308/// when representing expressions such as ii + 1 because now we only
309/// have access to original values that have been divided. Additional
310/// information is needed to specify accesses at below-128 granularity;
311/// ii. another alternative is to coarsen the loop step but this may have
312/// consequences on dependence analysis and fusability of loops: fusable
313/// loops probably need to have the same step (because we don't want to
314/// stripmine/unroll to enable fusion).
315/// As a consequence, we choose to represent the coarsening using the loop
316/// step for now and reevaluate in the future. Note that we can renormalize
317/// loop steps later if/when we have evidence that they are problematic.
318///
319/// For the simple strawman example above, vectorizing for a 1-D vector
320/// abstraction of size 128 returns code similar to:
321/// ```mlir
322/// affine.for %i = %M to %N step 128 {
323/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
324/// }
325/// ```
326///
327/// Unsupported cases, extensions, and work in progress (help welcome :-) ):
328/// ========================================================================
329/// 1. lowering to concrete vector types for various HW;
330/// 2. reduction support for n-D vectorization and non-unit steps;
331/// 3. non-effecting padding during vector.transfer_read and filter during
332/// vector.transfer_write;
333/// 4. misalignment support vector.transfer_read / vector.transfer_write
334/// (hopefully without read-modify-writes);
335/// 5. control-flow support;
336/// 6. cost-models, heuristics and search;
337/// 7. Op implementation, extensions and implication on memref views;
338/// 8. many TODOs left around.
339///
340/// Examples:
341/// =========
342/// Consider the following Function:
343/// ```mlir
344/// func @vector_add_2d(%M : index, %N : index) -> f32 {
345/// %A = alloc (%M, %N) : memref<?x?xf32, 0>
346/// %B = alloc (%M, %N) : memref<?x?xf32, 0>
347/// %C = alloc (%M, %N) : memref<?x?xf32, 0>
348/// %f1 = arith.constant 1.0 : f32
349/// %f2 = arith.constant 2.0 : f32
350/// affine.for %i0 = 0 to %M {
351/// affine.for %i1 = 0 to %N {
352/// // non-scoped %f1
353/// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
354/// }
355/// }
356/// affine.for %i2 = 0 to %M {
357/// affine.for %i3 = 0 to %N {
358/// // non-scoped %f2
359/// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
360/// }
361/// }
362/// affine.for %i4 = 0 to %M {
363/// affine.for %i5 = 0 to %N {
364/// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0>
365/// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0>
366/// %s5 = arith.addf %a5, %b5 : f32
367/// // non-scoped %f1
368/// %s6 = arith.addf %s5, %f1 : f32
369/// // non-scoped %f2
370/// %s7 = arith.addf %s5, %f2 : f32
371/// // diamond dependency.
372/// %s8 = arith.addf %s7, %s6 : f32
373/// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
374/// }
375/// }
376/// %c7 = arith.constant 7 : index
377/// %c42 = arith.constant 42 : index
378/// %res = load %C[%c7, %c42] : memref<?x?xf32, 0>
379/// return %res : f32
380/// }
381/// ```
382///
383/// The -affine-super-vectorize pass with the following arguments:
384/// ```
385/// -affine-super-vectorize="virtual-vector-size=256 test-fastest-varying=0"
386/// ```
387///
388/// produces this standard innermost-loop vectorized code:
389/// ```mlir
390/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
391/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
392/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
393/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
394/// %cst = arith.constant 1.0 : f32
395/// %cst_0 = arith.constant 2.0 : f32
396/// affine.for %i0 = 0 to %arg0 {
397/// affine.for %i1 = 0 to %arg1 step 256 {
398/// %cst_1 = arith.constant dense<vector<256xf32>, 1.0> :
399/// vector<256xf32>
400/// vector.transfer_write %cst_1, %0[%i0, %i1] :
401/// vector<256xf32>, memref<?x?xf32>
402/// }
403/// }
404/// affine.for %i2 = 0 to %arg0 {
405/// affine.for %i3 = 0 to %arg1 step 256 {
406/// %cst_2 = arith.constant dense<vector<256xf32>, 2.0> :
407/// vector<256xf32>
408/// vector.transfer_write %cst_2, %1[%i2, %i3] :
409/// vector<256xf32>, memref<?x?xf32>
410/// }
411/// }
412/// affine.for %i4 = 0 to %arg0 {
413/// affine.for %i5 = 0 to %arg1 step 256 {
414/// %3 = vector.transfer_read %0[%i4, %i5] :
415/// memref<?x?xf32>, vector<256xf32>
416/// %4 = vector.transfer_read %1[%i4, %i5] :
417/// memref<?x?xf32>, vector<256xf32>
418/// %5 = arith.addf %3, %4 : vector<256xf32>
419/// %cst_3 = arith.constant dense<vector<256xf32>, 1.0> :
420/// vector<256xf32>
421/// %6 = arith.addf %5, %cst_3 : vector<256xf32>
422/// %cst_4 = arith.constant dense<vector<256xf32>, 2.0> :
423/// vector<256xf32>
424/// %7 = arith.addf %5, %cst_4 : vector<256xf32>
425/// %8 = arith.addf %7, %6 : vector<256xf32>
426/// vector.transfer_write %8, %2[%i4, %i5] :
427/// vector<256xf32>, memref<?x?xf32>
428/// }
429/// }
430/// %c7 = arith.constant 7 : index
431/// %c42 = arith.constant 42 : index
432/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
433/// return %9 : f32
434/// }
435/// ```
436///
437/// The -affine-super-vectorize pass with the following arguments:
438/// ```
439/// -affine-super-vectorize="virtual-vector-size=32,256 \
440/// test-fastest-varying=1,0"
441/// ```
442///
443/// produces this more interesting mixed outer-innermost-loop vectorized code:
444/// ```mlir
445/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
446/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
447/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
448/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
449/// %cst = arith.constant 1.0 : f32
450/// %cst_0 = arith.constant 2.0 : f32
451/// affine.for %i0 = 0 to %arg0 step 32 {
452/// affine.for %i1 = 0 to %arg1 step 256 {
453/// %cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> :
454/// vector<32x256xf32>
455/// vector.transfer_write %cst_1, %0[%i0, %i1] :
456/// vector<32x256xf32>, memref<?x?xf32>
457/// }
458/// }
459/// affine.for %i2 = 0 to %arg0 step 32 {
460/// affine.for %i3 = 0 to %arg1 step 256 {
461/// %cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> :
462/// vector<32x256xf32>
463/// vector.transfer_write %cst_2, %1[%i2, %i3] :
464/// vector<32x256xf32>, memref<?x?xf32>
465/// }
466/// }
467/// affine.for %i4 = 0 to %arg0 step 32 {
468/// affine.for %i5 = 0 to %arg1 step 256 {
469/// %3 = vector.transfer_read %0[%i4, %i5] :
470/// memref<?x?xf32> vector<32x256xf32>
471/// %4 = vector.transfer_read %1[%i4, %i5] :
472/// memref<?x?xf32>, vector<32x256xf32>
473/// %5 = arith.addf %3, %4 : vector<32x256xf32>
474/// %cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> :
475/// vector<32x256xf32>
476/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32>
477/// %cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> :
478/// vector<32x256xf32>
479/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32>
480/// %8 = arith.addf %7, %6 : vector<32x256xf32>
481/// vector.transfer_write %8, %2[%i4, %i5] :
482/// vector<32x256xf32>, memref<?x?xf32>
483/// }
484/// }
485/// %c7 = arith.constant 7 : index
486/// %c42 = arith.constant 42 : index
487/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
488/// return %9 : f32
489/// }
490/// ```
491///
492/// Of course, much more intricate n-D imperfectly-nested patterns can be
493/// vectorized too and specified in a fully declarative fashion.
494///
495/// Reduction:
496/// ==========
497/// Vectorizing reduction loops along the reduction dimension is supported if:
498/// - the reduction kind is supported,
499/// - the vectorization is 1-D, and
500/// - the step size of the loop equals to one.
501///
502/// Comparing to the non-vector-dimension case, two additional things are done
503/// during vectorization of such loops:
504/// - The resulting vector returned from the loop is reduced to a scalar using
505/// `vector.reduce`.
506/// - In some cases a mask is applied to the vector yielded at the end of the
507/// loop to prevent garbage values from being written to the accumulator.
508///
509/// Reduction vectorization is switched off by default, it can be enabled by
510/// passing a map from loops to reductions to utility functions, or by passing
511/// `vectorize-reductions=true` to the vectorization pass.
512///
513/// Consider the following example:
514/// ```mlir
515/// func @vecred(%in: memref<512xf32>) -> f32 {
516/// %cst = arith.constant 0.000000e+00 : f32
517/// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) {
518/// %ld = affine.load %in[%i] : memref<512xf32>
519/// %cos = math.cos %ld : f32
520/// %add = arith.addf %part_sum, %cos : f32
521/// affine.yield %add : f32
522/// }
523/// return %sum : f32
524/// }
525/// ```
526///
527/// The -affine-super-vectorize pass with the following arguments:
528/// ```
529/// -affine-super-vectorize="virtual-vector-size=128 test-fastest-varying=0 \
530/// vectorize-reductions=true"
531/// ```
532/// produces the following output:
533/// ```mlir
534/// #map = affine_map<(d0) -> (-d0 + 500)>
535/// func @vecred(%arg0: memref<512xf32>) -> f32 {
536/// %cst = arith.constant 0.000000e+00 : f32
537/// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32>
538/// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0)
539/// -> (vector<128xf32>) {
540/// // %2 is the number of iterations left in the original loop.
541/// %2 = affine.apply #map(%arg1)
542/// %3 = vector.create_mask %2 : vector<128xi1>
543/// %cst_1 = arith.constant 0.000000e+00 : f32
544/// %4 = vector.transfer_read %arg0[%arg1], %cst_1 :
545/// memref<512xf32>, vector<128xf32>
546/// %5 = math.cos %4 : vector<128xf32>
547/// %6 = arith.addf %arg2, %5 : vector<128xf32>
548/// // We filter out the effect of last 12 elements using the mask.
549/// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32>
550/// affine.yield %7 : vector<128xf32>
551/// }
552/// %1 = vector.reduction <add>, %0 : vector<128xf32> into f32
553/// return %1 : f32
554/// }
555/// ```
556///
557/// Note that because of loop misalignment we needed to apply a mask to prevent
558/// last 12 elements from affecting the final result. The mask is full of ones
559/// in every iteration except for the last one, in which it has the form
560/// `11...100...0` with 116 ones and 12 zeros.
561
562#define DEBUG_TYPE"early-vect" "early-vect"
563
564using llvm::dbgs;
565
566/// Forward declaration.
567static FilterFunctionType
568isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
569 int fastestVaryingMemRefDimension);
570
571/// Creates a vectorization pattern from the command line arguments.
572/// Up to 3-D patterns are supported.
573/// If the command line argument requests a pattern of higher order, returns an
574/// empty pattern list which will conservatively result in no vectorization.
575static Optional<NestedPattern>
576makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
577 ArrayRef<int64_t> fastestVaryingPattern) {
578 using matcher::For;
579 int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0];
580 int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1];
581 int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2];
582 switch (vectorRank) {
583 case 1:
584 return For(isVectorizableLoopPtrFactory(parallelLoops, d0));
585 case 2:
586 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
587 For(isVectorizableLoopPtrFactory(parallelLoops, d1)));
588 case 3:
589 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
590 For(isVectorizableLoopPtrFactory(parallelLoops, d1),
591 For(isVectorizableLoopPtrFactory(parallelLoops, d2))));
592 default: {
593 return llvm::None;
594 }
595 }
596}
597
598static NestedPattern &vectorTransferPattern() {
599 static auto pattern = matcher::Op([](Operation &op) {
600 return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
601 });
602 return pattern;
603}
604
605namespace {
606
607/// Base state for the vectorize pass.
608/// Command line arguments are preempted by non-empty pass arguments.
609struct Vectorize : public AffineVectorizeBase<Vectorize> {
610 Vectorize() = default;
611 Vectorize(ArrayRef<int64_t> virtualVectorSize);
612 void runOnOperation() override;
613};
614
615} // namespace
616
617Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) {
618 vectorSizes = virtualVectorSize;
619}
620
621static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern,
622 unsigned patternDepth,
623 VectorizationStrategy *strategy) {
624 assert(patternDepth > depthInPattern &&(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 625
, __extension__ __PRETTY_FUNCTION__))
625 "patternDepth is greater than depthInPattern")(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 625
, __extension__ __PRETTY_FUNCTION__))
;
626 if (patternDepth - depthInPattern > strategy->vectorSizes.size()) {
627 // Don't vectorize this loop
628 return;
629 }
630 strategy->loopToVectorDim[loop] =
631 strategy->vectorSizes.size() - (patternDepth - depthInPattern);
632}
633
634/// Implements a simple strawman strategy for vectorization.
635/// Given a matched pattern `matches` of depth `patternDepth`, this strategy
636/// greedily assigns the fastest varying dimension ** of the vector ** to the
637/// innermost loop in the pattern.
638/// When coupled with a pattern that looks for the fastest varying dimension in
639/// load/store MemRefs, this creates a generic vectorization strategy that works
640/// for any loop in a hierarchy (outermost, innermost or intermediate).
641///
642/// TODO: In the future we should additionally increase the power of the
643/// profitability analysis along 3 directions:
644/// 1. account for loop extents (both static and parametric + annotations);
645/// 2. account for data layout permutations;
646/// 3. account for impact of vectorization on maximal loop fusion.
647/// Then we can quantify the above to build a cost model and search over
648/// strategies.
649static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
650 unsigned depthInPattern,
651 unsigned patternDepth,
652 VectorizationStrategy *strategy) {
653 for (auto m : matches) {
654 if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
655 patternDepth, strategy))) {
656 return failure();
657 }
658 vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern,
659 patternDepth, strategy);
660 }
661 return success();
662}
663
664///// end TODO: Hoist to a VectorizationStrategy.cpp when appropriate /////
665
666namespace {
667
668struct VectorizationState {
669
670 VectorizationState(MLIRContext *context) : builder(context) {}
671
672 /// Registers the vector replacement of a scalar operation and its result
673 /// values. Both operations must have the same number of results.
674 ///
675 /// This utility is used to register the replacement for the vast majority of
676 /// the vectorized operations.
677 ///
678 /// Example:
679 /// * 'replaced': %0 = arith.addf %1, %2 : f32
680 /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
681 void registerOpVectorReplacement(Operation *replaced, Operation *replacement);
682
683 /// Registers the vector replacement of a scalar value. The replacement
684 /// operation should have a single result, which replaces the scalar value.
685 ///
686 /// This utility is used to register the vector replacement of block arguments
687 /// and operation results which are not directly vectorized (i.e., their
688 /// scalar version still exists after vectorization), like uniforms.
689 ///
690 /// Example:
691 /// * 'replaced': block argument or operation outside of the vectorized
692 /// loop.
693 /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
694 void registerValueVectorReplacement(Value replaced, Operation *replacement);
695
696 /// Registers the vector replacement of a block argument (e.g., iter_args).
697 ///
698 /// Example:
699 /// * 'replaced': 'iter_arg' block argument.
700 /// * 'replacement': vectorized 'iter_arg' block argument.
701 void registerBlockArgVectorReplacement(BlockArgument replaced,
702 BlockArgument replacement);
703
704 /// Registers the scalar replacement of a scalar value. 'replacement' must be
705 /// scalar. Both values must be block arguments. Operation results should be
706 /// replaced using the 'registerOp*' utilitites.
707 ///
708 /// This utility is used to register the replacement of block arguments
709 /// that are within the loop to be vectorized and will continue being scalar
710 /// within the vector loop.
711 ///
712 /// Example:
713 /// * 'replaced': induction variable of a loop to be vectorized.
714 /// * 'replacement': new induction variable in the new vector loop.
715 void registerValueScalarReplacement(BlockArgument replaced,
716 BlockArgument replacement);
717
718 /// Registers the scalar replacement of a scalar result returned from a
719 /// reduction loop. 'replacement' must be scalar.
720 ///
721 /// This utility is used to register the replacement for scalar results of
722 /// vectorized reduction loops with iter_args.
723 ///
724 /// Example 2:
725 /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
726 /// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into
727 /// f32
728 void registerLoopResultScalarReplacement(Value replaced, Value replacement);
729
730 /// Returns in 'replacedVals' the scalar replacement for values in
731 /// 'inputVals'.
732 void getScalarValueReplacementsFor(ValueRange inputVals,
733 SmallVectorImpl<Value> &replacedVals);
734
735 /// Erases the scalar loop nest after its successful vectorization.
736 void finishVectorizationPattern(AffineForOp rootLoop);
737
738 // Used to build and insert all the new operations created. The insertion
739 // point is preserved and updated along the vectorization process.
740 OpBuilder builder;
741
742 // Maps input scalar operations to their vector counterparts.
743 DenseMap<Operation *, Operation *> opVectorReplacement;
744 // Maps input scalar values to their vector counterparts.
745 BlockAndValueMapping valueVectorReplacement;
746 // Maps input scalar values to their new scalar counterparts in the vector
747 // loop nest.
748 BlockAndValueMapping valueScalarReplacement;
749 // Maps results of reduction loops to their new scalar counterparts.
750 DenseMap<Value, Value> loopResultScalarReplacement;
751
752 // Maps the newly created vector loops to their vector dimension.
753 DenseMap<Operation *, unsigned> vecLoopToVecDim;
754
755 // Maps the new vectorized loops to the corresponding vector masks if it is
756 // required.
757 DenseMap<Operation *, Value> vecLoopToMask;
758
759 // The strategy drives which loop to vectorize by which amount.
760 const VectorizationStrategy *strategy = nullptr;
761
762private:
763 /// Internal implementation to map input scalar values to new vector or scalar
764 /// values.
765 void registerValueVectorReplacementImpl(Value replaced, Value replacement);
766 void registerValueScalarReplacementImpl(Value replaced, Value replacement);
767};
768
769} // namespace
770
771/// Registers the vector replacement of a scalar operation and its result
772/// values. Both operations must have the same number of results.
773///
774/// This utility is used to register the replacement for the vast majority of
775/// the vectorized operations.
776///
777/// Example:
778/// * 'replaced': %0 = arith.addf %1, %2 : f32
779/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
780void VectorizationState::registerOpVectorReplacement(Operation *replaced,
781 Operation *replacement) {
782 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"
; } } while (false)
;
783 LLVM_DEBUG(dbgs() << *replaced << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replaced << "\n"; } }
while (false)
;
784 LLVM_DEBUG(dbgs() << "into\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "into\n"; } } while (false)
;
785 LLVM_DEBUG(dbgs() << *replacement << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replacement << "\n";
} } while (false)
;
786
787 assert(replaced->getNumResults() == replacement->getNumResults() &&(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 788
, __extension__ __PRETTY_FUNCTION__))
788 "Unexpected replaced and replacement results")(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 788
, __extension__ __PRETTY_FUNCTION__))
;
789 assert(opVectorReplacement.count(replaced) == 0 && "already registered")(static_cast <bool> (opVectorReplacement.count(replaced
) == 0 && "already registered") ? void (0) : __assert_fail
("opVectorReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 789
, __extension__ __PRETTY_FUNCTION__))
;
790 opVectorReplacement[replaced] = replacement;
791
792 for (auto resultTuple :
793 llvm::zip(replaced->getResults(), replacement->getResults()))
794 registerValueVectorReplacementImpl(std::get<0>(resultTuple),
795 std::get<1>(resultTuple));
796}
797
798/// Registers the vector replacement of a scalar value. The replacement
799/// operation should have a single result, which replaces the scalar value.
800///
801/// This utility is used to register the vector replacement of block arguments
802/// and operation results which are not directly vectorized (i.e., their
803/// scalar version still exists after vectorization), like uniforms.
804///
805/// Example:
806/// * 'replaced': block argument or operation outside of the vectorized loop.
807/// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
808void VectorizationState::registerValueVectorReplacement(
809 Value replaced, Operation *replacement) {
810 assert(replacement->getNumResults() == 1 &&(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 811
, __extension__ __PRETTY_FUNCTION__))
811 "Expected single-result replacement")(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 811
, __extension__ __PRETTY_FUNCTION__))
;
812 if (Operation *defOp = replaced.getDefiningOp())
813 registerOpVectorReplacement(defOp, replacement);
814 else
815 registerValueVectorReplacementImpl(replaced, replacement->getResult(0));
816}
817
818/// Registers the vector replacement of a block argument (e.g., iter_args).
819///
820/// Example:
821/// * 'replaced': 'iter_arg' block argument.
822/// * 'replacement': vectorized 'iter_arg' block argument.
823void VectorizationState::registerBlockArgVectorReplacement(
824 BlockArgument replaced, BlockArgument replacement) {
825 registerValueVectorReplacementImpl(replaced, replacement);
826}
827
828void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
829 Value replacement) {
830 assert(!valueVectorReplacement.contains(replaced) &&(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831
, __extension__ __PRETTY_FUNCTION__))
831 "Vector replacement already registered")(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831
, __extension__ __PRETTY_FUNCTION__))
;
832 assert(replacement.getType().isa<VectorType>() &&(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 833
, __extension__ __PRETTY_FUNCTION__))
833 "Expected vector type in vector replacement")(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 833
, __extension__ __PRETTY_FUNCTION__))
;
834 valueVectorReplacement.map(replaced, replacement);
835}
836
837/// Registers the scalar replacement of a scalar value. 'replacement' must be
838/// scalar. Both values must be block arguments. Operation results should be
839/// replaced using the 'registerOp*' utilitites.
840///
841/// This utility is used to register the replacement of block arguments
842/// that are within the loop to be vectorized and will continue being scalar
843/// within the vector loop.
844///
845/// Example:
846/// * 'replaced': induction variable of a loop to be vectorized.
847/// * 'replacement': new induction variable in the new vector loop.
848void VectorizationState::registerValueScalarReplacement(
849 BlockArgument replaced, BlockArgument replacement) {
850 registerValueScalarReplacementImpl(replaced, replacement);
851}
852
853/// Registers the scalar replacement of a scalar result returned from a
854/// reduction loop. 'replacement' must be scalar.
855///
856/// This utility is used to register the replacement for scalar results of
857/// vectorized reduction loops with iter_args.
858///
859/// Example 2:
860/// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
861/// * 'replacement': %1 = vector.reduction <add>, %0 : vector<4xf32> into f32
862void VectorizationState::registerLoopResultScalarReplacement(
863 Value replaced, Value replacement) {
864 assert(isa<AffineForOp>(replaced.getDefiningOp()))(static_cast <bool> (isa<AffineForOp>(replaced.getDefiningOp
())) ? void (0) : __assert_fail ("isa<AffineForOp>(replaced.getDefiningOp())"
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 864
, __extension__ __PRETTY_FUNCTION__))
;
865 assert(loopResultScalarReplacement.count(replaced) == 0 &&(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 866
, __extension__ __PRETTY_FUNCTION__))
866 "already registered")(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 866
, __extension__ __PRETTY_FUNCTION__))
;
867 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
868 "with scalar: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
869 << replacement)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
;
870 loopResultScalarReplacement[replaced] = replacement;
871}
872
873void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
874 Value replacement) {
875 assert(!valueScalarReplacement.contains(replaced) &&(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876
, __extension__ __PRETTY_FUNCTION__))
876 "Scalar value replacement already registered")(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876
, __extension__ __PRETTY_FUNCTION__))
;
877 assert(!replacement.getType().isa<VectorType>() &&(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 878
, __extension__ __PRETTY_FUNCTION__))
878 "Expected scalar type in scalar replacement")(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 878
, __extension__ __PRETTY_FUNCTION__))
;
879 valueScalarReplacement.map(replaced, replacement);
880}
881
882/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
883void VectorizationState::getScalarValueReplacementsFor(
884 ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
885 for (Value inputVal : inputVals)
886 replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal));
887}
888
889/// Erases a loop nest, including all its nested operations.
890static void eraseLoopNest(AffineForOp forOp) {
891 LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ erasing:\n"
<< forOp << "\n"; } } while (false)
;
892 forOp.erase();
893}
894
895/// Erases the scalar loop nest after its successful vectorization.
896void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) {
897 LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] Finalizing vectorization\n"
; } } while (false)
;
898 eraseLoopNest(rootLoop);
899}
900
901// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
902static void computeMemoryOpIndices(Operation *op, AffineMap map,
903 ValueRange mapOperands,
904 VectorizationState &state,
905 SmallVectorImpl<Value> &results) {
906 for (auto resultExpr : map.getResults()) {
907 auto singleResMap =
908 AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
909 auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
910 mapOperands);
911 results.push_back(afOp);
912 }
913}
914
915/// Returns a FilterFunctionType that can be used in NestedPattern to match a
916/// loop whose underlying load/store accesses are either invariant or all
917// varying along the `fastestVaryingMemRefDimension`.
918static FilterFunctionType
919isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
920 int fastestVaryingMemRefDimension) {
921 return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
922 auto loop = cast<AffineForOp>(forOp);
923 auto parallelIt = parallelLoops.find(loop);
924 if (parallelIt == parallelLoops.end())
925 return false;
926 int memRefDim = -1;
927 auto vectorizableBody =
928 isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern());
929 if (!vectorizableBody)
930 return false;
931 return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
932 memRefDim == fastestVaryingMemRefDimension;
933 };
934}
935
936/// Returns the vector type resulting from applying the provided vectorization
937/// strategy on the scalar type.
938static VectorType getVectorType(Type scalarTy,
939 const VectorizationStrategy *strategy) {
940 assert(!scalarTy.isa<VectorType>() && "Expected scalar type")(static_cast <bool> (!scalarTy.isa<VectorType>() &&
"Expected scalar type") ? void (0) : __assert_fail ("!scalarTy.isa<VectorType>() && \"Expected scalar type\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 940
, __extension__ __PRETTY_FUNCTION__))
;
941 return VectorType::get(strategy->vectorSizes, scalarTy);
942}
943
944/// Tries to transform a scalar constant into a vector constant. Returns the
945/// vector constant if the scalar type is valid vector element type. Returns
946/// nullptr, otherwise.
947static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
948 VectorizationState &state) {
949 Type scalarTy = constOp.getType();
950 if (!VectorType::isValidElementType(scalarTy))
951 return nullptr;
952
953 auto vecTy = getVectorType(scalarTy, state.strategy);
954 auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue());
955
956 OpBuilder::InsertionGuard guard(state.builder);
957 Operation *parentOp = state.builder.getInsertionBlock()->getParentOp();
958 // Find the innermost vectorized ancestor loop to insert the vector constant.
959 while (parentOp && !state.vecLoopToVecDim.count(parentOp))
960 parentOp = parentOp->getParentOp();
961 assert(parentOp && state.vecLoopToVecDim.count(parentOp) &&(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 962
, __extension__ __PRETTY_FUNCTION__))
962 isa<AffineForOp>(parentOp) && "Expected a vectorized for op")(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 962
, __extension__ __PRETTY_FUNCTION__))
;
963 auto vecForOp = cast<AffineForOp>(parentOp);
964 state.builder.setInsertionPointToStart(vecForOp.getBody());
965 auto newConstOp =
966 state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr);
967
968 // Register vector replacement for future uses in the scope.
969 state.registerOpVectorReplacement(constOp, newConstOp);
970 return newConstOp;
971}
972
973/// Creates a constant vector filled with the neutral elements of the given
974/// reduction. The scalar type of vector elements will be taken from
975/// `oldOperand`.
976static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
977 Value oldOperand,
978 VectorizationState &state) {
979 Type scalarTy = oldOperand.getType();
980 if (!VectorType::isValidElementType(scalarTy))
981 return nullptr;
982
983 Attribute valueAttr = getIdentityValueAttr(
984 reductionKind, scalarTy, state.builder, oldOperand.getLoc());
985 auto vecTy = getVectorType(scalarTy, state.strategy);
986 auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr);
987 auto newConstOp =
988 state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr);
989
990 return newConstOp;
991}
992
993/// Creates a mask used to filter out garbage elements in the last iteration
994/// of unaligned loops. If a mask is not required then `nullptr` is returned.
995/// The mask will be a vector of booleans representing meaningful vector
996/// elements in the current iteration. It is filled with ones for each iteration
997/// except for the last one, where it has the form `11...100...0` with the
998/// number of ones equal to the number of meaningful elements (i.e. the number
999/// of iterations that would be left in the original loop).
1000static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
1001 assert(state.strategy->vectorSizes.size() == 1 &&(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1002
, __extension__ __PRETTY_FUNCTION__))
1002 "Creating a mask non-1-D vectors is not supported.")(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1002
, __extension__ __PRETTY_FUNCTION__))
;
1003 assert(vecForOp.getStep() == state.strategy->vectorSizes[0] &&(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005
, __extension__ __PRETTY_FUNCTION__))
1004 "Creating a mask for loops with non-unit original step size is not "(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005
, __extension__ __PRETTY_FUNCTION__))
1005 "supported.")(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1005
, __extension__ __PRETTY_FUNCTION__))
;
1006
1007 // Check if we have already created the mask.
1008 if (Value mask = state.vecLoopToMask.lookup(vecForOp))
1009 return mask;
1010
1011 // If the loop has constant bounds and the original number of iterations is
1012 // divisable by the vector size then we don't need a mask.
1013 if (vecForOp.hasConstantBounds()) {
1014 int64_t originalTripCount =
1015 vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound();
1016 if (originalTripCount % vecForOp.getStep() == 0)
1017 return nullptr;
1018 }
1019
1020 OpBuilder::InsertionGuard guard(state.builder);
1021 state.builder.setInsertionPointToStart(vecForOp.getBody());
1022
1023 // We generate the mask using the `vector.create_mask` operation which accepts
1024 // the number of meaningful elements (i.e. the length of the prefix of 1s).
1025 // To compute the number of meaningful elements we subtract the current value
1026 // of the iteration variable from the upper bound of the loop. Example:
1027 //
1028 // // 500 is the upper bound of the loop
1029 // #map = affine_map<(d0) -> (500 - d0)>
1030 // %elems_left = affine.apply #map(%iv)
1031 // %mask = vector.create_mask %elems_left : vector<128xi1>
1032
1033 Location loc = vecForOp.getLoc();
1034
1035 // First we get the upper bound of the loop using `affine.apply` or
1036 // `affine.min`.
1037 AffineMap ubMap = vecForOp.getUpperBoundMap();
1038 Value ub;
1039 if (ubMap.getNumResults() == 1)
1040 ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(),
1041 vecForOp.getUpperBoundOperands());
1042 else
1043 ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(),
1044 vecForOp.getUpperBoundOperands());
1045 // Then we compute the number of (original) iterations left in the loop.
1046 AffineExpr subExpr =
1047 state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1);
1048 Value itersLeft =
1049 makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr),
1050 {ub, vecForOp.getInductionVar()});
1051 // If the affine maps were successfully composed then `ub` is unneeded.
1052 if (ub.use_empty())
1053 ub.getDefiningOp()->erase();
1054 // Finally we create the mask.
1055 Type maskTy = VectorType::get(state.strategy->vectorSizes,
1056 state.builder.getIntegerType(1));
1057 Value mask =
1058 state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft);
1059
1060 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1061 << itersLeft << "\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1062 << mask << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
;
1063
1064 state.vecLoopToMask[vecForOp] = mask;
1065 return mask;
1066}
1067
1068/// Returns true if the provided value is vector uniform given the vectorization
1069/// strategy.
1070// TODO: For now, only values that are induction variables of loops not in
1071// `loopToVectorDim` or invariants to all the loops in the vectorization
1072// strategy are considered vector uniforms.
1073static bool isUniformDefinition(Value value,
1074 const VectorizationStrategy *strategy) {
1075 AffineForOp forOp = getForInductionVarOwner(value);
1076 if (forOp && strategy->loopToVectorDim.count(forOp) == 0)
1077 return true;
1078
1079 for (auto loopToDim : strategy->loopToVectorDim) {
1080 auto loop = cast<AffineForOp>(loopToDim.first);
1081 if (!loop.isDefinedOutsideOfLoop(value))
1082 return false;
1083 }
1084 return true;
1085}
1086
1087/// Generates a broadcast op for the provided uniform value using the
1088/// vectorization strategy in 'state'.
1089static Operation *vectorizeUniform(Value uniformVal,
1090 VectorizationState &state) {
1091 OpBuilder::InsertionGuard guard(state.builder);
1092 Value uniformScalarRepl =
1093 state.valueScalarReplacement.lookupOrDefault(uniformVal);
1094 state.builder.setInsertionPointAfterValue(uniformScalarRepl);
1095
1096 auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
1097 auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
1098 vectorTy, uniformScalarRepl);
1099 state.registerValueVectorReplacement(uniformVal, bcastOp);
1100 return bcastOp;
1101}
1102
1103/// Tries to vectorize a given `operand` by applying the following logic:
1104/// 1. if the defining operation has been already vectorized, `operand` is
1105/// already in the proper vector form;
1106/// 2. if the `operand` is a constant, returns the vectorized form of the
1107/// constant;
1108/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`;
1109/// 4. otherwise, the vectorization of `operand` is not supported.
1110/// Newly created vector operations are registered in `state` as replacement
1111/// for their scalar counterparts.
1112/// In particular this logic captures some of the use cases where definitions
1113/// that are not scoped under the current pattern are needed to vectorize.
1114/// One such example is top level function constants that need to be splatted.
1115///
1116/// Returns an operand that has been vectorized to match `state`'s strategy if
1117/// vectorization is possible with the above logic. Returns nullptr otherwise.
1118///
1119/// TODO: handle more complex cases.
1120static Value vectorizeOperand(Value operand, VectorizationState &state) {
1121 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorize operand: "
<< operand; } } while (false)
;
15
Assuming 'DebugFlag' is false
16
Loop condition is false. Exiting loop
1122 // If this value is already vectorized, we are done.
1123 if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) {
1124 LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << " -> already vectorized: "
<< vecRepl; } } while (false)
;
1125 return vecRepl;
1126 }
1127
1128 // An vector operand that is not in the replacement map should never reach
1129 // this point. Reaching this point could mean that the code was already
1130 // vectorized and we shouldn't try to vectorize already vectorized code.
1131 assert(!operand.getType().isa<VectorType>() &&(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1132
, __extension__ __PRETTY_FUNCTION__))
17
Taking false branch
18
Assuming the condition is true
19
'?' condition is true
1132 "Vector op not found in replacement map")(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1132
, __extension__ __PRETTY_FUNCTION__))
;
1133
1134 // Vectorize constant.
1135 if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
20
Taking true branch
1136 auto vecConstant = vectorizeConstant(constOp, state);
1137 LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> constant: " <<
vecConstant; } } while (false)
;
21
Assuming 'DebugFlag' is true
22
Assuming the condition is true
23
Taking true branch
24
Null pointer value stored to 'op.state'
25
Calling 'operator<<'
1138 return vecConstant.getResult();
1139 }
1140
1141 // Vectorize uniform values.
1142 if (isUniformDefinition(operand, state.strategy)) {
1143 Operation *vecUniform = vectorizeUniform(operand, state);
1144 LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> uniform: " << *
vecUniform; } } while (false)
;
1145 return vecUniform->getResult(0);
1146 }
1147
1148 // Check for unsupported block argument scenarios. A supported block argument
1149 // should have been vectorized already.
1150 if (!operand.getDefiningOp())
1151 LLVM_DEBUG(dbgs() << "-> unsupported block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> unsupported block argument\n"
; } } while (false)
;
1152 else
1153 // Generic unsupported case.
1154 LLVM_DEBUG(dbgs() << "-> non-vectorizable\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> non-vectorizable\n";
} } while (false)
;
1155
1156 return nullptr;
1157}
1158
1159/// Vectorizes an affine load with the vectorization strategy in 'state' by
1160/// generating a 'vector.transfer_read' op with the proper permutation map
1161/// inferred from the indices of the load. The new 'vector.transfer_read' is
1162/// registered as replacement of the scalar load. Returns the newly created
1163/// 'vector.transfer_read' if vectorization was successful. Returns nullptr,
1164/// otherwise.
1165static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
1166 VectorizationState &state) {
1167 MemRefType memRefType = loadOp.getMemRefType();
1168 Type elementType = memRefType.getElementType();
1169 auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
1170
1171 // Replace map operands with operands from the vector loop nest.
1172 SmallVector<Value, 8> mapOperands;
1173 state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands);
1174
1175 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1176 SmallVector<Value, 8> indices;
1177 indices.reserve(memRefType.getRank());
1178 if (loadOp.getAffineMap() !=
1179 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1180 computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
1181 indices);
1182 else
1183 indices.append(mapOperands.begin(), mapOperands.end());
1184
1185 // Compute permutation map using the information of new vector loops.
1186 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1187 indices, state.vecLoopToVecDim);
1188 if (!permutationMap) {
1189 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n"
; } } while (false)
;
1190 return nullptr;
1191 }
1192 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1193 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1194
1195 auto transfer = state.builder.create<vector::TransferReadOp>(
1196 loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
1197
1198 // Register replacement for future uses in the scope.
1199 state.registerOpVectorReplacement(loadOp, transfer);
1200 return transfer;
1201}
1202
1203/// Vectorizes an affine store with the vectorization strategy in 'state' by
1204/// generating a 'vector.transfer_write' op with the proper permutation map
1205/// inferred from the indices of the store. The new 'vector.transfer_store' is
1206/// registered as replacement of the scalar load. Returns the newly created
1207/// 'vector.transfer_write' if vectorization was successful. Returns nullptr,
1208/// otherwise.
1209static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
1210 VectorizationState &state) {
1211 MemRefType memRefType = storeOp.getMemRefType();
1212 Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state);
1213 if (!vectorValue)
1214 return nullptr;
1215
1216 // Replace map operands with operands from the vector loop nest.
1217 SmallVector<Value, 8> mapOperands;
1218 state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands);
1219
1220 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1221 SmallVector<Value, 8> indices;
1222 indices.reserve(memRefType.getRank());
1223 if (storeOp.getAffineMap() !=
1224 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1225 computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state,
1226 indices);
1227 else
1228 indices.append(mapOperands.begin(), mapOperands.end());
1229
1230 // Compute permutation map using the information of new vector loops.
1231 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1232 indices, state.vecLoopToVecDim);
1233 if (!permutationMap)
1234 return nullptr;
1235 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1236 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1237
1238 auto transfer = state.builder.create<vector::TransferWriteOp>(
1239 storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
1240 permutationMap);
1241 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorized store: "
<< transfer; } } while (false)
;
1242
1243 // Register replacement for future uses in the scope.
1244 state.registerOpVectorReplacement(storeOp, transfer);
1245 return transfer;
1246}
1247
1248/// Returns true if `value` is a constant equal to the neutral element of the
1249/// given vectorizable reduction.
1250static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
1251 Value value, VectorizationState &state) {
1252 Type scalarTy = value.getType();
1253 if (!VectorType::isValidElementType(scalarTy))
1254 return false;
1255 Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
1256 state.builder, value.getLoc());
1257 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
1258 return constOp.getValue() == valueAttr;
1259 return false;
1260}
1261
1262/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is
1263/// created and registered as replacement for the scalar loop. The builder's
1264/// insertion point is set to the new loop's body so that subsequent vectorized
1265/// operations are inserted into the new loop. If the loop is a vector
1266/// dimension, the step of the newly created loop will reflect the vectorization
1267/// factor used to vectorized that dimension.
1268static Operation *vectorizeAffineForOp(AffineForOp forOp,
1269 VectorizationState &state) {
1270 const VectorizationStrategy &strategy = *state.strategy;
1271 auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp);
1272 bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end();
1273
1274 // TODO: Vectorization of reduction loops is not supported for non-unit steps.
1275 if (isLoopVecDim
11.1
'isLoopVecDim' is false
11.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0 && forOp.getStep() != 1) {
1276 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1277 dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1278 << "\n[early-vect]+++++ unsupported step size for reduction loop: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1279 << forOp.getStep() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
;
1280 return nullptr;
1281 }
1282
1283 // If we are vectorizing a vector dimension, compute a new step for the new
1284 // vectorized loop using the vectorization factor for the vector dimension.
1285 // Otherwise, propagate the step of the scalar loop.
1286 unsigned newStep;
1287 if (isLoopVecDim
11.2
'isLoopVecDim' is false
11.2
'isLoopVecDim' is false
) {
12
Taking false branch
1288 unsigned vectorDim = loopToVecDimIt->second;
1289 assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow")(static_cast <bool> (vectorDim < strategy.vectorSizes
.size() && "vector dim overflow") ? void (0) : __assert_fail
("vectorDim < strategy.vectorSizes.size() && \"vector dim overflow\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1289
, __extension__ __PRETTY_FUNCTION__))
;
1290 int64_t forOpVecFactor = strategy.vectorSizes[vectorDim];
1291 newStep = forOp.getStep() * forOpVecFactor;
1292 } else {
1293 newStep = forOp.getStep();
1294 }
1295
1296 // Get information about reduction kinds.
1297 ArrayRef<LoopReduction> reductions;
1298 if (isLoopVecDim
12.1
'isLoopVecDim' is false
12.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0) {
1299 auto it = strategy.reductionLoops.find(forOp);
1300 assert(it != strategy.reductionLoops.end() &&(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1301
, __extension__ __PRETTY_FUNCTION__))
1301 "Reduction descriptors not found when vectorizing a reduction loop")(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1301
, __extension__ __PRETTY_FUNCTION__))
;
1302 reductions = it->second;
1303 assert(reductions.size() == forOp.getNumIterOperands() &&(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1304
, __extension__ __PRETTY_FUNCTION__))
1304 "The size of reductions array must match the number of iter_args")(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1304
, __extension__ __PRETTY_FUNCTION__))
;
1305 }
1306
1307 // Vectorize 'iter_args'.
1308 SmallVector<Value, 8> vecIterOperands;
1309 if (!isLoopVecDim
12.2
'isLoopVecDim' is false
12.2
'isLoopVecDim' is false
) {
13
Taking true branch
1310 for (auto operand : forOp.getIterOperands())
1311 vecIterOperands.push_back(vectorizeOperand(operand, state));
14
Calling 'vectorizeOperand'
1312 } else {
1313 // For reduction loops we need to pass a vector of neutral elements as an
1314 // initial value of the accumulator. We will add the original initial value
1315 // later.
1316 for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) {
1317 vecIterOperands.push_back(createInitialVector(
1318 std::get<0>(redAndOperand).kind, std::get<1>(redAndOperand), state));
1319 }
1320 }
1321
1322 auto vecForOp = state.builder.create<AffineForOp>(
1323 forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
1324 forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
1325 vecIterOperands,
1326 /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
1327 // Make sure we don't create a default terminator in the loop body as
1328 // the proper terminator will be added during vectorization.
1329 });
1330
1331 // Register loop-related replacements:
1332 // 1) The new vectorized loop is registered as vector replacement of the
1333 // scalar loop.
1334 // 2) The new iv of the vectorized loop is registered as scalar replacement
1335 // since a scalar copy of the iv will prevail in the vectorized loop.
1336 // TODO: A vector replacement will also be added in the future when
1337 // vectorization of linear ops is supported.
1338 // 3) The new 'iter_args' region arguments are registered as vector
1339 // replacements since they have been vectorized.
1340 // 4) If the loop performs a reduction along the vector dimension, a
1341 // `vector.reduction` or similar op is inserted for each resulting value
1342 // of the loop and its scalar value replaces the corresponding scalar
1343 // result of the loop.
1344 state.registerOpVectorReplacement(forOp, vecForOp);
1345 state.registerValueScalarReplacement(forOp.getInductionVar(),
1346 vecForOp.getInductionVar());
1347 for (auto iterTuple :
1348 llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs()))
1349 state.registerBlockArgVectorReplacement(std::get<0>(iterTuple),
1350 std::get<1>(iterTuple));
1351
1352 if (isLoopVecDim) {
1353 for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) {
1354 // First, we reduce the vector returned from the loop into a scalar.
1355 Value reducedRes =
1356 getVectorReductionOp(reductions[i].kind, state.builder,
1357 vecForOp.getLoc(), vecForOp.getResult(i));
1358 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
1359 << reducedRes)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
;
1360 // Then we combine it with the original (scalar) initial value unless it
1361 // is equal to the neutral element of the reduction.
1362 Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
1363 Value finalRes = reducedRes;
1364 if (!isNeutralElementConst(reductions[i].kind, origInit, state))
1365 finalRes =
1366 arith::getReductionOp(reductions[i].kind, state.builder,
1367 reducedRes.getLoc(), reducedRes, origInit);
1368 state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
1369 }
1370 }
1371
1372 if (isLoopVecDim)
1373 state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second;
1374
1375 // Change insertion point so that upcoming vectorized instructions are
1376 // inserted into the vectorized loop's body.
1377 state.builder.setInsertionPointToStart(vecForOp.getBody());
1378
1379 // If this is a reduction loop then we may need to create a mask to filter out
1380 // garbage in the last iteration.
1381 if (isLoopVecDim && forOp.getNumIterOperands() > 0)
1382 createMask(vecForOp, state);
1383
1384 return vecForOp;
1385}
1386
1387/// Vectorizes arbitrary operation by plain widening. We apply generic type
1388/// widening of all its results and retrieve the vector counterparts for all its
1389/// operands.
1390static Operation *widenOp(Operation *op, VectorizationState &state) {
1391 SmallVector<Type, 8> vectorTypes;
1392 for (Value result : op->getResults())
1393 vectorTypes.push_back(
1394 VectorType::get(state.strategy->vectorSizes, result.getType()));
1395
1396 SmallVector<Value, 8> vectorOperands;
1397 for (Value operand : op->getOperands()) {
1398 Value vecOperand = vectorizeOperand(operand, state);
1399 if (!vecOperand) {
1400 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n"
; } } while (false)
;
1401 return nullptr;
1402 }
1403 vectorOperands.push_back(vecOperand);
1404 }
1405
1406 // Create a clone of the op with the proper operands and return types.
1407 // TODO: The following assumes there is always an op with a fixed
1408 // name that works both in scalar mode and vector mode.
1409 // TODO: Is it worth considering an Operation.clone operation which
1410 // changes the type so we can promote an Operation with less boilerplate?
1411 Operation *vecOp =
1412 state.builder.create(op->getLoc(), op->getName().getIdentifier(),
1413 vectorOperands, vectorTypes, op->getAttrs());
1414 state.registerOpVectorReplacement(op, vecOp);
1415 return vecOp;
1416}
1417
1418/// Vectorizes a yield operation by widening its types. The builder's insertion
1419/// point is set after the vectorized parent op to continue vectorizing the
1420/// operations after the parent op. When vectorizing a reduction loop a mask may
1421/// be used to prevent adding garbage values to the accumulator.
1422static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
1423 VectorizationState &state) {
1424 Operation *newYieldOp = widenOp(yieldOp, state);
1425 Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp();
1426
1427 // If there is a mask for this loop then we must prevent garbage values from
1428 // being added to the accumulator by inserting `select` operations, for
1429 // example:
1430 //
1431 // %res = arith.addf %acc, %val : vector<128xf32>
1432 // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32>
1433 // affine.yield %res_masked : vector<128xf32>
1434 //
1435 if (Value mask = state.vecLoopToMask.lookup(newParentOp)) {
1436 state.builder.setInsertionPoint(newYieldOp);
1437 for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) {
1438 Value result = newYieldOp->getOperand(i);
1439 Value iterArg = cast<AffineForOp>(newParentOp).getRegionIterArgs()[i];
1440 Value maskedResult = state.builder.create<arith::SelectOp>(
1441 result.getLoc(), mask, result, iterArg);
1442 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
1443 dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
1444 << maskedResult)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
;
1445 newYieldOp->setOperand(i, maskedResult);
1446 }
1447 }
1448
1449 state.builder.setInsertionPointAfter(newParentOp);
1450 return newYieldOp;
1451}
1452
1453/// Encodes Operation-specific behavior for vectorization. In general we
1454/// assume that all operands of an op must be vectorized but this is not
1455/// always true. In the future, it would be nice to have a trait that
1456/// describes how a particular operation vectorizes. For now we implement the
1457/// case distinction here. Returns a vectorized form of an operation or
1458/// nullptr if vectorization fails.
1459// TODO: consider adding a trait to Op to describe how it gets vectorized.
1460// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
1461// do one-off logic here; ideally it would be TableGen'd.
1462static Operation *vectorizeOneOperation(Operation *op,
1463 VectorizationState &state) {
1464 // Sanity checks.
1465 assert(!isa<vector::TransferReadOp>(op) &&(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1466
, __extension__ __PRETTY_FUNCTION__))
4
Assuming 'op' is not a 'TransferReadOp'
5
'?' condition is true
1466 "vector.transfer_read cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1466
, __extension__ __PRETTY_FUNCTION__))
;
1467 assert(!isa<vector::TransferWriteOp>(op) &&(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1468
, __extension__ __PRETTY_FUNCTION__))
6
Assuming 'op' is not a 'TransferWriteOp'
7
'?' condition is true
1468 "vector.transfer_write cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1468
, __extension__ __PRETTY_FUNCTION__))
;
1469
1470 if (auto loadOp = dyn_cast<AffineLoadOp>(op))
8
Taking false branch
1471 return vectorizeAffineLoad(loadOp, state);
1472 if (auto storeOp = dyn_cast<AffineStoreOp>(op))
9
Taking false branch
1473 return vectorizeAffineStore(storeOp, state);
1474 if (auto forOp = dyn_cast<AffineForOp>(op))
10
Taking true branch
1475 return vectorizeAffineForOp(forOp, state);
11
Calling 'vectorizeAffineForOp'
1476 if (auto yieldOp = dyn_cast<AffineYieldOp>(op))
1477 return vectorizeAffineYieldOp(yieldOp, state);
1478 if (auto constant = dyn_cast<arith::ConstantOp>(op))
1479 return vectorizeConstant(constant, state);
1480
1481 // Other ops with regions are not supported.
1482 if (op->getNumRegions() != 0)
1483 return nullptr;
1484
1485 return widenOp(op, state);
1486}
1487
1488/// Recursive implementation to convert all the nested loops in 'match' to a 2D
1489/// vector container that preserves the relative nesting level of each loop with
1490/// respect to the others in 'match'. 'currentLevel' is the nesting level that
1491/// will be assigned to the loop in the current 'match'.
1492static void
1493getMatchedAffineLoopsRec(NestedMatch match, unsigned currentLevel,
1494 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1495 // Add a new empty level to the output if it doesn't exist already.
1496 assert(currentLevel <= loops.size() && "Unexpected currentLevel")(static_cast <bool> (currentLevel <= loops.size() &&
"Unexpected currentLevel") ? void (0) : __assert_fail ("currentLevel <= loops.size() && \"Unexpected currentLevel\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1496
, __extension__ __PRETTY_FUNCTION__))
;
1497 if (currentLevel == loops.size())
1498 loops.emplace_back();
1499
1500 // Add current match and recursively visit its children.
1501 loops[currentLevel].push_back(cast<AffineForOp>(match.getMatchedOperation()));
1502 for (auto childMatch : match.getMatchedChildren()) {
1503 getMatchedAffineLoopsRec(childMatch, currentLevel + 1, loops);
1504 }
1505}
1506
1507/// Converts all the nested loops in 'match' to a 2D vector container that
1508/// preserves the relative nesting level of each loop with respect to the others
1509/// in 'match'. This means that every loop in 'loops[i]' will have a parent loop
1510/// in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in
1511/// 'loops[i+1]'.
1512static void
1513getMatchedAffineLoops(NestedMatch match,
1514 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1515 getMatchedAffineLoopsRec(match, /*currLoopDepth=*/0, loops);
1516}
1517
1518/// Internal implementation to vectorize affine loops from a single loop nest
1519/// using an n-D vectorization strategy.
1520static LogicalResult
1521vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1522 const VectorizationStrategy &strategy) {
1523 assert(loops[0].size() == 1 && "Expected single root loop")(static_cast <bool> (loops[0].size() == 1 && "Expected single root loop"
) ? void (0) : __assert_fail ("loops[0].size() == 1 && \"Expected single root loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1523
, __extension__ __PRETTY_FUNCTION__))
;
1524 AffineForOp rootLoop = loops[0][0];
1525 VectorizationState state(rootLoop.getContext());
1526 state.builder.setInsertionPointAfter(rootLoop);
1527 state.strategy = &strategy;
1528
1529 // Since patterns are recursive, they can very well intersect.
1530 // Since we do not want a fully greedy strategy in general, we decouple
1531 // pattern matching, from profitability analysis, from application.
1532 // As a consequence we must check that each root pattern is still
1533 // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
1534 // TODO: implement a non-greedy profitability analysis that keeps only
1535 // non-intersecting patterns.
1536 if (!isVectorizableLoopBody(rootLoop, vectorTransferPattern())) {
1537 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ loop is not vectorizable"
; } } while (false)
;
1538 return failure();
1539 }
1540
1541 //////////////////////////////////////////////////////////////////////////////
1542 // Vectorize the scalar loop nest following a topological order. A new vector
1543 // loop nest with the vectorized operations is created along the process. If
1544 // vectorization succeeds, the scalar loop nest is erased. If vectorization
1545 // fails, the vector loop nest is erased and the scalar loop nest is not
1546 // modified.
1547 //////////////////////////////////////////////////////////////////////////////
1548
1549 auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) {
1550 LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ Vectorizing: "
<< *op; } } while (false)
;
1
Assuming 'DebugFlag' is false
2
Loop condition is false. Exiting loop
1551 Operation *vectorOp = vectorizeOneOperation(op, state);
3
Calling 'vectorizeOneOperation'
1552 if (!vectorOp) {
1553 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1554 dbgs() << "[early-vect]+++++ failed vectorizing the operation: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1555 << *op << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
;
1556 return WalkResult::interrupt();
1557 }
1558
1559 return WalkResult::advance();
1560 });
1561
1562 if (opVecResult.wasInterrupted()) {
1563 LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
1564 << rootLoop << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
;
1565 // Erase vector loop nest if it was created.
1566 auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop);
1567 if (vecRootLoopIt != state.opVectorReplacement.end())
1568 eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second));
1569
1570 return failure();
1571 }
1572
1573 // Replace results of reduction loops with the scalar values computed using
1574 // `vector.reduce` or similar ops.
1575 for (auto resPair : state.loopResultScalarReplacement)
1576 resPair.first.replaceAllUsesWith(resPair.second);
1577
1578 assert(state.opVectorReplacement.count(rootLoop) == 1 &&(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1579
, __extension__ __PRETTY_FUNCTION__))
1579 "Expected vector replacement for loop nest")(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1579
, __extension__ __PRETTY_FUNCTION__))
;
1580 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ success vectorizing pattern"
; } } while (false)
;
1581 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
1582 << *state.opVectorReplacement[rootLoop])do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
;
1583
1584 // Finish this vectorization pattern.
1585 state.finishVectorizationPattern(rootLoop);
1586 return success();
1587}
1588
1589/// Extracts the matched loops and vectorizes them following a topological
1590/// order. A new vector loop nest will be created if vectorization succeeds. The
1591/// original loop nest won't be modified in any case.
1592static LogicalResult vectorizeRootMatch(NestedMatch m,
1593 const VectorizationStrategy &strategy) {
1594 std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
1595 getMatchedAffineLoops(m, loopsToVectorize);
1596 return vectorizeLoopNest(loopsToVectorize, strategy);
1597}
1598
1599/// Traverses all the loop matches and classifies them into intersection
1600/// buckets. Two matches intersect if any of them encloses the other one. A
1601/// match intersects with a bucket if the match intersects with the root
1602/// (outermost) loop in that bucket.
1603static void computeIntersectionBuckets(
1604 ArrayRef<NestedMatch> matches,
1605 std::vector<SmallVector<NestedMatch, 8>> &intersectionBuckets) {
1606 assert(intersectionBuckets.empty() && "Expected empty output")(static_cast <bool> (intersectionBuckets.empty() &&
"Expected empty output") ? void (0) : __assert_fail ("intersectionBuckets.empty() && \"Expected empty output\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1606
, __extension__ __PRETTY_FUNCTION__))
;
1607 // Keeps track of the root (outermost) loop of each bucket.
1608 SmallVector<AffineForOp, 8> bucketRoots;
1609
1610 for (const NestedMatch &match : matches) {
1611 AffineForOp matchRoot = cast<AffineForOp>(match.getMatchedOperation());
1612 bool intersects = false;
1613 for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) {
1614 AffineForOp bucketRoot = bucketRoots[i];
1615 // Add match to the bucket if the bucket root encloses the match root.
1616 if (bucketRoot->isAncestor(matchRoot)) {
1617 intersectionBuckets[i].push_back(match);
1618 intersects = true;
1619 break;
1620 }
1621 // Add match to the bucket if the match root encloses the bucket root. The
1622 // match root becomes the new bucket root.
1623 if (matchRoot->isAncestor(bucketRoot)) {
1624 bucketRoots[i] = matchRoot;
1625 intersectionBuckets[i].push_back(match);
1626 intersects = true;
1627 break;
1628 }
1629 }
1630
1631 // Match doesn't intersect with any existing bucket. Create a new bucket for
1632 // it.
1633 if (!intersects) {
1634 bucketRoots.push_back(matchRoot);
1635 intersectionBuckets.emplace_back();
1636 intersectionBuckets.back().push_back(match);
1637 }
1638 }
1639}
1640
1641/// Internal implementation to vectorize affine loops in 'loops' using the n-D
1642/// vectorization factors in 'vectorSizes'. By default, each vectorization
1643/// factor is applied inner-to-outer to the loops of each loop nest.
1644/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1645/// vectorization order. `reductionLoops` can be provided to specify loops which
1646/// can be vectorized along the reduction dimension.
1647static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1648 ArrayRef<int64_t> vectorSizes,
1649 ArrayRef<int64_t> fastestVaryingPattern,
1650 const ReductionLoopMap &reductionLoops) {
1651 assert((reductionLoops.empty() || vectorSizes.size() == 1) &&(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1652
, __extension__ __PRETTY_FUNCTION__))
1652 "Vectorizing reductions is supported only for 1-D vectors")(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1652
, __extension__ __PRETTY_FUNCTION__))
;
1653
1654 // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops.
1655 Optional<NestedPattern> pattern =
1656 makePattern(loops, vectorSizes.size(), fastestVaryingPattern);
1657 if (!pattern.hasValue()) {
1658 LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] pattern couldn't be computed\n"
; } } while (false)
;
1659 return;
1660 }
1661
1662 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1663 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1664 LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] new pattern on parent op\n"
; } } while (false)
;
1665 LLVM_DEBUG(dbgs() << *parentOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *parentOp << "\n"; } }
while (false)
;
1666
1667 unsigned patternDepth = pattern->getDepth();
1668
1669 // Compute all the pattern matches and classify them into buckets of
1670 // intersecting matches.
1671 SmallVector<NestedMatch, 32> allMatches;
1672 pattern->match(parentOp, &allMatches);
1673 std::vector<SmallVector<NestedMatch, 8>> intersectionBuckets;
1674 computeIntersectionBuckets(allMatches, intersectionBuckets);
1675
1676 // Iterate over all buckets and vectorize the matches eagerly. We can only
1677 // vectorize one match from each bucket since all the matches within a bucket
1678 // intersect.
1679 for (auto &intersectingMatches : intersectionBuckets) {
1680 for (NestedMatch &match : intersectingMatches) {
1681 VectorizationStrategy strategy;
1682 // TODO: depending on profitability, elect to reduce the vector size.
1683 strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end());
1684 strategy.reductionLoops = reductionLoops;
1685 if (failed(analyzeProfitability(match.getMatchedChildren(), 1,
1686 patternDepth, &strategy))) {
1687 continue;
1688 }
1689 vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth,
1690 &strategy);
1691 // Vectorize match. Skip the rest of intersecting matches in the bucket if
1692 // vectorization succeeded.
1693 // TODO: if pattern does not apply, report it; alter the cost/benefit.
1694 // TODO: some diagnostics if failure to vectorize occurs.
1695 if (succeeded(vectorizeRootMatch(match, strategy)))
1696 break;
1697 }
1698 }
1699
1700 LLVM_DEBUG(dbgs() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n"; } } while (false)
;
1701}
1702
1703std::unique_ptr<OperationPass<func::FuncOp>>
1704createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) {
1705 return std::make_unique<Vectorize>(virtualVectorSize);
1706}
1707std::unique_ptr<OperationPass<func::FuncOp>> createSuperVectorizePass() {
1708 return std::make_unique<Vectorize>();
1709}
1710
1711/// Applies vectorization to the current function by searching over a bunch of
1712/// predetermined patterns.
1713void Vectorize::runOnOperation() {
1714 func::FuncOp f = getOperation();
1715 if (!fastestVaryingPattern.empty() &&
1716 fastestVaryingPattern.size() != vectorSizes.size()) {
1717 f.emitRemark("Fastest varying pattern specified with different size than "
1718 "the vector size.");
1719 return signalPassFailure();
1720 }
1721
1722 if (vectorizeReductions && vectorSizes.size() != 1) {
1723 f.emitError("Vectorizing reductions is supported only for 1-D vectors.");
1724 return signalPassFailure();
1725 }
1726
1727 DenseSet<Operation *> parallelLoops;
1728 ReductionLoopMap reductionLoops;
1729
1730 // If 'vectorize-reduction=true' is provided, we also populate the
1731 // `reductionLoops` map.
1732 if (vectorizeReductions) {
1733 f.walk([&parallelLoops, &reductionLoops](AffineForOp loop) {
1734 SmallVector<LoopReduction, 2> reductions;
1735 if (isLoopParallel(loop, &reductions)) {
1736 parallelLoops.insert(loop);
1737 // If it's not a reduction loop, adding it to the map is not necessary.
1738 if (!reductions.empty())
1739 reductionLoops[loop] = reductions;
1740 }
1741 });
1742 } else {
1743 f.walk([&parallelLoops](AffineForOp loop) {
1744 if (isLoopParallel(loop))
1745 parallelLoops.insert(loop);
1746 });
1747 }
1748
1749 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1750 NestedPatternContext mlContext;
1751 vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern,
1752 reductionLoops);
1753}
1754
1755/// Verify that affine loops in 'loops' meet the nesting criteria expected by
1756/// SuperVectorizer:
1757/// * There must be at least one loop.
1758/// * There must be a single root loop (nesting level 0).
1759/// * Each loop at a given nesting level must be nested in a loop from a
1760/// previous nesting level.
1761static LogicalResult
1762verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
1763 // Expected at least one loop.
1764 if (loops.empty())
1765 return failure();
1766
1767 // Expected only one root loop.
1768 if (loops[0].size() != 1)
1769 return failure();
1770
1771 // Traverse loops outer-to-inner to check some invariants.
1772 for (int i = 1, end = loops.size(); i < end; ++i) {
1773 for (AffineForOp loop : loops[i]) {
1774 // Check that each loop at this level is nested in one of the loops from
1775 // the previous level.
1776 if (none_of(loops[i - 1], [&](AffineForOp maybeParent) {
1777 return maybeParent->isProperAncestor(loop);
1778 }))
1779 return failure();
1780
1781 // Check that each loop at this level is not nested in another loop from
1782 // this level.
1783 for (AffineForOp sibling : loops[i]) {
1784 if (sibling->isProperAncestor(loop))
1785 return failure();
1786 }
1787 }
1788 }
1789
1790 return success();
1791}
1792
1793namespace mlir {
1794
1795/// External utility to vectorize affine loops in 'loops' using the n-D
1796/// vectorization factors in 'vectorSizes'. By default, each vectorization
1797/// factor is applied inner-to-outer to the loops of each loop nest.
1798/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1799/// vectorization order.
1800/// If `reductionLoops` is not empty, the given reduction loops may be
1801/// vectorized along the reduction dimension.
1802/// TODO: Vectorizing reductions is supported only for 1-D vectorization.
1803void vectorizeAffineLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1804 ArrayRef<int64_t> vectorSizes,
1805 ArrayRef<int64_t> fastestVaryingPattern,
1806 const ReductionLoopMap &reductionLoops) {
1807 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1808 NestedPatternContext mlContext;
1809 vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern,
1810 reductionLoops);
1811}
1812
1813/// External utility to vectorize affine loops from a single loop nest using an
1814/// n-D vectorization strategy (see doc in VectorizationStrategy definition).
1815/// Loops are provided in a 2D vector container. The first dimension represents
1816/// the nesting level relative to the loops to be vectorized. The second
1817/// dimension contains the loops. This means that:
1818/// a) every loop in 'loops[i]' must have a parent loop in 'loops[i-1]',
1819/// b) a loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'.
1820///
1821/// For example, for the following loop nest:
1822///
1823/// func @vec2d(%in0: memref<64x128x512xf32>, %in1: memref<64x128x128xf32>,
1824/// %out0: memref<64x128x512xf32>,
1825/// %out1: memref<64x128x128xf32>) {
1826/// affine.for %i0 = 0 to 64 {
1827/// affine.for %i1 = 0 to 128 {
1828/// affine.for %i2 = 0 to 512 {
1829/// %ld = affine.load %in0[%i0, %i1, %i2] : memref<64x128x512xf32>
1830/// affine.store %ld, %out0[%i0, %i1, %i2] : memref<64x128x512xf32>
1831/// }
1832/// affine.for %i3 = 0 to 128 {
1833/// %ld = affine.load %in1[%i0, %i1, %i3] : memref<64x128x128xf32>
1834/// affine.store %ld, %out1[%i0, %i1, %i3] : memref<64x128x128xf32>
1835/// }
1836/// }
1837/// }
1838/// return
1839/// }
1840///
1841/// loops = {{%i0}, {%i2, %i3}}, to vectorize the outermost and the two
1842/// innermost loops;
1843/// loops = {{%i1}, {%i2, %i3}}, to vectorize the middle and the two innermost
1844/// loops;
1845/// loops = {{%i2}}, to vectorize only the first innermost loop;
1846/// loops = {{%i3}}, to vectorize only the second innermost loop;
1847/// loops = {{%i1}}, to vectorize only the middle loop.
1848LogicalResult
1849vectorizeAffineLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1850 const VectorizationStrategy &strategy) {
1851 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1852 NestedPatternContext mlContext;
1853 if (failed(verifyLoopNesting(loops)))
1854 return failure();
1855 return vectorizeLoopNest(loops, strategy);
1856}
1857
1858std::unique_ptr<OperationPass<func::FuncOp>>
1859createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) {
1860 return std::make_unique<Vectorize>(virtualVectorSize);
1861}
1862std::unique_ptr<OperationPass<func::FuncOp>> createSuperVectorizePass() {
1863 return std::make_unique<Vectorize>();
1864}
1865
1866} // namespace mlir

/build/llvm-toolchain-snapshot-15~++20220420111733+e13d2efed663/mlir/include/mlir/IR/OpDefinition.h

1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24#include "llvm/Support/PointerLikeTypeTraits.h"
25
26#include <type_traits>
27
28namespace mlir {
29class Builder;
30class OpBuilder;
31
32/// This class implements `Optional` functionality for ParseResult. We don't
33/// directly use Optional here, because it provides an implicit conversion
34/// to 'bool' which we want to avoid. This class is used to implement tri-state
35/// 'parseOptional' functions that may have a failure mode when parsing that
36/// shouldn't be attributed to "not present".
37class OptionalParseResult {
38public:
39 OptionalParseResult() = default;
40 OptionalParseResult(LogicalResult result) : impl(result) {}
41 OptionalParseResult(ParseResult result) : impl(result) {}
42 OptionalParseResult(const InFlightDiagnostic &)
43 : OptionalParseResult(failure()) {}
44 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
45
46 /// Returns true if we contain a valid ParseResult value.
47 bool hasValue() const { return impl.hasValue(); }
48
49 /// Access the internal ParseResult value.
50 ParseResult getValue() const { return impl.getValue(); }
51 ParseResult operator*() const { return getValue(); }
52
53private:
54 Optional<ParseResult> impl;
55};
56
57// These functions are out-of-line utilities, which avoids them being template
58// instantiated/duplicated.
59namespace impl {
60/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
61/// region's only block if it does not have a terminator already. If the region
62/// is empty, insert a new block first. `buildTerminatorOp` should return the
63/// terminator operation to insert.
64void ensureRegionTerminator(
65 Region &region, OpBuilder &builder, Location loc,
66 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
67void ensureRegionTerminator(
68 Region &region, Builder &builder, Location loc,
69 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
70
71} // namespace impl
72
73/// This is the concrete base class that holds the operation pointer and has
74/// non-generic methods that only depend on State (to avoid having them
75/// instantiated on template types that don't affect them.
76///
77/// This also has the fallback implementations of customization hooks for when
78/// they aren't customized.
79class OpState {
80public:
81 /// Ops are pointer-like, so we allow conversion to bool.
82 explicit operator bool() { return getOperation() != nullptr; }
83
84 /// This implicitly converts to Operation*.
85 operator Operation *() const { return state; }
86
87 /// Shortcut of `->` to access a member of Operation.
88 Operation *operator->() const { return state; }
89
90 /// Return the operation that this refers to.
91 Operation *getOperation() { return state; }
92
93 /// Return the context this operation belongs to.
94 MLIRContext *getContext() { return getOperation()->getContext(); }
95
96 /// Print the operation to the given stream.
97 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
98 state->print(os, flags);
27
Called C++ object pointer is null
99 }
100 void print(raw_ostream &os, AsmState &asmState) {
101 state->print(os, asmState);
102 }
103
104 /// Dump this operation.
105 void dump() { state->dump(); }
106
107 /// The source location the operation was defined or derived from.
108 Location getLoc() { return state->getLoc(); }
109
110 /// Return true if there are no users of any results of this operation.
111 bool use_empty() { return state->use_empty(); }
112
113 /// Remove this operation from its parent block and delete it.
114 void erase() { state->erase(); }
115
116 /// Emit an error with the op name prefixed, like "'dim' op " which is
117 /// convenient for verifiers.
118 InFlightDiagnostic emitOpError(const Twine &message = {});
119
120 /// Emit an error about fatal conditions with this operation, reporting up to
121 /// any diagnostic handlers that may be listening.
122 InFlightDiagnostic emitError(const Twine &message = {});
123
124 /// Emit a warning about this operation, reporting up to any diagnostic
125 /// handlers that may be listening.
126 InFlightDiagnostic emitWarning(const Twine &message = {});
127
128 /// Emit a remark about this operation, reporting up to any diagnostic
129 /// handlers that may be listening.
130 InFlightDiagnostic emitRemark(const Twine &message = {});
131
132 /// Walk the operation by calling the callback for each nested operation
133 /// (including this one), block or region, depending on the callback provided.
134 /// Regions, blocks and operations at the same nesting level are visited in
135 /// lexicographical order. The walk order for enclosing regions, blocks and
136 /// operations with respect to their nested ones is specified by 'Order'
137 /// (post-order by default). A callback on a block or operation is allowed to
138 /// erase that block or operation if either:
139 /// * the walk is in post-order, or
140 /// * the walk is in pre-order and the walk is skipped after the erasure.
141 /// See Operation::walk for more details.
142 template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
143 typename RetT = detail::walkResultType<FnT>>
144 typename std::enable_if<
145 llvm::function_traits<std::decay_t<FnT>>::num_args == 1, RetT>::type
146 walk(FnT &&callback) {
147 return state->walk<Order>(std::forward<FnT>(callback));
148 }
149
150 /// Generic walker with a stage aware callback. Walk the operation by calling
151 /// the callback for each nested operation (including this one) N+1 times,
152 /// where N is the number of regions attached to that operation.
153 ///
154 /// The callback method can take any of the following forms:
155 /// void(Operation *, const WalkStage &) : Walk all operation opaquely
156 /// * op.walk([](Operation *nestedOp, const WalkStage &stage) { ...});
157 /// void(OpT, const WalkStage &) : Walk all operations of the given derived
158 /// type.
159 /// * op.walk([](ReturnOp returnOp, const WalkStage &stage) { ...});
160 /// WalkResult(Operation*|OpT, const WalkStage &stage) : Walk operations,
161 /// but allow for interruption/skipping.
162 /// * op.walk([](... op, const WalkStage &stage) {
163 /// // Skip the walk of this op based on some invariant.
164 /// if (some_invariant)
165 /// return WalkResult::skip();
166 /// // Interrupt, i.e cancel, the walk based on some invariant.
167 /// if (another_invariant)
168 /// return WalkResult::interrupt();
169 /// return WalkResult::advance();
170 /// });
171 template <typename FnT, typename RetT = detail::walkResultType<FnT>>
172 typename std::enable_if<
173 llvm::function_traits<std::decay_t<FnT>>::num_args == 2, RetT>::type
174 walk(FnT &&callback) {
175 return state->walk(std::forward<FnT>(callback));
176 }
177
178 // These are default implementations of customization hooks.
179public:
180 /// This hook returns any canonicalization pattern rewrites that the operation
181 /// supports, for use by the canonicalization pass.
182 static void getCanonicalizationPatterns(RewritePatternSet &results,
183 MLIRContext *context) {}
184
185protected:
186 /// If the concrete type didn't implement a custom verifier hook, just fall
187 /// back to this one which accepts everything.
188 LogicalResult verify() { return success(); }
189 LogicalResult verifyRegions() { return success(); }
190
191 /// Parse the custom form of an operation. Unless overridden, this method will
192 /// first try to get an operation parser from the op's dialect. Otherwise the
193 /// custom assembly form of an op is always rejected. Op implementations
194 /// should implement this to return failure. On success, they should fill in
195 /// result with the fields to use.
196 static ParseResult parse(OpAsmParser &parser, OperationState &result);
197
198 /// Print the operation. Unless overridden, this method will first try to get
199 /// an operation printer from the dialect. Otherwise, it prints the operation
200 /// in generic form.
201 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
202
203 /// Print an operation name, eliding the dialect prefix if necessary.
204 static void printOpName(Operation *op, OpAsmPrinter &p,
205 StringRef defaultDialect);
206
207 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
208 /// so we can cast it away here.
209 explicit OpState(Operation *state) : state(state) {}
210
211private:
212 Operation *state;
213
214 /// Allow access to internal hook implementation methods.
215 friend RegisteredOperationName;
216};
217
218// Allow comparing operators.
219inline bool operator==(OpState lhs, OpState rhs) {
220 return lhs.getOperation() == rhs.getOperation();
221}
222inline bool operator!=(OpState lhs, OpState rhs) {
223 return lhs.getOperation() != rhs.getOperation();
224}
225
226raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
227
228/// This class represents a single result from folding an operation.
229class OpFoldResult : public PointerUnion<Attribute, Value> {
230 using PointerUnion<Attribute, Value>::PointerUnion;
231
232public:
233 void dump() { llvm::errs() << *this << "\n"; }
234};
235
236/// Allow printing to a stream.
237inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
238 if (Value value = ofr.dyn_cast<Value>())
239 value.print(os);
240 else
241 ofr.dyn_cast<Attribute>().print(os);
242 return os;
243}
244
245/// Allow printing to a stream.
246inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
247 op.print(os, OpPrintingFlags().useLocalScope());
26
Calling 'OpState::print'
248 return os;
249}
250
251//===----------------------------------------------------------------------===//
252// Operation Trait Types
253//===----------------------------------------------------------------------===//
254
255namespace OpTrait {
256
257// These functions are out-of-line implementations of the methods in the
258// corresponding trait classes. This avoids them being template
259// instantiated/duplicated.
260namespace impl {
261OpFoldResult foldIdempotent(Operation *op);
262OpFoldResult foldInvolution(Operation *op);
263LogicalResult verifyZeroOperands(Operation *op);
264LogicalResult verifyOneOperand(Operation *op);
265LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
266LogicalResult verifyIsIdempotent(Operation *op);
267LogicalResult verifyIsInvolution(Operation *op);
268LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
269LogicalResult verifyOperandsAreFloatLike(Operation *op);
270LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
271LogicalResult verifySameTypeOperands(Operation *op);
272LogicalResult verifyZeroRegion(Operation *op);
273LogicalResult verifyOneRegion(Operation *op);
274LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
275LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
276LogicalResult verifyZeroResult(Operation *op);
277LogicalResult verifyOneResult(Operation *op);
278LogicalResult verifyNResults(Operation *op, unsigned numOperands);
279LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
280LogicalResult verifySameOperandsShape(Operation *op);
281LogicalResult verifySameOperandsAndResultShape(Operation *op);
282LogicalResult verifySameOperandsElementType(Operation *op);
283LogicalResult verifySameOperandsAndResultElementType(Operation *op);
284LogicalResult verifySameOperandsAndResultType(Operation *op);
285LogicalResult verifyResultsAreBoolLike(Operation *op);
286LogicalResult verifyResultsAreFloatLike(Operation *op);
287LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
288LogicalResult verifyIsTerminator(Operation *op);
289LogicalResult verifyZeroSuccessor(Operation *op);
290LogicalResult verifyOneSuccessor(Operation *op);
291LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
292LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
293LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
294 StringRef valueGroupName,
295 size_t expectedCount);
296LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
297LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
298LogicalResult verifyNoRegionArguments(Operation *op);
299LogicalResult verifyElementwise(Operation *op);
300LogicalResult verifyIsIsolatedFromAbove(Operation *op);
301} // namespace impl
302
303/// Helper class for implementing traits. Clients are not expected to interact
304/// with this directly, so its members are all protected.
305template <typename ConcreteType, template <typename> class TraitType>
306class TraitBase {
307protected:
308 /// Return the ultimate Operation being worked on.
309 Operation *getOperation() {
310 auto *concrete = static_cast<ConcreteType *>(this);
311 return concrete->getOperation();
312 }
313};
314
315//===----------------------------------------------------------------------===//
316// Operand Traits
317
318namespace detail {
319/// Utility trait base that provides accessors for derived traits that have
320/// multiple operands.
321template <typename ConcreteType, template <typename> class TraitType>
322struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
323 using operand_iterator = Operation::operand_iterator;
324 using operand_range = Operation::operand_range;
325 using operand_type_iterator = Operation::operand_type_iterator;
326 using operand_type_range = Operation::operand_type_range;
327
328 /// Return the number of operands.
329 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
330
331 /// Return the operand at index 'i'.
332 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
333
334 /// Set the operand at index 'i' to 'value'.
335 void setOperand(unsigned i, Value value) {
336 this->getOperation()->setOperand(i, value);
337 }
338
339 /// Operand iterator access.
340 operand_iterator operand_begin() {
341 return this->getOperation()->operand_begin();
342 }
343 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
344 operand_range getOperands() { return this->getOperation()->getOperands(); }
345
346 /// Operand type access.
347 operand_type_iterator operand_type_begin() {
348 return this->getOperation()->operand_type_begin();
349 }
350 operand_type_iterator operand_type_end() {
351 return this->getOperation()->operand_type_end();
352 }
353 operand_type_range getOperandTypes() {
354 return this->getOperation()->getOperandTypes();
355 }
356};
357} // namespace detail
358
359/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
360/// It should be run after core traits and before any other user defined traits.
361/// In order to run it in the correct order, wrap it with OpInvariants trait so
362/// that tblgen will be able to put it in the right order.
363template <typename ConcreteType>
364class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
365public:
366 static LogicalResult verifyTrait(Operation *op) {
367 return cast<ConcreteType>(op).verifyInvariantsImpl();
368 }
369};
370
371/// This class provides the API for ops that are known to have no
372/// SSA operand.
373template <typename ConcreteType>
374class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
375public:
376 static LogicalResult verifyTrait(Operation *op) {
377 return impl::verifyZeroOperands(op);
378 }
379
380private:
381 // Disable these.
382 void getOperand() {}
383 void setOperand() {}
384};
385
386/// This class provides the API for ops that are known to have exactly one
387/// SSA operand.
388template <typename ConcreteType>
389class OneOperand : public TraitBase<ConcreteType, OneOperand> {
390public:
391 Value getOperand() { return this->getOperation()->getOperand(0); }
392
393 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
394
395 static LogicalResult verifyTrait(Operation *op) {
396 return impl::verifyOneOperand(op);
397 }
398};
399
400/// This class provides the API for ops that are known to have a specified
401/// number of operands. This is used as a trait like this:
402///
403/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
404///
405template <unsigned N>
406class NOperands {
407public:
408 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
409
410 template <typename ConcreteType>
411 class Impl
412 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
413 public:
414 static LogicalResult verifyTrait(Operation *op) {
415 return impl::verifyNOperands(op, N);
416 }
417 };
418};
419
420/// This class provides the API for ops that are known to have a at least a
421/// specified number of operands. This is used as a trait like this:
422///
423/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
424///
425template <unsigned N>
426class AtLeastNOperands {
427public:
428 template <typename ConcreteType>
429 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
430 AtLeastNOperands<N>::Impl> {
431 public:
432 static LogicalResult verifyTrait(Operation *op) {
433 return impl::verifyAtLeastNOperands(op, N);
434 }
435 };
436};
437
438/// This class provides the API for ops which have an unknown number of
439/// SSA operands.
440template <typename ConcreteType>
441class VariadicOperands
442 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
443
444//===----------------------------------------------------------------------===//
445// Region Traits
446
447/// This class provides verification for ops that are known to have zero
448/// regions.
449template <typename ConcreteType>
450class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
451public:
452 static LogicalResult verifyTrait(Operation *op) {
453 return impl::verifyZeroRegion(op);
454 }
455};
456
457namespace detail {
458/// Utility trait base that provides accessors for derived traits that have
459/// multiple regions.
460template <typename ConcreteType, template <typename> class TraitType>
461struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
462 using region_iterator = MutableArrayRef<Region>;
463 using region_range = RegionRange;
464
465 /// Return the number of regions.
466 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
467
468 /// Return the region at `index`.
469 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
470
471 /// Region iterator access.
472 region_iterator region_begin() {
473 return this->getOperation()->region_begin();
474 }
475 region_iterator region_end() { return this->getOperation()->region_end(); }
476 region_range getRegions() { return this->getOperation()->getRegions(); }
477};
478} // namespace detail
479
480/// This class provides APIs for ops that are known to have a single region.
481template <typename ConcreteType>
482class OneRegion : public TraitBase<ConcreteType, OneRegion> {
483public:
484 Region &getRegion() { return this->getOperation()->getRegion(0); }
485
486 /// Returns a range of operations within the region of this operation.
487 auto getOps() { return getRegion().getOps(); }
488 template <typename OpT>
489 auto getOps() {
490 return getRegion().template getOps<OpT>();
491 }
492
493 static LogicalResult verifyTrait(Operation *op) {
494 return impl::verifyOneRegion(op);
495 }
496};
497
498/// This class provides the API for ops that are known to have a specified
499/// number of regions.
500template <unsigned N>
501class NRegions {
502public:
503 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
504
505 template <typename ConcreteType>
506 class Impl
507 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
508 public:
509 static LogicalResult verifyTrait(Operation *op) {
510 return impl::verifyNRegions(op, N);
511 }
512 };
513};
514
515/// This class provides APIs for ops that are known to have at least a specified
516/// number of regions.
517template <unsigned N>
518class AtLeastNRegions {
519public:
520 template <typename ConcreteType>
521 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
522 AtLeastNRegions<N>::Impl> {
523 public:
524 static LogicalResult verifyTrait(Operation *op) {
525 return impl::verifyAtLeastNRegions(op, N);
526 }
527 };
528};
529
530/// This class provides the API for ops which have an unknown number of
531/// regions.
532template <typename ConcreteType>
533class VariadicRegions
534 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
535
536//===----------------------------------------------------------------------===//
537// Result Traits
538
539/// This class provides return value APIs for ops that are known to have
540/// zero results.
541template <typename ConcreteType>
542class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
543public:
544 static LogicalResult verifyTrait(Operation *op) {
545 return impl::verifyZeroResult(op);
546 }
547};
548
549namespace detail {
550/// Utility trait base that provides accessors for derived traits that have
551/// multiple results.
552template <typename ConcreteType, template <typename> class TraitType>
553struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
554 using result_iterator = Operation::result_iterator;
555 using result_range = Operation::result_range;
556 using result_type_iterator = Operation::result_type_iterator;
557 using result_type_range = Operation::result_type_range;
558
559 /// Return the number of results.
560 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
561
562 /// Return the result at index 'i'.
563 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
564
565 /// Replace all uses of results of this operation with the provided 'values'.
566 /// 'values' may correspond to an existing operation, or a range of 'Value'.
567 template <typename ValuesT>
568 void replaceAllUsesWith(ValuesT &&values) {
569 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
570 }
571
572 /// Return the type of the `i`-th result.
573 Type getType(unsigned i) { return getResult(i).getType(); }
574
575 /// Result iterator access.
576 result_iterator result_begin() {
577 return this->getOperation()->result_begin();
578 }
579 result_iterator result_end() { return this->getOperation()->result_end(); }
580 result_range getResults() { return this->getOperation()->getResults(); }
581
582 /// Result type access.
583 result_type_iterator result_type_begin() {
584 return this->getOperation()->result_type_begin();
585 }
586 result_type_iterator result_type_end() {
587 return this->getOperation()->result_type_end();
588 }
589 result_type_range getResultTypes() {
590 return this->getOperation()->getResultTypes();
591 }
592};
593} // namespace detail
594
595/// This class provides return value APIs for ops that are known to have a
596/// single result. ResultType is the concrete type returned by getType().
597template <typename ConcreteType>
598class OneResult : public TraitBase<ConcreteType, OneResult> {
599public:
600 Value getResult() { return this->getOperation()->getResult(0); }
601
602 /// If the operation returns a single value, then the Op can be implicitly
603 /// converted to an Value. This yields the value of the only result.
604 operator Value() { return getResult(); }
605
606 /// Replace all uses of 'this' value with the new value, updating anything
607 /// in the IR that uses 'this' to use the other value instead. When this
608 /// returns there are zero uses of 'this'.
609 void replaceAllUsesWith(Value newValue) {
610 getResult().replaceAllUsesWith(newValue);
611 }
612
613 /// Replace all uses of 'this' value with the result of 'op'.
614 void replaceAllUsesWith(Operation *op) {
615 this->getOperation()->replaceAllUsesWith(op);
616 }
617
618 static LogicalResult verifyTrait(Operation *op) {
619 return impl::verifyOneResult(op);
620 }
621};
622
623/// This trait is used for return value APIs for ops that are known to have a
624/// specific type other than `Type`. This allows the "getType()" member to be
625/// more specific for an op. This should be used in conjunction with OneResult,
626/// and occur in the trait list before OneResult.
627template <typename ResultType>
628class OneTypedResult {
629public:
630 /// This class provides return value APIs for ops that are known to have a
631 /// single result. ResultType is the concrete type returned by getType().
632 template <typename ConcreteType>
633 class Impl
634 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
635 public:
636 ResultType getType() {
637 auto resultTy = this->getOperation()->getResult(0).getType();
638 return resultTy.template cast<ResultType>();
639 }
640 };
641};
642
643/// This class provides the API for ops that are known to have a specified
644/// number of results. This is used as a trait like this:
645///
646/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
647///
648template <unsigned N>
649class NResults {
650public:
651 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
652
653 template <typename ConcreteType>
654 class Impl
655 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
656 public:
657 static LogicalResult verifyTrait(Operation *op) {
658 return impl::verifyNResults(op, N);
659 }
660 };
661};
662
663/// This class provides the API for ops that are known to have at least a
664/// specified number of results. This is used as a trait like this:
665///
666/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
667///
668template <unsigned N>
669class AtLeastNResults {
670public:
671 template <typename ConcreteType>
672 class Impl : public detail::MultiResultTraitBase<ConcreteType,
673 AtLeastNResults<N>::Impl> {
674 public:
675 static LogicalResult verifyTrait(Operation *op) {
676 return impl::verifyAtLeastNResults(op, N);
677 }
678 };
679};
680
681/// This class provides the API for ops which have an unknown number of
682/// results.
683template <typename ConcreteType>
684class VariadicResults
685 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
686
687//===----------------------------------------------------------------------===//
688// Terminator Traits
689
690/// This class indicates that the regions associated with this op don't have
691/// terminators.
692template <typename ConcreteType>
693class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
694
695/// This class provides the API for ops that are known to be terminators.
696template <typename ConcreteType>
697class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
698public:
699 static LogicalResult verifyTrait(Operation *op) {
700 return impl::verifyIsTerminator(op);
701 }
702};
703
704/// This class provides verification for ops that are known to have zero
705/// successors.
706template <typename ConcreteType>
707class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
708public:
709 static LogicalResult verifyTrait(Operation *op) {
710 return impl::verifyZeroSuccessor(op);
711 }
712};
713
714namespace detail {
715/// Utility trait base that provides accessors for derived traits that have
716/// multiple successors.
717template <typename ConcreteType, template <typename> class TraitType>
718struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
719 using succ_iterator = Operation::succ_iterator;
720 using succ_range = SuccessorRange;
721
722 /// Return the number of successors.
723 unsigned getNumSuccessors() {
724 return this->getOperation()->getNumSuccessors();
725 }
726
727 /// Return the successor at `index`.
728 Block *getSuccessor(unsigned i) {
729 return this->getOperation()->getSuccessor(i);
730 }
731
732 /// Set the successor at `index`.
733 void setSuccessor(Block *block, unsigned i) {
734 return this->getOperation()->setSuccessor(block, i);
735 }
736
737 /// Successor iterator access.
738 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
739 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
740 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
741};
742} // namespace detail
743
744/// This class provides APIs for ops that are known to have a single successor.
745template <typename ConcreteType>
746class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
747public:
748 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
749 void setSuccessor(Block *succ) {
750 this->getOperation()->setSuccessor(succ, 0);
751 }
752
753 static LogicalResult verifyTrait(Operation *op) {
754 return impl::verifyOneSuccessor(op);
755 }
756};
757
758/// This class provides the API for ops that are known to have a specified
759/// number of successors.
760template <unsigned N>
761class NSuccessors {
762public:
763 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
764
765 template <typename ConcreteType>
766 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
767 NSuccessors<N>::Impl> {
768 public:
769 static LogicalResult verifyTrait(Operation *op) {
770 return impl::verifyNSuccessors(op, N);
771 }
772 };
773};
774
775/// This class provides APIs for ops that are known to have at least a specified
776/// number of successors.
777template <unsigned N>
778class AtLeastNSuccessors {
779public:
780 template <typename ConcreteType>
781 class Impl
782 : public detail::MultiSuccessorTraitBase<ConcreteType,
783 AtLeastNSuccessors<N>::Impl> {
784 public:
785 static LogicalResult verifyTrait(Operation *op) {
786 return impl::verifyAtLeastNSuccessors(op, N);
787 }
788 };
789};
790
791/// This class provides the API for ops which have an unknown number of
792/// successors.
793template <typename ConcreteType>
794class VariadicSuccessors
795 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
796};
797
798//===----------------------------------------------------------------------===//
799// SingleBlock
800
801/// This class provides APIs and verifiers for ops with regions having a single
802/// block.
803template <typename ConcreteType>
804struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
805public:
806 static LogicalResult verifyTrait(Operation *op) {
807 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
808 Region &region = op->getRegion(i);
809
810 // Empty regions are fine.
811 if (region.empty())
812 continue;
813
814 // Non-empty regions must contain a single basic block.
815 if (!llvm::hasSingleElement(region))
816 return op->emitOpError("expects region #")
817 << i << " to have 0 or 1 blocks";
818
819 if (!ConcreteType::template hasTrait<NoTerminator>()) {
820 Block &block = region.front();
821 if (block.empty())
822 return op->emitOpError() << "expects a non-empty block";
823 }
824 }
825 return success();
826 }
827
828 Block *getBody(unsigned idx = 0) {
829 Region &region = this->getOperation()->getRegion(idx);
830 assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region"
) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\""
, "mlir/include/mlir/IR/OpDefinition.h", 830, __extension__ __PRETTY_FUNCTION__
))
;
831 return &region.front();
832 }
833 Region &getBodyRegion(unsigned idx = 0) {
834 return this->getOperation()->getRegion(idx);
835 }
836
837 //===------------------------------------------------------------------===//
838 // Single Region Utilities
839 //===------------------------------------------------------------------===//
840
841 /// The following are a set of methods only enabled when the parent
842 /// operation has a single region. Each of these methods take an additional
843 /// template parameter that represents the concrete operation so that we
844 /// can use SFINAE to disable the methods for non-single region operations.
845 template <typename OpT, typename T = void>
846 using enable_if_single_region =
847 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
848
849 template <typename OpT = ConcreteType>
850 enable_if_single_region<OpT, Block::iterator> begin() {
851 return getBody()->begin();
852 }
853 template <typename OpT = ConcreteType>
854 enable_if_single_region<OpT, Block::iterator> end() {
855 return getBody()->end();
856 }
857 template <typename OpT = ConcreteType>
858 enable_if_single_region<OpT, Operation &> front() {
859 return *begin();
860 }
861
862 /// Insert the operation into the back of the body.
863 template <typename OpT = ConcreteType>
864 enable_if_single_region<OpT> push_back(Operation *op) {
865 insert(Block::iterator(getBody()->end()), op);
866 }
867
868 /// Insert the operation at the given insertion point.
869 template <typename OpT = ConcreteType>
870 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
871 insert(Block::iterator(insertPt), op);
872 }
873 template <typename OpT = ConcreteType>
874 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
875 getBody()->getOperations().insert(insertPt, op);
876 }
877};
878
879//===----------------------------------------------------------------------===//
880// SingleBlockImplicitTerminator
881
882/// This class provides APIs and verifiers for ops with regions having a single
883/// block that must terminate with `TerminatorOpType`.
884template <typename TerminatorOpType>
885struct SingleBlockImplicitTerminator {
886 template <typename ConcreteType>
887 class Impl : public SingleBlock<ConcreteType> {
888 private:
889 using Base = SingleBlock<ConcreteType>;
890 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
891 /// cyclic header inclusion.
892 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
893 OperationState state(loc, TerminatorOpType::getOperationName());
894 TerminatorOpType::build(builder, state);
895 return Operation::create(state);
896 }
897
898 public:
899 /// The type of the operation used as the implicit terminator type.
900 using ImplicitTerminatorOpT = TerminatorOpType;
901
902 static LogicalResult verifyRegionTrait(Operation *op) {
903 if (failed(Base::verifyTrait(op)))
904 return failure();
905 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
906 Region &region = op->getRegion(i);
907 // Empty regions are fine.
908 if (region.empty())
909 continue;
910 Operation &terminator = region.front().back();
911 if (isa<TerminatorOpType>(terminator))
912 continue;
913
914 return op->emitOpError("expects regions to end with '" +
915 TerminatorOpType::getOperationName() +
916 "', found '" +
917 terminator.getName().getStringRef() + "'")
918 .attachNote()
919 << "in custom textual format, the absence of terminator implies "
920 "'"
921 << TerminatorOpType::getOperationName() << '\'';
922 }
923
924 return success();
925 }
926
927 /// Ensure that the given region has the terminator required by this trait.
928 /// If OpBuilder is provided, use it to build the terminator and notify the
929 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
930 /// construct an OpBuilder with no listeners; this should only be used if no
931 /// OpBuilder is available at the call site, e.g., in the parser.
932 static void ensureTerminator(Region &region, Builder &builder,
933 Location loc) {
934 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
935 buildTerminator);
936 }
937 static void ensureTerminator(Region &region, OpBuilder &builder,
938 Location loc) {
939 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
940 buildTerminator);
941 }
942
943 //===------------------------------------------------------------------===//
944 // Single Region Utilities
945 //===------------------------------------------------------------------===//
946 using Base::getBody;
947
948 template <typename OpT, typename T = void>
949 using enable_if_single_region =
950 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
951
952 /// Insert the operation into the back of the body, before the terminator.
953 template <typename OpT = ConcreteType>
954 enable_if_single_region<OpT> push_back(Operation *op) {
955 insert(Block::iterator(getBody()->getTerminator()), op);
956 }
957
958 /// Insert the operation at the given insertion point. Note: The operation
959 /// is never inserted after the terminator, even if the insertion point is
960 /// end().
961 template <typename OpT = ConcreteType>
962 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
963 insert(Block::iterator(insertPt), op);
964 }
965 template <typename OpT = ConcreteType>
966 enable_if_single_region<OpT> insert(Block::iterator insertPt,
967 Operation *op) {
968 auto *body = getBody();
969 if (insertPt == body->end())
970 insertPt = Block::iterator(body->getTerminator());
971 body->getOperations().insert(insertPt, op);
972 }
973 };
974};
975
976/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
977/// to be used with `llvm::is_detected`.
978template <class T>
979using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
980
981/// Support to check if an operation has the SingleBlockImplicitTerminator
982/// trait. We can't just use `hasTrait` because this class is templated on a
983/// specific terminator op.
984template <class Op, bool hasTerminator =
985 llvm::is_detected<has_implicit_terminator_t, Op>::value>
986struct hasSingleBlockImplicitTerminator {
987 static constexpr bool value = std::is_base_of<
988 typename OpTrait::SingleBlockImplicitTerminator<
989 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
990 Op>::value;
991};
992template <class Op>
993struct hasSingleBlockImplicitTerminator<Op, false> {
994 static constexpr bool value = false;
995};
996
997//===----------------------------------------------------------------------===//
998// Misc Traits
999
1000/// This class provides verification for ops that are known to have the same
1001/// operand shape: all operands are scalars, vectors/tensors of the same
1002/// shape.
1003template <typename ConcreteType>
1004class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
1005public:
1006 static LogicalResult verifyTrait(Operation *op) {
1007 return impl::verifySameOperandsShape(op);
1008 }
1009};
1010
1011/// This class provides verification for ops that are known to have the same
1012/// operand and result shape: both are scalars, vectors/tensors of the same
1013/// shape.
1014template <typename ConcreteType>
1015class SameOperandsAndResultShape
1016 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
1017public:
1018 static LogicalResult verifyTrait(Operation *op) {
1019 return impl::verifySameOperandsAndResultShape(op);
1020 }
1021};
1022
1023/// This class provides verification for ops that are known to have the same
1024/// operand element type (or the type itself if it is scalar).
1025///
1026template <typename ConcreteType>
1027class SameOperandsElementType
1028 : public TraitBase<ConcreteType, SameOperandsElementType> {
1029public:
1030 static LogicalResult verifyTrait(Operation *op) {
1031 return impl::verifySameOperandsElementType(op);
1032 }
1033};
1034
1035/// This class provides verification for ops that are known to have the same
1036/// operand and result element type (or the type itself if it is scalar).
1037///
1038template <typename ConcreteType>
1039class SameOperandsAndResultElementType
1040 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1041public:
1042 static LogicalResult verifyTrait(Operation *op) {
1043 return impl::verifySameOperandsAndResultElementType(op);
1044 }
1045};
1046
1047/// This class provides verification for ops that are known to have the same
1048/// operand and result type.
1049///
1050/// Note: this trait subsumes the SameOperandsAndResultShape and
1051/// SameOperandsAndResultElementType traits.
1052template <typename ConcreteType>
1053class SameOperandsAndResultType
1054 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1055public:
1056 static LogicalResult verifyTrait(Operation *op) {
1057 return impl::verifySameOperandsAndResultType(op);
1058 }
1059};
1060
1061/// This class verifies that any results of the specified op have a boolean
1062/// type, a vector thereof, or a tensor thereof.
1063template <typename ConcreteType>
1064class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1065public:
1066 static LogicalResult verifyTrait(Operation *op) {
1067 return impl::verifyResultsAreBoolLike(op);
1068 }
1069};
1070
1071/// This class verifies that any results of the specified op have a floating
1072/// point type, a vector thereof, or a tensor thereof.
1073template <typename ConcreteType>
1074class ResultsAreFloatLike
1075 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1076public:
1077 static LogicalResult verifyTrait(Operation *op) {
1078 return impl::verifyResultsAreFloatLike(op);
1079 }
1080};
1081
1082/// This class verifies that any results of the specified op have a signless
1083/// integer or index type, a vector thereof, or a tensor thereof.
1084template <typename ConcreteType>
1085class ResultsAreSignlessIntegerLike
1086 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1087public:
1088 static LogicalResult verifyTrait(Operation *op) {
1089 return impl::verifyResultsAreSignlessIntegerLike(op);
1090 }
1091};
1092
1093/// This class adds property that the operation is commutative.
1094template <typename ConcreteType>
1095class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
1096
1097/// This class adds property that the operation is an involution.
1098/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1099template <typename ConcreteType>
1100class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1101public:
1102 static LogicalResult verifyTrait(Operation *op) {
1103 static_assert(ConcreteType::template hasTrait<OneResult>(),
1104 "expected operation to produce one result");
1105 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1106 "expected operation to take one operand");
1107 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1108 "expected operation to preserve type");
1109 // Involution requires the operation to be side effect free as well
1110 // but currently this check is under a FIXME and is not actually done.
1111 return impl::verifyIsInvolution(op);
1112 }
1113
1114 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1115 return impl::foldInvolution(op);
1116 }
1117};
1118
1119/// This class adds property that the operation is idempotent.
1120/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1121/// or a binary operation "g" that satisfies g(x, x) = x.
1122template <typename ConcreteType>
1123class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1124public:
1125 static LogicalResult verifyTrait(Operation *op) {
1126 static_assert(ConcreteType::template hasTrait<OneResult>(),
1127 "expected operation to produce one result");
1128 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1129 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1130 "expected operation to take one or two operands");
1131 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1132 "expected operation to preserve type");
1133 // Idempotent requires the operation to be side effect free as well
1134 // but currently this check is under a FIXME and is not actually done.
1135 return impl::verifyIsIdempotent(op);
1136 }
1137
1138 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1139 return impl::foldIdempotent(op);
1140 }
1141};
1142
1143/// This class verifies that all operands of the specified op have a float type,
1144/// a vector thereof, or a tensor thereof.
1145template <typename ConcreteType>
1146class OperandsAreFloatLike
1147 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1148public:
1149 static LogicalResult verifyTrait(Operation *op) {
1150 return impl::verifyOperandsAreFloatLike(op);
1151 }
1152};
1153
1154/// This class verifies that all operands of the specified op have a signless
1155/// integer or index type, a vector thereof, or a tensor thereof.
1156template <typename ConcreteType>
1157class OperandsAreSignlessIntegerLike
1158 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1159public:
1160 static LogicalResult verifyTrait(Operation *op) {
1161 return impl::verifyOperandsAreSignlessIntegerLike(op);
1162 }
1163};
1164
1165/// This class verifies that all operands of the specified op have the same
1166/// type.
1167template <typename ConcreteType>
1168class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1169public:
1170 static LogicalResult verifyTrait(Operation *op) {
1171 return impl::verifySameTypeOperands(op);
1172 }
1173};
1174
1175/// This class provides the API for a sub-set of ops that are known to be
1176/// constant-like. These are non-side effecting operations with one result and
1177/// zero operands that can always be folded to a specific attribute value.
1178template <typename ConcreteType>
1179class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1180public:
1181 static LogicalResult verifyTrait(Operation *op) {
1182 static_assert(ConcreteType::template hasTrait<OneResult>(),
1183 "expected operation to produce one result");
1184 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1185 "expected operation to take zero operands");
1186 // TODO: We should verify that the operation can always be folded, but this
1187 // requires that the attributes of the op already be verified. We should add
1188 // support for verifying traits "after" the operation to enable this use
1189 // case.
1190 return success();
1191 }
1192};
1193
1194/// This class provides the API for ops that are known to be isolated from
1195/// above.
1196template <typename ConcreteType>
1197class IsIsolatedFromAbove
1198 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1199public:
1200 static LogicalResult verifyRegionTrait(Operation *op) {
1201 return impl::verifyIsIsolatedFromAbove(op);
1202 }
1203};
1204
1205/// A trait of region holding operations that defines a new scope for polyhedral
1206/// optimization purposes. Any SSA values of 'index' type that either dominate
1207/// such an operation or are used at the top-level of such an operation
1208/// automatically become valid symbols for the polyhedral scope defined by that
1209/// operation. For more details, see `Traits.md#AffineScope`.
1210template <typename ConcreteType>
1211class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1212public:
1213 static LogicalResult verifyTrait(Operation *op) {
1214 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1215 "expected operation to have one or more regions");
1216 return success();
1217 }
1218};
1219
1220/// A trait of region holding operations that define a new scope for automatic
1221/// allocations, i.e., allocations that are freed when control is transferred
1222/// back from the operation's region. Any operations performing such allocations
1223/// (for eg. memref.alloca) will have their allocations automatically freed at
1224/// their closest enclosing operation with this trait.
1225template <typename ConcreteType>
1226class AutomaticAllocationScope
1227 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1228public:
1229 static LogicalResult verifyTrait(Operation *op) {
1230 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1231 "expected operation to have one or more regions");
1232 return success();
1233 }
1234};
1235
1236/// This class provides a verifier for ops that are expecting their parent
1237/// to be one of the given parent ops
1238template <typename... ParentOpTypes>
1239struct HasParent {
1240 template <typename ConcreteType>
1241 class Impl : public TraitBase<ConcreteType, Impl> {
1242 public:
1243 static LogicalResult verifyTrait(Operation *op) {
1244 if (llvm::isa_and_nonnull<ParentOpTypes...>(op->getParentOp()))
1245 return success();
1246
1247 return op->emitOpError()
1248 << "expects parent op "
1249 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1250 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1251 << "'";
1252 }
1253 };
1254};
1255
1256/// A trait for operations that have an attribute specifying operand segments.
1257///
1258/// Certain operations can have multiple variadic operands and their size
1259/// relationship is not always known statically. For such cases, we need
1260/// a per-op-instance specification to divide the operands into logical groups
1261/// or segments. This can be modeled by attributes. The attribute will be named
1262/// as `operand_segment_sizes`.
1263///
1264/// This trait verifies the attribute for specifying operand segments has
1265/// the correct type (1D vector) and values (non-negative), etc.
1266template <typename ConcreteType>
1267class AttrSizedOperandSegments
1268 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1269public:
1270 static StringRef getOperandSegmentSizeAttr() {
1271 return "operand_segment_sizes";
1272 }
1273
1274 static LogicalResult verifyTrait(Operation *op) {
1275 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1276 op, getOperandSegmentSizeAttr());
1277 }
1278};
1279
1280/// Similar to AttrSizedOperandSegments but used for results.
1281template <typename ConcreteType>
1282class AttrSizedResultSegments
1283 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1284public:
1285 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1286
1287 static LogicalResult verifyTrait(Operation *op) {
1288 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1289 op, getResultSegmentSizeAttr());
1290 }
1291};
1292
1293/// This trait provides a verifier for ops that are expecting their regions to
1294/// not have any arguments
1295template <typename ConcrentType>
1296struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1297 static LogicalResult verifyTrait(Operation *op) {
1298 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1299 }
1300};
1301
1302// This trait is used to flag operations that consume or produce
1303// values of `MemRef` type where those references can be 'normalized'.
1304// TODO: Right now, the operands of an operation are either all normalizable,
1305// or not. In the future, we may want to allow some of the operands to be
1306// normalizable.
1307template <typename ConcrentType>
1308struct MemRefsNormalizable
1309 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1310
1311/// This trait tags element-wise ops on vectors or tensors.
1312///
1313/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1314/// trait. In particular, broadcasting behavior is not allowed.
1315///
1316/// An `Elementwise` op must satisfy the following properties:
1317///
1318/// 1. If any result is a vector/tensor then at least one operand must also be a
1319/// vector/tensor.
1320/// 2. If any operand is a vector/tensor then there must be at least one result
1321/// and all results must be vectors/tensors.
1322/// 3. All operand and result vector/tensor types must be of the same shape. The
1323/// shape may be dynamic in which case the op's behaviour is undefined for
1324/// non-matching shapes.
1325/// 4. The operation must be elementwise on its vector/tensor operands and
1326/// results. When applied to single-element vectors/tensors, the result must
1327/// be the same per elememnt.
1328///
1329/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1330/// interface `ElementwiseTypeInterface` that describes the container types for
1331/// which the operation is elementwise.
1332///
1333/// Rationale:
1334/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1335/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1336/// generic definition of the iteration space.
1337/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1338/// the same pattern, as otherwise lots of special handling for type
1339/// mismatches would be needed.
1340/// - 4. guarantees that no error handling is needed. Higher-level dialects
1341/// should reify any needed guards or error handling code before lowering to
1342/// an `Elementwise` op.
1343template <typename ConcreteType>
1344struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1345 static LogicalResult verifyTrait(Operation *op) {
1346 return ::mlir::OpTrait::impl::verifyElementwise(op);
1347 }
1348};
1349
1350/// This trait tags `Elementwise` operatons that can be systematically
1351/// scalarized. All vector/tensor operands and results are then replaced by
1352/// scalars of the respective element type. Semantically, this is the operation
1353/// on a single element of the vector/tensor.
1354///
1355/// Rationale:
1356/// Allow to define the vector/tensor semantics of elementwise operations based
1357/// on the same op's behavior on scalars. This provides a constructive procedure
1358/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1359///
1360/// Example:
1361/// ```
1362/// %tensor_select = "arith.select"(%pred_tensor, %true_val, %false_val)
1363/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1364/// -> tensor<?xf32>
1365/// ```
1366/// can be scalarized to
1367///
1368/// ```
1369/// %scalar_select = "arith.select"(%pred, %true_val_scalar, %false_val_scalar)
1370/// : (i1, f32, f32) -> f32
1371/// ```
1372template <typename ConcreteType>
1373struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1374 static LogicalResult verifyTrait(Operation *op) {
1375 static_assert(
1376 ConcreteType::template hasTrait<Elementwise>(),
1377 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1378 return success();
1379 }
1380};
1381
1382/// This trait tags `Elementwise` operatons that can be systematically
1383/// vectorized. All scalar operands and results are then replaced by vectors
1384/// with the respective element type. Semantically, this is the operation on
1385/// multiple elements simultaneously. See also `Tensorizable`.
1386///
1387/// Rationale:
1388/// Provide the reverse to `Scalarizable` which, when chained together, allows
1389/// reasoning about the relationship between the tensor and vector case.
1390/// Additionally, it permits reasoning about promoting scalars to vectors via
1391/// broadcasting in cases like `%select_scalar_pred` below.
1392template <typename ConcreteType>
1393struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1394 static LogicalResult verifyTrait(Operation *op) {
1395 static_assert(
1396 ConcreteType::template hasTrait<Elementwise>(),
1397 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1398 return success();
1399 }
1400};
1401
1402/// This trait tags `Elementwise` operatons that can be systematically
1403/// tensorized. All scalar operands and results are then replaced by tensors
1404/// with the respective element type. Semantically, this is the operation on
1405/// multiple elements simultaneously. See also `Vectorizable`.
1406///
1407/// Rationale:
1408/// Provide the reverse to `Scalarizable` which, when chained together, allows
1409/// reasoning about the relationship between the tensor and vector case.
1410/// Additionally, it permits reasoning about promoting scalars to tensors via
1411/// broadcasting in cases like `%select_scalar_pred` below.
1412///
1413/// Examples:
1414/// ```
1415/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1416/// ```
1417/// can be tensorized to
1418/// ```
1419/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1420/// -> tensor<?xf32>
1421/// ```
1422///
1423/// ```
1424/// %scalar_pred = "arith.select"(%pred, %true_val, %false_val)
1425/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1426/// ```
1427/// can be tensorized to
1428/// ```
1429/// %tensor_pred = "arith.select"(%pred, %true_val, %false_val)
1430/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1431/// -> tensor<?xf32>
1432/// ```
1433template <typename ConcreteType>
1434struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1435 static LogicalResult verifyTrait(Operation *op) {
1436 static_assert(
1437 ConcreteType::template hasTrait<Elementwise>(),
1438 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1439 return success();
1440 }
1441};
1442
1443/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1444/// provide an easy way for scalar operations to conveniently generalize their
1445/// behavior to vectors/tensors, and systematize conversion between these forms.
1446bool hasElementwiseMappableTraits(Operation *op);
1447
1448} // namespace OpTrait
1449
1450//===----------------------------------------------------------------------===//
1451// Internal Trait Utilities
1452//===----------------------------------------------------------------------===//
1453
1454namespace op_definition_impl {
1455//===----------------------------------------------------------------------===//
1456// Trait Existence
1457
1458/// Returns true if this given Trait ID matches the IDs of any of the provided
1459/// trait types `Traits`.
1460template <template <typename T> class... Traits>
1461static bool hasTrait(TypeID traitID) {
1462 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1463 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1464 if (traitIDs[i] == traitID)
1465 return true;
1466 return false;
1467}
1468
1469//===----------------------------------------------------------------------===//
1470// Trait Folding
1471
1472/// Trait to check if T provides a 'foldTrait' method for single result
1473/// operations.
1474template <typename T, typename... Args>
1475using has_single_result_fold_trait = decltype(T::foldTrait(
1476 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1477template <typename T>
1478using detect_has_single_result_fold_trait =
1479 llvm::is_detected<has_single_result_fold_trait, T>;
1480/// Trait to check if T provides a general 'foldTrait' method.
1481template <typename T, typename... Args>
1482using has_fold_trait =
1483 decltype(T::foldTrait(std::declval<Operation *>(),
1484 std::declval<ArrayRef<Attribute>>(),
1485 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1486template <typename T>
1487using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1488/// Trait to check if T provides any `foldTrait` method.
1489/// NOTE: This should use std::disjunction when C++17 is available.
1490template <typename T>
1491using detect_has_any_fold_trait =
1492 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1493 detect_has_fold_trait<T>,
1494 detect_has_single_result_fold_trait<T>>;
1495
1496/// Returns the result of folding a trait that implements a `foldTrait` function
1497/// that is specialized for operations that have a single result.
1498template <typename Trait>
1499static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1500 LogicalResult>
1501foldTrait(Operation *op, ArrayRef<Attribute> operands,
1502 SmallVectorImpl<OpFoldResult> &results) {
1503 assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__
))
1504 "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__
))
1505 "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1505, __extension__ __PRETTY_FUNCTION__
))
;
1506 // If a previous trait has already been folded and replaced this operation, we
1507 // fail to fold this trait.
1508 if (!results.empty())
1509 return failure();
1510
1511 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1512 if (result.template dyn_cast<Value>() != op->getResult(0))
1513 results.push_back(result);
1514 return success();
1515 }
1516 return failure();
1517}
1518/// Returns the result of folding a trait that implements a generalized
1519/// `foldTrait` function that is supports any operation type.
1520template <typename Trait>
1521static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1522foldTrait(Operation *op, ArrayRef<Attribute> operands,
1523 SmallVectorImpl<OpFoldResult> &results) {
1524 // If a previous trait has already been folded and replaced this operation, we
1525 // fail to fold this trait.
1526 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1527}
1528template <typename Trait>
1529static inline std::enable_if_t<!detect_has_any_fold_trait<Trait>::value,
1530 LogicalResult>
1531foldTrait(Operation *, ArrayRef<Attribute>, SmallVectorImpl<OpFoldResult> &) {
1532 return failure();
1533}
1534
1535/// Given a tuple type containing a set of traits, return the result of folding
1536/// the given operation.
1537template <typename... Ts>
1538static LogicalResult foldTraits(Operation *op, ArrayRef<Attribute> operands,
1539 SmallVectorImpl<OpFoldResult> &results) {
1540 bool anyFolded = false;
1541 (void)std::initializer_list<int>{
1542 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1543 return success(anyFolded);
1544}
1545
1546//===----------------------------------------------------------------------===//
1547// Trait Verification
1548
1549/// Trait to check if T provides a `verifyTrait` method.
1550template <typename T, typename... Args>
1551using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1552template <typename T>
1553using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1554
1555/// Trait to check if T provides a `verifyTrait` method.
1556template <typename T, typename... Args>
1557using has_verify_region_trait =
1558 decltype(T::verifyRegionTrait(std::declval<Operation *>()));
1559template <typename T>
1560using detect_has_verify_region_trait =
1561 llvm::is_detected<has_verify_region_trait, T>;
1562
1563/// Verify the given trait if it provides a verifier.
1564template <typename T>
1565std::enable_if_t<detect_has_verify_trait<T>::value, LogicalResult>
1566verifyTrait(Operation *op) {
1567 return T::verifyTrait(op);
1568}
1569template <typename T>
1570inline std::enable_if_t<!detect_has_verify_trait<T>::value, LogicalResult>
1571verifyTrait(Operation *) {
1572 return success();
1573}
1574
1575/// Given a set of traits, return the result of verifying the given operation.
1576template <typename... Ts>
1577LogicalResult verifyTraits(Operation *op) {
1578 LogicalResult result = success();
1579 (void)std::initializer_list<int>{
1580 (result = succeeded(result) ? verifyTrait<Ts>(op) : failure(), 0)...};
1581 return result;
1582}
1583
1584/// Verify the given trait if it provides a region verifier.
1585template <typename T>
1586std::enable_if_t<detect_has_verify_region_trait<T>::value, LogicalResult>
1587verifyRegionTrait(Operation *op) {
1588 return T::verifyRegionTrait(op);
1589}
1590template <typename T>
1591inline std::enable_if_t<!detect_has_verify_region_trait<T>::value,
1592 LogicalResult>
1593verifyRegionTrait(Operation *) {
1594 return success();
1595}
1596
1597/// Given a set of traits, return the result of verifying the regions of the
1598/// given operation.
1599template <typename... Ts>
1600LogicalResult verifyRegionTraits(Operation *op) {
1601 (void)op;
1602 LogicalResult result = success();
1603 (void)std::initializer_list<int>{
1604 (result = succeeded(result) ? verifyRegionTrait<Ts>(op) : failure(),
1605 0)...};
1606 return result;
1607}
1608} // namespace op_definition_impl
1609
1610//===----------------------------------------------------------------------===//
1611// Operation Definition classes
1612//===----------------------------------------------------------------------===//
1613
1614/// This provides public APIs that all operations should have. The template
1615/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1616/// are base classes by the policy pattern.
1617template <typename ConcreteType, template <typename T> class... Traits>
1618class Op : public OpState, public Traits<ConcreteType>... {
1619public:
1620 /// Inherit getOperation from `OpState`.
1621 using OpState::getOperation;
1622 using OpState::verify;
1623 using OpState::verifyRegions;
1624
1625 /// Return if this operation contains the provided trait.
1626 template <template <typename T> class Trait>
1627 static constexpr bool hasTrait() {
1628 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1629 }
1630
1631 /// Create a deep copy of this operation.
1632 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1633
1634 /// Create a partial copy of this operation without traversing into attached
1635 /// regions. The new operation will have the same number of regions as the
1636 /// original one, but they will be left empty.
1637 ConcreteType cloneWithoutRegions() {
1638 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1639 }
1640
1641 /// Return true if this "op class" can match against the specified operation.
1642 static bool classof(Operation *op) {
1643 if (auto info = op->getRegisteredInfo())
1644 return TypeID::get<ConcreteType>() == info->getTypeID();
1645#ifndef NDEBUG
1646 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1647 llvm::report_fatal_error(
1648 "classof on '" + ConcreteType::getOperationName() +
1649 "' failed due to the operation not being registered");
1650#endif
1651 return false;
1652 }
1653 /// Provide `classof` support for other OpBase derived classes, such as
1654 /// Interfaces.
1655 template <typename T>
1656 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1657 classof(const T *op) {
1658 return classof(const_cast<T *>(op)->getOperation());
1659 }
1660
1661 /// Expose the type we are instantiated on to template machinery that may want
1662 /// to introspect traits on this operation.
1663 using ConcreteOpType = ConcreteType;
1664
1665 /// This is a public constructor. Any op can be initialized to null.
1666 explicit Op() : OpState(nullptr) {}
1667 Op(std::nullptr_t) : OpState(nullptr) {}
1668
1669 /// This is a public constructor to enable access via the llvm::cast family of
1670 /// methods. This should not be used directly.
1671 explicit Op(Operation *state) : OpState(state) {}
1672
1673 /// Methods for supporting PointerLikeTypeTraits.
1674 const void *getAsOpaquePointer() const {
1675 return static_cast<const void *>((Operation *)*this);
1676 }
1677 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1678 return ConcreteOpType(
1679 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1680 }
1681
1682 /// Attach the given models as implementations of the corresponding interfaces
1683 /// for the concrete operation.
1684 template <typename... Models>
1685 static void attachInterface(MLIRContext &context) {
1686 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
1687 ConcreteType::getOperationName(), &context);
1688 if (!info)
1689 llvm::report_fatal_error(
1690 "Attempting to attach an interface to an unregistered operation " +
1691 ConcreteType::getOperationName() + ".");
1692 info->attachInterface<Models...>();
1693 }
1694
1695private:
1696 /// Trait to check if T provides a 'fold' method for a single result op.
1697 template <typename T, typename... Args>
1698 using has_single_result_fold =
1699 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1700 template <typename T>
1701 using detect_has_single_result_fold =
1702 llvm::is_detected<has_single_result_fold, T>;
1703 /// Trait to check if T provides a general 'fold' method.
1704 template <typename T, typename... Args>
1705 using has_fold = decltype(std::declval<T>().fold(
1706 std::declval<ArrayRef<Attribute>>(),
1707 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1708 template <typename T>
1709 using detect_has_fold = llvm::is_detected<has_fold, T>;
1710 /// Trait to check if T provides a 'print' method.
1711 template <typename T, typename... Args>
1712 using has_print =
1713 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1714 template <typename T>
1715 using detect_has_print = llvm::is_detected<has_print, T>;
1716
1717 /// Returns an interface map containing the interfaces registered to this
1718 /// operation.
1719 static detail::InterfaceMap getInterfaceMap() {
1720 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1721 }
1722
1723 /// Return the internal implementations of each of the OperationName
1724 /// hooks.
1725 /// Implementation of `FoldHookFn` OperationName hook.
1726 static OperationName::FoldHookFn getFoldHookFn() {
1727 return getFoldHookFnImpl<ConcreteType>();
1728 }
1729 /// The internal implementation of `getFoldHookFn` above that is invoked if
1730 /// the operation is single result and defines a `fold` method.
1731 template <typename ConcreteOpT>
1732 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1733 Traits<ConcreteOpT>...>::value &&
1734 detect_has_single_result_fold<ConcreteOpT>::value,
1735 OperationName::FoldHookFn>
1736 getFoldHookFnImpl() {
1737 return [](Operation *op, ArrayRef<Attribute> operands,
1738 SmallVectorImpl<OpFoldResult> &results) {
1739 return foldSingleResultHook<ConcreteOpT>(op, operands, results);
1740 };
1741 }
1742 /// The internal implementation of `getFoldHookFn` above that is invoked if
1743 /// the operation is not single result and defines a `fold` method.
1744 template <typename ConcreteOpT>
1745 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1746 Traits<ConcreteOpT>...>::value &&
1747 detect_has_fold<ConcreteOpT>::value,
1748 OperationName::FoldHookFn>
1749 getFoldHookFnImpl() {
1750 return [](Operation *op, ArrayRef<Attribute> operands,
1751 SmallVectorImpl<OpFoldResult> &results) {
1752 return foldHook<ConcreteOpT>(op, operands, results);
1753 };
1754 }
1755 /// The internal implementation of `getFoldHookFn` above that is invoked if
1756 /// the operation does not define a `fold` method.
1757 template <typename ConcreteOpT>
1758 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1759 !detect_has_fold<ConcreteOpT>::value,
1760 OperationName::FoldHookFn>
1761 getFoldHookFnImpl() {
1762 return [](Operation *op, ArrayRef<Attribute> operands,
1763 SmallVectorImpl<OpFoldResult> &results) {
1764 // In this case, we only need to fold the traits of the operation.
1765 return op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1766 op, operands, results);
1767 };
1768 }
1769 /// Return the result of folding a single result operation that defines a
1770 /// `fold` method.
1771 template <typename ConcreteOpT>
1772 static LogicalResult
1773 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1774 SmallVectorImpl<OpFoldResult> &results) {
1775 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1776
1777 // If the fold failed or was in-place, try to fold the traits of the
1778 // operation.
1779 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1780 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1781 op, operands, results)))
1782 return success();
1783 return success(static_cast<bool>(result));
1784 }
1785 results.push_back(result);
1786 return success();
1787 }
1788 /// Return the result of folding an operation that defines a `fold` method.
1789 template <typename ConcreteOpT>
1790 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1791 SmallVectorImpl<OpFoldResult> &results) {
1792 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1793
1794 // If the fold failed or was in-place, try to fold the traits of the
1795 // operation.
1796 if (failed(result) || results.empty()) {
1797 if (succeeded(op_definition_impl::foldTraits<Traits<ConcreteType>...>(
1798 op, operands, results)))
1799 return success();
1800 }
1801 return result;
1802 }
1803
1804 /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1805 static OperationName::GetCanonicalizationPatternsFn
1806 getGetCanonicalizationPatternsFn() {
1807 return &ConcreteType::getCanonicalizationPatterns;
1808 }
1809 /// Implementation of `GetHasTraitFn`
1810 static OperationName::HasTraitFn getHasTraitFn() {
1811 return
1812 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1813 }
1814 /// Implementation of `ParseAssemblyFn` OperationName hook.
1815 static OperationName::ParseAssemblyFn getParseAssemblyFn() {
1816 return &ConcreteType::parse;
1817 }
1818 /// Implementation of `PrintAssemblyFn` OperationName hook.
1819 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1820 return getPrintAssemblyFnImpl<ConcreteType>();
1821 }
1822 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1823 /// the concrete operation does not define a `print` method.
1824 template <typename ConcreteOpT>
1825 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1826 OperationName::PrintAssemblyFn>
1827 getPrintAssemblyFnImpl() {
1828 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1829 return OpState::print(op, printer, defaultDialect);
1830 };
1831 }
1832 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1833 /// the concrete operation defines a `print` method.
1834 template <typename ConcreteOpT>
1835 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1836 OperationName::PrintAssemblyFn>
1837 getPrintAssemblyFnImpl() {
1838 return &printAssembly;
1839 }
1840 static void printAssembly(Operation *op, OpAsmPrinter &p,
1841 StringRef defaultDialect) {
1842 OpState::printOpName(op, p, defaultDialect);
1843 return cast<ConcreteType>(op).print(p);
1844 }
1845 /// Implementation of `VerifyInvariantsFn` OperationName hook.
1846 static LogicalResult verifyInvariants(Operation *op) {
1847 static_assert(hasNoDataMembers(),
1848 "Op class shouldn't define new data members");
1849 return failure(
1850 failed(op_definition_impl::verifyTraits<Traits<ConcreteType>...>(op)) ||
1851 failed(cast<ConcreteType>(op).verify()));
1852 }
1853 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
1854 return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
1855 }
1856 /// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
1857 static LogicalResult verifyRegionInvariants(Operation *op) {
1858 static_assert(hasNoDataMembers(),
1859 "Op class shouldn't define new data members");
1860 return failure(
1861 failed(op_definition_impl::verifyRegionTraits<Traits<ConcreteType>...>(
1862 op)) ||
1863 failed(cast<ConcreteType>(op).verifyRegions()));
1864 }
1865 static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
1866 return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
1867 }
1868
1869 static constexpr bool hasNoDataMembers() {
1870 // Checking that the derived class does not define any member by comparing
1871 // its size to an ad-hoc EmptyOp.
1872 class EmptyOp : public Op<EmptyOp, Traits...> {};
1873 return sizeof(ConcreteType) == sizeof(EmptyOp);
1874 }
1875
1876 /// Allow access to internal implementation methods.
1877 friend RegisteredOperationName;
1878};
1879
1880/// This class represents the base of an operation interface. See the definition
1881/// of `detail::Interface` for requirements on the `Traits` type.
1882template <typename ConcreteType, typename Traits>
1883class OpInterface
1884 : public detail::Interface<ConcreteType, Operation *, Traits,
1885 Op<ConcreteType>, OpTrait::TraitBase> {
1886public:
1887 using Base = OpInterface<ConcreteType, Traits>;
1888 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1889 Op<ConcreteType>, OpTrait::TraitBase>;
1890
1891 /// Inherit the base class constructor.
1892 using InterfaceBase::InterfaceBase;
1893
1894protected:
1895 /// Returns the impl interface instance for the given operation.
1896 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1897 OperationName name = op->getName();
1898
1899 // Access the raw interface from the operation info.
1900 if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1901 if (auto *opIface = rInfo->getInterface<ConcreteType>())
1902 return opIface;
1903 // Fallback to the dialect to provide it with a chance to implement this
1904 // interface for this operation.
1905 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
1906 op->getName());
1907 }
1908 // Fallback to the dialect to provide it with a chance to implement this
1909 // interface for this operation.
1910 if (Dialect *dialect = name.getDialect())
1911 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1912 return nullptr;
1913 }
1914
1915 /// Allow access to `getInterfaceFor`.
1916 friend InterfaceBase;
1917};
1918
1919//===----------------------------------------------------------------------===//
1920// CastOpInterface utilities
1921//===----------------------------------------------------------------------===//
1922
1923// These functions are out-of-line implementations of the methods in
1924// CastOpInterface, which avoids them being template instantiated/duplicated.
1925namespace impl {
1926/// Attempt to fold the given cast operation.
1927LogicalResult foldCastInterfaceOp(Operation *op,
1928 ArrayRef<Attribute> attrOperands,
1929 SmallVectorImpl<OpFoldResult> &foldResults);
1930/// Attempt to verify the given cast operation.
1931LogicalResult verifyCastInterfaceOp(
1932 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1933} // namespace impl
1934} // namespace mlir
1935
1936namespace llvm {
1937
1938template <typename T>
1939struct DenseMapInfo<
1940 T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
1941 static inline T getEmptyKey() {
1942 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1943 return T::getFromOpaquePointer(pointer);
1944 }
1945 static inline T getTombstoneKey() {
1946 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1947 return T::getFromOpaquePointer(pointer);
1948 }
1949 static unsigned getHashValue(T val) {
1950 return hash_value(val.getAsOpaquePointer());
1951 }
1952 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
1953};
1954
1955} // namespace llvm
1956
1957#endif