Bug Summary

File:mlir/include/mlir/IR/OpDefinition.h
Warning:line 113, column 5
Called C++ object pointer is null

Annotated Source Code

Press '?' to see keyboard shortcuts

clang -cc1 -cc1 -triple x86_64-pc-linux-gnu -analyze -disable-free -clear-ast-before-backend -disable-llvm-verifier -discard-value-names -main-file-name SuperVectorize.cpp -analyzer-store=region -analyzer-opt-analyze-nested-blocks -analyzer-checker=core -analyzer-checker=apiModeling -analyzer-checker=unix -analyzer-checker=deadcode -analyzer-checker=cplusplus -analyzer-checker=security.insecureAPI.UncheckedReturn -analyzer-checker=security.insecureAPI.getpw -analyzer-checker=security.insecureAPI.gets -analyzer-checker=security.insecureAPI.mktemp -analyzer-checker=security.insecureAPI.mkstemp -analyzer-checker=security.insecureAPI.vfork -analyzer-checker=nullability.NullPassedToNonnull -analyzer-checker=nullability.NullReturnedFromNonnull -analyzer-output plist -w -setup-static-analyzer -analyzer-config-compatibility-mode=true -mrelocation-model pic -pic-level 2 -mframe-pointer=none -fmath-errno -ffp-contract=on -fno-rounding-math -mconstructor-aliases -funwind-tables=2 -target-cpu x86-64 -tune-cpu generic -debugger-tuning=gdb -ffunction-sections -fdata-sections -fcoverage-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins -resource-dir /usr/lib/llvm-14/lib/clang/14.0.0 -D MLIR_CUDA_CONVERSIONS_ENABLED=1 -D MLIR_ROCM_CONVERSIONS_ENABLED=1 -D _DEBUG -D _GNU_SOURCE -D __STDC_CONSTANT_MACROS -D __STDC_FORMAT_MACROS -D __STDC_LIMIT_MACROS -I tools/mlir/lib/Dialect/Affine/Transforms -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Affine/Transforms -I include -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/llvm/include -I /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include -I tools/mlir/include -D _FORTIFY_SOURCE=2 -D NDEBUG -U NDEBUG -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/x86_64-linux-gnu/c++/10 -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../include/c++/10/backward -internal-isystem /usr/lib/llvm-14/lib/clang/14.0.0/include -internal-isystem /usr/local/include -internal-isystem /usr/lib/gcc/x86_64-linux-gnu/10/../../../../x86_64-linux-gnu/include -internal-externc-isystem /usr/include/x86_64-linux-gnu -internal-externc-isystem /include -internal-externc-isystem /usr/include -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fmacro-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fcoverage-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -O3 -Wno-unused-command-line-argument -Wno-unused-parameter -Wwrite-strings -Wno-missing-field-initializers -Wno-long-long -Wno-maybe-uninitialized -Wno-class-memaccess -Wno-redundant-move -Wno-pessimizing-move -Wno-noexcept-type -Wno-comment -std=c++14 -fdeprecated-macro -fdebug-compilation-dir=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/build-llvm/tools/clang/stage2-bins=build-llvm/tools/clang/stage2-bins -fdebug-prefix-map=/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/= -ferror-limit 19 -fvisibility-inlines-hidden -stack-protector 2 -fgnuc-version=4.2.1 -fcolor-diagnostics -vectorize-loops -vectorize-slp -analyzer-output=html -analyzer-config stable-report-filename=true -faddrsig -D__GCC_HAVE_DWARF2_CFI_ASM=1 -o /tmp/scan-build-2022-01-19-134126-35450-1 -x c++ /build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

1//===- SuperVectorize.cpp - Vectorize Pass Impl ---------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements vectorization of loops, operations and data types to
10// a target-independent, n-D super-vector abstraction.
11//
12//===----------------------------------------------------------------------===//
13
14#include "PassDetail.h"
15#include "mlir/Dialect/Affine/Analysis/AffineAnalysis.h"
16#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
17#include "mlir/Dialect/Affine/Analysis/NestedMatcher.h"
18#include "mlir/Dialect/Affine/IR/AffineOps.h"
19#include "mlir/Dialect/Affine/Utils.h"
20#include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
21#include "mlir/Dialect/Vector/VectorOps.h"
22#include "mlir/Dialect/Vector/VectorUtils.h"
23#include "mlir/IR/BlockAndValueMapping.h"
24#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/Support/Debug.h"
27
28using namespace mlir;
29using namespace vector;
30
31///
32/// Implements a high-level vectorization strategy on a Function.
33/// The abstraction used is that of super-vectors, which provide a single,
34/// compact, representation in the vector types, information that is expected
35/// to reduce the impact of the phase ordering problem
36///
37/// Vector granularity:
38/// ===================
39/// This pass is designed to perform vectorization at a super-vector
40/// granularity. A super-vector is loosely defined as a vector type that is a
41/// multiple of a "good" vector size so the HW can efficiently implement a set
42/// of high-level primitives. Multiple is understood along any dimension; e.g.
43/// both vector<16xf32> and vector<2x8xf32> are valid super-vectors for a
44/// vector<8xf32> HW vector. Note that a "good vector size so the HW can
45/// efficiently implement a set of high-level primitives" is not necessarily an
46/// integer multiple of actual hardware registers. We leave details of this
47/// distinction unspecified for now.
48///
49/// Some may prefer the terminology a "tile of HW vectors". In this case, one
50/// should note that super-vectors implement an "always full tile" abstraction.
51/// They guarantee no partial-tile separation is necessary by relying on a
52/// high-level copy-reshape abstraction that we call vector.transfer. This
53/// copy-reshape operations is also responsible for performing layout
54/// transposition if necessary. In the general case this will require a scoped
55/// allocation in some notional local memory.
56///
57/// Whatever the mental model one prefers to use for this abstraction, the key
58/// point is that we burn into a single, compact, representation in the vector
59/// types, information that is expected to reduce the impact of the phase
60/// ordering problem. Indeed, a vector type conveys information that:
61/// 1. the associated loops have dependency semantics that do not prevent
62/// vectorization;
63/// 2. the associate loops have been sliced in chunks of static sizes that are
64/// compatible with vector sizes (i.e. similar to unroll-and-jam);
65/// 3. the inner loops, in the unroll-and-jam analogy of 2, are captured by
66/// the
67/// vector type and no vectorization hampering transformations can be
68/// applied to them anymore;
69/// 4. the underlying memrefs are accessed in some notional contiguous way
70/// that allows loading into vectors with some amount of spatial locality;
71/// In other words, super-vectorization provides a level of separation of
72/// concern by way of opacity to subsequent passes. This has the effect of
73/// encapsulating and propagating vectorization constraints down the list of
74/// passes until we are ready to lower further.
75///
76/// For a particular target, a notion of minimal n-d vector size will be
77/// specified and vectorization targets a multiple of those. In the following
78/// paragraph, let "k ." represent "a multiple of", to be understood as a
79/// multiple in the same dimension (e.g. vector<16 x k . 128> summarizes
80/// vector<16 x 128>, vector<16 x 256>, vector<16 x 1024>, etc).
81///
82/// Some non-exhaustive notable super-vector sizes of interest include:
83/// - CPU: vector<k . HW_vector_size>,
84/// vector<k' . core_count x k . HW_vector_size>,
85/// vector<socket_count x k' . core_count x k . HW_vector_size>;
86/// - GPU: vector<k . warp_size>,
87/// vector<k . warp_size x float2>,
88/// vector<k . warp_size x float4>,
89/// vector<k . warp_size x 4 x 4x 4> (for tensor_core sizes).
90///
91/// Loops and operations are emitted that operate on those super-vector shapes.
92/// Subsequent lowering passes will materialize to actual HW vector sizes. These
93/// passes are expected to be (gradually) more target-specific.
94///
95/// At a high level, a vectorized load in a loop will resemble:
96/// ```mlir
97/// affine.for %i = ? to ? step ? {
98/// %v_a = vector.transfer_read A[%i] : memref<?xf32>, vector<128xf32>
99/// }
100/// ```
101/// It is the responsibility of the implementation of vector.transfer_read to
102/// materialize vector registers from the original scalar memrefs. A later (more
103/// target-dependent) lowering pass will materialize to actual HW vector sizes.
104/// This lowering may be occur at different times:
105/// 1. at the MLIR level into a combination of loops, unrolling, DmaStartOp +
106/// DmaWaitOp + vectorized operations for data transformations and shuffle;
107/// thus opening opportunities for unrolling and pipelining. This is an
108/// instance of library call "whiteboxing"; or
109/// 2. later in the a target-specific lowering pass or hand-written library
110/// call; achieving full separation of concerns. This is an instance of
111/// library call; or
112/// 3. a mix of both, e.g. based on a model.
113/// In the future, these operations will expose a contract to constrain the
114/// search on vectorization patterns and sizes.
115///
116/// Occurrence of super-vectorization in the compiler flow:
117/// =======================================================
118/// This is an active area of investigation. We start with 2 remarks to position
119/// super-vectorization in the context of existing ongoing work: LLVM VPLAN
120/// and LLVM SLP Vectorizer.
121///
122/// LLVM VPLAN:
123/// -----------
124/// The astute reader may have noticed that in the limit, super-vectorization
125/// can be applied at a similar time and with similar objectives than VPLAN.
126/// For instance, in the case of a traditional, polyhedral compilation-flow (for
127/// instance, the PPCG project uses ISL to provide dependence analysis,
128/// multi-level(scheduling + tiling), lifting footprint to fast memory,
129/// communication synthesis, mapping, register optimizations) and before
130/// unrolling. When vectorization is applied at this *late* level in a typical
131/// polyhedral flow, and is instantiated with actual hardware vector sizes,
132/// super-vectorization is expected to match (or subsume) the type of patterns
133/// that LLVM's VPLAN aims at targeting. The main difference here is that MLIR
134/// is higher level and our implementation should be significantly simpler. Also
135/// note that in this mode, recursive patterns are probably a bit of an overkill
136/// although it is reasonable to expect that mixing a bit of outer loop and
137/// inner loop vectorization + unrolling will provide interesting choices to
138/// MLIR.
139///
140/// LLVM SLP Vectorizer:
141/// --------------------
142/// Super-vectorization however is not meant to be usable in a similar fashion
143/// to the SLP vectorizer. The main difference lies in the information that
144/// both vectorizers use: super-vectorization examines contiguity of memory
145/// references along fastest varying dimensions and loops with recursive nested
146/// patterns capturing imperfectly-nested loop nests; the SLP vectorizer, on
147/// the other hand, performs flat pattern matching inside a single unrolled loop
148/// body and stitches together pieces of load and store operations into full
149/// 1-D vectors. We envision that the SLP vectorizer is a good way to capture
150/// innermost loop, control-flow dependent patterns that super-vectorization may
151/// not be able to capture easily. In other words, super-vectorization does not
152/// aim at replacing the SLP vectorizer and the two solutions are complementary.
153///
154/// Ongoing investigations:
155/// -----------------------
156/// We discuss the following *early* places where super-vectorization is
157/// applicable and touch on the expected benefits and risks . We list the
158/// opportunities in the context of the traditional polyhedral compiler flow
159/// described in PPCG. There are essentially 6 places in the MLIR pass pipeline
160/// we expect to experiment with super-vectorization:
161/// 1. Right after language lowering to MLIR: this is the earliest time where
162/// super-vectorization is expected to be applied. At this level, all the
163/// language/user/library-level annotations are available and can be fully
164/// exploited. Examples include loop-type annotations (such as parallel,
165/// reduction, scan, dependence distance vector, vectorizable) as well as
166/// memory access annotations (such as non-aliasing writes guaranteed,
167/// indirect accesses that are permutations by construction) accesses or
168/// that a particular operation is prescribed atomic by the user. At this
169/// level, anything that enriches what dependence analysis can do should be
170/// aggressively exploited. At this level we are close to having explicit
171/// vector types in the language, except we do not impose that burden on the
172/// programmer/library: we derive information from scalar code + annotations.
173/// 2. After dependence analysis and before polyhedral scheduling: the
174/// information that supports vectorization does not need to be supplied by a
175/// higher level of abstraction. Traditional dependence analysis is available
176/// in MLIR and will be used to drive vectorization and cost models.
177///
178/// Let's pause here and remark that applying super-vectorization as described
179/// in 1. and 2. presents clear opportunities and risks:
180/// - the opportunity is that vectorization is burned in the type system and
181/// is protected from the adverse effect of loop scheduling, tiling, loop
182/// interchange and all passes downstream. Provided that subsequent passes are
183/// able to operate on vector types; the vector shapes, associated loop
184/// iterator properties, alignment, and contiguity of fastest varying
185/// dimensions are preserved until we lower the super-vector types. We expect
186/// this to significantly rein in on the adverse effects of phase ordering.
187/// - the risks are that a. all passes after super-vectorization have to work
188/// on elemental vector types (not that this is always true, wherever
189/// vectorization is applied) and b. that imposing vectorization constraints
190/// too early may be overall detrimental to loop fusion, tiling and other
191/// transformations because the dependence distances are coarsened when
192/// operating on elemental vector types. For this reason, the pattern
193/// profitability analysis should include a component that also captures the
194/// maximal amount of fusion available under a particular pattern. This is
195/// still at the stage of rough ideas but in this context, search is our
196/// friend as the Tensor Comprehensions and auto-TVM contributions
197/// demonstrated previously.
198/// Bottom-line is we do not yet have good answers for the above but aim at
199/// making it easy to answer such questions.
200///
201/// Back to our listing, the last places where early super-vectorization makes
202/// sense are:
203/// 3. right after polyhedral-style scheduling: PLUTO-style algorithms are known
204/// to improve locality, parallelism and be configurable (e.g. max-fuse,
205/// smart-fuse etc). They can also have adverse effects on contiguity
206/// properties that are required for vectorization but the vector.transfer
207/// copy-reshape-pad-transpose abstraction is expected to help recapture
208/// these properties.
209/// 4. right after polyhedral-style scheduling+tiling;
210/// 5. right after scheduling+tiling+rescheduling: points 4 and 5 represent
211/// probably the most promising places because applying tiling achieves a
212/// separation of concerns that allows rescheduling to worry less about
213/// locality and more about parallelism and distribution (e.g. min-fuse).
214///
215/// At these levels the risk-reward looks different: on one hand we probably
216/// lost a good deal of language/user/library-level annotation; on the other
217/// hand we gained parallelism and locality through scheduling and tiling.
218/// However we probably want to ensure tiling is compatible with the
219/// full-tile-only abstraction used in super-vectorization or suffer the
220/// consequences. It is too early to place bets on what will win but we expect
221/// super-vectorization to be the right abstraction to allow exploring at all
222/// these levels. And again, search is our friend.
223///
224/// Lastly, we mention it again here:
225/// 6. as a MLIR-based alternative to VPLAN.
226///
227/// Lowering, unrolling, pipelining:
228/// ================================
229/// TODO: point to the proper places.
230///
231/// Algorithm:
232/// ==========
233/// The algorithm proceeds in a few steps:
234/// 1. defining super-vectorization patterns and matching them on the tree of
235/// AffineForOp. A super-vectorization pattern is defined as a recursive
236/// data structures that matches and captures nested, imperfectly-nested
237/// loops that have a. conformable loop annotations attached (e.g. parallel,
238/// reduction, vectorizable, ...) as well as b. all contiguous load/store
239/// operations along a specified minor dimension (not necessarily the
240/// fastest varying) ;
241/// 2. analyzing those patterns for profitability (TODO: and
242/// interference);
243/// 3. then, for each pattern in order:
244/// a. applying iterative rewriting of the loops and all their nested
245/// operations in topological order. Rewriting is implemented by
246/// coarsening the loops and converting operations and operands to their
247/// vector forms. Processing operations in topological order is relatively
248/// simple due to the structured nature of the control-flow
249/// representation. This order ensures that all the operands of a given
250/// operation have been vectorized before the operation itself in a single
251/// traversal, except for operands defined outside of the loop nest. The
252/// algorithm can convert the following operations to their vector form:
253/// * Affine load and store operations are converted to opaque vector
254/// transfer read and write operations.
255/// * Scalar constant operations/operands are converted to vector
256/// constant operations (splat).
257/// * Uniform operands (only induction variables of loops not mapped to
258/// a vector dimension, or operands defined outside of the loop nest
259/// for now) are broadcasted to a vector.
260/// TODO: Support more uniform cases.
261/// * Affine for operations with 'iter_args' are vectorized by
262/// vectorizing their 'iter_args' operands and results.
263/// TODO: Support more complex loops with divergent lbs and/or ubs.
264/// * The remaining operations in the loop nest are vectorized by
265/// widening their scalar types to vector types.
266/// b. if everything under the root AffineForOp in the current pattern
267/// is vectorized properly, we commit that loop to the IR and remove the
268/// scalar loop. Otherwise, we discard the vectorized loop and keep the
269/// original scalar loop.
270/// c. vectorization is applied on the next pattern in the list. Because
271/// pattern interference avoidance is not yet implemented and that we do
272/// not support further vectorizing an already vector load we need to
273/// re-verify that the pattern is still vectorizable. This is expected to
274/// make cost models more difficult to write and is subject to improvement
275/// in the future.
276///
277/// Choice of loop transformation to support the algorithm:
278/// =======================================================
279/// The choice of loop transformation to apply for coarsening vectorized loops
280/// is still subject to exploratory tradeoffs. In particular, say we want to
281/// vectorize by a factor 128, we want to transform the following input:
282/// ```mlir
283/// affine.for %i = %M to %N {
284/// %a = affine.load %A[%i] : memref<?xf32>
285/// }
286/// ```
287///
288/// Traditionally, one would vectorize late (after scheduling, tiling,
289/// memory promotion etc) say after stripmining (and potentially unrolling in
290/// the case of LLVM's SLP vectorizer):
291/// ```mlir
292/// affine.for %i = floor(%M, 128) to ceil(%N, 128) {
293/// affine.for %ii = max(%M, 128 * %i) to min(%N, 128*%i + 127) {
294/// %a = affine.load %A[%ii] : memref<?xf32>
295/// }
296/// }
297/// ```
298///
299/// Instead, we seek to vectorize early and freeze vector types before
300/// scheduling, so we want to generate a pattern that resembles:
301/// ```mlir
302/// affine.for %i = ? to ? step ? {
303/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
304/// }
305/// ```
306///
307/// i. simply dividing the lower / upper bounds by 128 creates issues
308/// when representing expressions such as ii + 1 because now we only
309/// have access to original values that have been divided. Additional
310/// information is needed to specify accesses at below-128 granularity;
311/// ii. another alternative is to coarsen the loop step but this may have
312/// consequences on dependence analysis and fusability of loops: fusable
313/// loops probably need to have the same step (because we don't want to
314/// stripmine/unroll to enable fusion).
315/// As a consequence, we choose to represent the coarsening using the loop
316/// step for now and reevaluate in the future. Note that we can renormalize
317/// loop steps later if/when we have evidence that they are problematic.
318///
319/// For the simple strawman example above, vectorizing for a 1-D vector
320/// abstraction of size 128 returns code similar to:
321/// ```mlir
322/// affine.for %i = %M to %N step 128 {
323/// %v_a = vector.transfer_read %A[%i] : memref<?xf32>, vector<128xf32>
324/// }
325/// ```
326///
327/// Unsupported cases, extensions, and work in progress (help welcome :-) ):
328/// ========================================================================
329/// 1. lowering to concrete vector types for various HW;
330/// 2. reduction support for n-D vectorization and non-unit steps;
331/// 3. non-effecting padding during vector.transfer_read and filter during
332/// vector.transfer_write;
333/// 4. misalignment support vector.transfer_read / vector.transfer_write
334/// (hopefully without read-modify-writes);
335/// 5. control-flow support;
336/// 6. cost-models, heuristics and search;
337/// 7. Op implementation, extensions and implication on memref views;
338/// 8. many TODOs left around.
339///
340/// Examples:
341/// =========
342/// Consider the following Function:
343/// ```mlir
344/// func @vector_add_2d(%M : index, %N : index) -> f32 {
345/// %A = alloc (%M, %N) : memref<?x?xf32, 0>
346/// %B = alloc (%M, %N) : memref<?x?xf32, 0>
347/// %C = alloc (%M, %N) : memref<?x?xf32, 0>
348/// %f1 = arith.constant 1.0 : f32
349/// %f2 = arith.constant 2.0 : f32
350/// affine.for %i0 = 0 to %M {
351/// affine.for %i1 = 0 to %N {
352/// // non-scoped %f1
353/// affine.store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
354/// }
355/// }
356/// affine.for %i2 = 0 to %M {
357/// affine.for %i3 = 0 to %N {
358/// // non-scoped %f2
359/// affine.store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
360/// }
361/// }
362/// affine.for %i4 = 0 to %M {
363/// affine.for %i5 = 0 to %N {
364/// %a5 = affine.load %A[%i4, %i5] : memref<?x?xf32, 0>
365/// %b5 = affine.load %B[%i4, %i5] : memref<?x?xf32, 0>
366/// %s5 = arith.addf %a5, %b5 : f32
367/// // non-scoped %f1
368/// %s6 = arith.addf %s5, %f1 : f32
369/// // non-scoped %f2
370/// %s7 = arith.addf %s5, %f2 : f32
371/// // diamond dependency.
372/// %s8 = arith.addf %s7, %s6 : f32
373/// affine.store %s8, %C[%i4, %i5] : memref<?x?xf32, 0>
374/// }
375/// }
376/// %c7 = arith.constant 7 : index
377/// %c42 = arith.constant 42 : index
378/// %res = load %C[%c7, %c42] : memref<?x?xf32, 0>
379/// return %res : f32
380/// }
381/// ```
382///
383/// The -affine-vectorize pass with the following arguments:
384/// ```
385/// -affine-vectorize="virtual-vector-size=256 test-fastest-varying=0"
386/// ```
387///
388/// produces this standard innermost-loop vectorized code:
389/// ```mlir
390/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
391/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
392/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
393/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
394/// %cst = arith.constant 1.0 : f32
395/// %cst_0 = arith.constant 2.0 : f32
396/// affine.for %i0 = 0 to %arg0 {
397/// affine.for %i1 = 0 to %arg1 step 256 {
398/// %cst_1 = arith.constant dense<vector<256xf32>, 1.0> :
399/// vector<256xf32>
400/// vector.transfer_write %cst_1, %0[%i0, %i1] :
401/// vector<256xf32>, memref<?x?xf32>
402/// }
403/// }
404/// affine.for %i2 = 0 to %arg0 {
405/// affine.for %i3 = 0 to %arg1 step 256 {
406/// %cst_2 = arith.constant dense<vector<256xf32>, 2.0> :
407/// vector<256xf32>
408/// vector.transfer_write %cst_2, %1[%i2, %i3] :
409/// vector<256xf32>, memref<?x?xf32>
410/// }
411/// }
412/// affine.for %i4 = 0 to %arg0 {
413/// affine.for %i5 = 0 to %arg1 step 256 {
414/// %3 = vector.transfer_read %0[%i4, %i5] :
415/// memref<?x?xf32>, vector<256xf32>
416/// %4 = vector.transfer_read %1[%i4, %i5] :
417/// memref<?x?xf32>, vector<256xf32>
418/// %5 = arith.addf %3, %4 : vector<256xf32>
419/// %cst_3 = arith.constant dense<vector<256xf32>, 1.0> :
420/// vector<256xf32>
421/// %6 = arith.addf %5, %cst_3 : vector<256xf32>
422/// %cst_4 = arith.constant dense<vector<256xf32>, 2.0> :
423/// vector<256xf32>
424/// %7 = arith.addf %5, %cst_4 : vector<256xf32>
425/// %8 = arith.addf %7, %6 : vector<256xf32>
426/// vector.transfer_write %8, %2[%i4, %i5] :
427/// vector<256xf32>, memref<?x?xf32>
428/// }
429/// }
430/// %c7 = arith.constant 7 : index
431/// %c42 = arith.constant 42 : index
432/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
433/// return %9 : f32
434/// }
435/// ```
436///
437/// The -affine-vectorize pass with the following arguments:
438/// ```
439/// -affine-vectorize="virtual-vector-size=32,256 test-fastest-varying=1,0"
440/// ```
441///
442/// produces this more interesting mixed outer-innermost-loop vectorized code:
443/// ```mlir
444/// func @vector_add_2d(%arg0 : index, %arg1 : index) -> f32 {
445/// %0 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
446/// %1 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
447/// %2 = memref.alloc(%arg0, %arg1) : memref<?x?xf32>
448/// %cst = arith.constant 1.0 : f32
449/// %cst_0 = arith.constant 2.0 : f32
450/// affine.for %i0 = 0 to %arg0 step 32 {
451/// affine.for %i1 = 0 to %arg1 step 256 {
452/// %cst_1 = arith.constant dense<vector<32x256xf32>, 1.0> :
453/// vector<32x256xf32>
454/// vector.transfer_write %cst_1, %0[%i0, %i1] :
455/// vector<32x256xf32>, memref<?x?xf32>
456/// }
457/// }
458/// affine.for %i2 = 0 to %arg0 step 32 {
459/// affine.for %i3 = 0 to %arg1 step 256 {
460/// %cst_2 = arith.constant dense<vector<32x256xf32>, 2.0> :
461/// vector<32x256xf32>
462/// vector.transfer_write %cst_2, %1[%i2, %i3] :
463/// vector<32x256xf32>, memref<?x?xf32>
464/// }
465/// }
466/// affine.for %i4 = 0 to %arg0 step 32 {
467/// affine.for %i5 = 0 to %arg1 step 256 {
468/// %3 = vector.transfer_read %0[%i4, %i5] :
469/// memref<?x?xf32> vector<32x256xf32>
470/// %4 = vector.transfer_read %1[%i4, %i5] :
471/// memref<?x?xf32>, vector<32x256xf32>
472/// %5 = arith.addf %3, %4 : vector<32x256xf32>
473/// %cst_3 = arith.constant dense<vector<32x256xf32>, 1.0> :
474/// vector<32x256xf32>
475/// %6 = arith.addf %5, %cst_3 : vector<32x256xf32>
476/// %cst_4 = arith.constant dense<vector<32x256xf32>, 2.0> :
477/// vector<32x256xf32>
478/// %7 = arith.addf %5, %cst_4 : vector<32x256xf32>
479/// %8 = arith.addf %7, %6 : vector<32x256xf32>
480/// vector.transfer_write %8, %2[%i4, %i5] :
481/// vector<32x256xf32>, memref<?x?xf32>
482/// }
483/// }
484/// %c7 = arith.constant 7 : index
485/// %c42 = arith.constant 42 : index
486/// %9 = load %2[%c7, %c42] : memref<?x?xf32>
487/// return %9 : f32
488/// }
489/// ```
490///
491/// Of course, much more intricate n-D imperfectly-nested patterns can be
492/// vectorized too and specified in a fully declarative fashion.
493///
494/// Reduction:
495/// ==========
496/// Vectorizing reduction loops along the reduction dimension is supported if:
497/// - the reduction kind is supported,
498/// - the vectorization is 1-D, and
499/// - the step size of the loop equals to one.
500///
501/// Comparing to the non-vector-dimension case, two additional things are done
502/// during vectorization of such loops:
503/// - The resulting vector returned from the loop is reduced to a scalar using
504/// `vector.reduce`.
505/// - In some cases a mask is applied to the vector yielded at the end of the
506/// loop to prevent garbage values from being written to the accumulator.
507///
508/// Reduction vectorization is switched off by default, it can be enabled by
509/// passing a map from loops to reductions to utility functions, or by passing
510/// `vectorize-reductions=true` to the vectorization pass.
511///
512/// Consider the following example:
513/// ```mlir
514/// func @vecred(%in: memref<512xf32>) -> f32 {
515/// %cst = arith.constant 0.000000e+00 : f32
516/// %sum = affine.for %i = 0 to 500 iter_args(%part_sum = %cst) -> (f32) {
517/// %ld = affine.load %in[%i] : memref<512xf32>
518/// %cos = math.cos %ld : f32
519/// %add = arith.addf %part_sum, %cos : f32
520/// affine.yield %add : f32
521/// }
522/// return %sum : f32
523/// }
524/// ```
525///
526/// The -affine-vectorize pass with the following arguments:
527/// ```
528/// -affine-vectorize="virtual-vector-size=128 test-fastest-varying=0 \
529/// vectorize-reductions=true"
530/// ```
531/// produces the following output:
532/// ```mlir
533/// #map = affine_map<(d0) -> (-d0 + 500)>
534/// func @vecred(%arg0: memref<512xf32>) -> f32 {
535/// %cst = arith.constant 0.000000e+00 : f32
536/// %cst_0 = arith.constant dense<0.000000e+00> : vector<128xf32>
537/// %0 = affine.for %arg1 = 0 to 500 step 128 iter_args(%arg2 = %cst_0)
538/// -> (vector<128xf32>) {
539/// // %2 is the number of iterations left in the original loop.
540/// %2 = affine.apply #map(%arg1)
541/// %3 = vector.create_mask %2 : vector<128xi1>
542/// %cst_1 = arith.constant 0.000000e+00 : f32
543/// %4 = vector.transfer_read %arg0[%arg1], %cst_1 :
544/// memref<512xf32>, vector<128xf32>
545/// %5 = math.cos %4 : vector<128xf32>
546/// %6 = arith.addf %arg2, %5 : vector<128xf32>
547/// // We filter out the effect of last 12 elements using the mask.
548/// %7 = select %3, %6, %arg2 : vector<128xi1>, vector<128xf32>
549/// affine.yield %7 : vector<128xf32>
550/// }
551/// %1 = vector.reduction "add", %0 : vector<128xf32> into f32
552/// return %1 : f32
553/// }
554/// ```
555///
556/// Note that because of loop misalignment we needed to apply a mask to prevent
557/// last 12 elements from affecting the final result. The mask is full of ones
558/// in every iteration except for the last one, in which it has the form
559/// `11...100...0` with 116 ones and 12 zeros.
560
561#define DEBUG_TYPE"early-vect" "early-vect"
562
563using llvm::dbgs;
564
565/// Forward declaration.
566static FilterFunctionType
567isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
568 int fastestVaryingMemRefDimension);
569
570/// Creates a vectorization pattern from the command line arguments.
571/// Up to 3-D patterns are supported.
572/// If the command line argument requests a pattern of higher order, returns an
573/// empty pattern list which will conservatively result in no vectorization.
574static Optional<NestedPattern>
575makePattern(const DenseSet<Operation *> &parallelLoops, int vectorRank,
576 ArrayRef<int64_t> fastestVaryingPattern) {
577 using matcher::For;
578 int64_t d0 = fastestVaryingPattern.empty() ? -1 : fastestVaryingPattern[0];
579 int64_t d1 = fastestVaryingPattern.size() < 2 ? -1 : fastestVaryingPattern[1];
580 int64_t d2 = fastestVaryingPattern.size() < 3 ? -1 : fastestVaryingPattern[2];
581 switch (vectorRank) {
582 case 1:
583 return For(isVectorizableLoopPtrFactory(parallelLoops, d0));
584 case 2:
585 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
586 For(isVectorizableLoopPtrFactory(parallelLoops, d1)));
587 case 3:
588 return For(isVectorizableLoopPtrFactory(parallelLoops, d0),
589 For(isVectorizableLoopPtrFactory(parallelLoops, d1),
590 For(isVectorizableLoopPtrFactory(parallelLoops, d2))));
591 default: {
592 return llvm::None;
593 }
594 }
595}
596
597static NestedPattern &vectorTransferPattern() {
598 static auto pattern = matcher::Op([](Operation &op) {
599 return isa<vector::TransferReadOp, vector::TransferWriteOp>(op);
600 });
601 return pattern;
602}
603
604namespace {
605
606/// Base state for the vectorize pass.
607/// Command line arguments are preempted by non-empty pass arguments.
608struct Vectorize : public AffineVectorizeBase<Vectorize> {
609 Vectorize() = default;
610 Vectorize(ArrayRef<int64_t> virtualVectorSize);
611 void runOnOperation() override;
612};
613
614} // namespace
615
616Vectorize::Vectorize(ArrayRef<int64_t> virtualVectorSize) {
617 vectorSizes = virtualVectorSize;
618}
619
620static void vectorizeLoopIfProfitable(Operation *loop, unsigned depthInPattern,
621 unsigned patternDepth,
622 VectorizationStrategy *strategy) {
623 assert(patternDepth > depthInPattern &&(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 624
, __extension__ __PRETTY_FUNCTION__))
624 "patternDepth is greater than depthInPattern")(static_cast <bool> (patternDepth > depthInPattern &&
"patternDepth is greater than depthInPattern") ? void (0) : __assert_fail
("patternDepth > depthInPattern && \"patternDepth is greater than depthInPattern\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 624
, __extension__ __PRETTY_FUNCTION__))
;
625 if (patternDepth - depthInPattern > strategy->vectorSizes.size()) {
626 // Don't vectorize this loop
627 return;
628 }
629 strategy->loopToVectorDim[loop] =
630 strategy->vectorSizes.size() - (patternDepth - depthInPattern);
631}
632
633/// Implements a simple strawman strategy for vectorization.
634/// Given a matched pattern `matches` of depth `patternDepth`, this strategy
635/// greedily assigns the fastest varying dimension ** of the vector ** to the
636/// innermost loop in the pattern.
637/// When coupled with a pattern that looks for the fastest varying dimension in
638/// load/store MemRefs, this creates a generic vectorization strategy that works
639/// for any loop in a hierarchy (outermost, innermost or intermediate).
640///
641/// TODO: In the future we should additionally increase the power of the
642/// profitability analysis along 3 directions:
643/// 1. account for loop extents (both static and parametric + annotations);
644/// 2. account for data layout permutations;
645/// 3. account for impact of vectorization on maximal loop fusion.
646/// Then we can quantify the above to build a cost model and search over
647/// strategies.
648static LogicalResult analyzeProfitability(ArrayRef<NestedMatch> matches,
649 unsigned depthInPattern,
650 unsigned patternDepth,
651 VectorizationStrategy *strategy) {
652 for (auto m : matches) {
653 if (failed(analyzeProfitability(m.getMatchedChildren(), depthInPattern + 1,
654 patternDepth, strategy))) {
655 return failure();
656 }
657 vectorizeLoopIfProfitable(m.getMatchedOperation(), depthInPattern,
658 patternDepth, strategy);
659 }
660 return success();
661}
662
663///// end TODO: Hoist to a VectorizationStrategy.cpp when appropriate /////
664
665namespace {
666
667struct VectorizationState {
668
669 VectorizationState(MLIRContext *context) : builder(context) {}
670
671 /// Registers the vector replacement of a scalar operation and its result
672 /// values. Both operations must have the same number of results.
673 ///
674 /// This utility is used to register the replacement for the vast majority of
675 /// the vectorized operations.
676 ///
677 /// Example:
678 /// * 'replaced': %0 = arith.addf %1, %2 : f32
679 /// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
680 void registerOpVectorReplacement(Operation *replaced, Operation *replacement);
681
682 /// Registers the vector replacement of a scalar value. The replacement
683 /// operation should have a single result, which replaces the scalar value.
684 ///
685 /// This utility is used to register the vector replacement of block arguments
686 /// and operation results which are not directly vectorized (i.e., their
687 /// scalar version still exists after vectorization), like uniforms.
688 ///
689 /// Example:
690 /// * 'replaced': block argument or operation outside of the vectorized
691 /// loop.
692 /// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
693 void registerValueVectorReplacement(Value replaced, Operation *replacement);
694
695 /// Registers the vector replacement of a block argument (e.g., iter_args).
696 ///
697 /// Example:
698 /// * 'replaced': 'iter_arg' block argument.
699 /// * 'replacement': vectorized 'iter_arg' block argument.
700 void registerBlockArgVectorReplacement(BlockArgument replaced,
701 BlockArgument replacement);
702
703 /// Registers the scalar replacement of a scalar value. 'replacement' must be
704 /// scalar. Both values must be block arguments. Operation results should be
705 /// replaced using the 'registerOp*' utilitites.
706 ///
707 /// This utility is used to register the replacement of block arguments
708 /// that are within the loop to be vectorized and will continue being scalar
709 /// within the vector loop.
710 ///
711 /// Example:
712 /// * 'replaced': induction variable of a loop to be vectorized.
713 /// * 'replacement': new induction variable in the new vector loop.
714 void registerValueScalarReplacement(BlockArgument replaced,
715 BlockArgument replacement);
716
717 /// Registers the scalar replacement of a scalar result returned from a
718 /// reduction loop. 'replacement' must be scalar.
719 ///
720 /// This utility is used to register the replacement for scalar results of
721 /// vectorized reduction loops with iter_args.
722 ///
723 /// Example 2:
724 /// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
725 /// * 'replacement': %1 = vector.reduction "add" %0 : vector<4xf32> into f32
726 void registerLoopResultScalarReplacement(Value replaced, Value replacement);
727
728 /// Returns in 'replacedVals' the scalar replacement for values in
729 /// 'inputVals'.
730 void getScalarValueReplacementsFor(ValueRange inputVals,
731 SmallVectorImpl<Value> &replacedVals);
732
733 /// Erases the scalar loop nest after its successful vectorization.
734 void finishVectorizationPattern(AffineForOp rootLoop);
735
736 // Used to build and insert all the new operations created. The insertion
737 // point is preserved and updated along the vectorization process.
738 OpBuilder builder;
739
740 // Maps input scalar operations to their vector counterparts.
741 DenseMap<Operation *, Operation *> opVectorReplacement;
742 // Maps input scalar values to their vector counterparts.
743 BlockAndValueMapping valueVectorReplacement;
744 // Maps input scalar values to their new scalar counterparts in the vector
745 // loop nest.
746 BlockAndValueMapping valueScalarReplacement;
747 // Maps results of reduction loops to their new scalar counterparts.
748 DenseMap<Value, Value> loopResultScalarReplacement;
749
750 // Maps the newly created vector loops to their vector dimension.
751 DenseMap<Operation *, unsigned> vecLoopToVecDim;
752
753 // Maps the new vectorized loops to the corresponding vector masks if it is
754 // required.
755 DenseMap<Operation *, Value> vecLoopToMask;
756
757 // The strategy drives which loop to vectorize by which amount.
758 const VectorizationStrategy *strategy = nullptr;
759
760private:
761 /// Internal implementation to map input scalar values to new vector or scalar
762 /// values.
763 void registerValueVectorReplacementImpl(Value replaced, Value replacement);
764 void registerValueScalarReplacementImpl(Value replaced, Value replacement);
765};
766
767} // namespace
768
769/// Registers the vector replacement of a scalar operation and its result
770/// values. Both operations must have the same number of results.
771///
772/// This utility is used to register the replacement for the vast majority of
773/// the vectorized operations.
774///
775/// Example:
776/// * 'replaced': %0 = arith.addf %1, %2 : f32
777/// * 'replacement': %0 = arith.addf %1, %2 : vector<128xf32>
778void VectorizationState::registerOpVectorReplacement(Operation *replaced,
779 Operation *replacement) {
780 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ commit vectorized op:\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ commit vectorized op:\n"
; } } while (false)
;
781 LLVM_DEBUG(dbgs() << *replaced << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replaced << "\n"; } }
while (false)
;
782 LLVM_DEBUG(dbgs() << "into\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "into\n"; } } while (false)
;
783 LLVM_DEBUG(dbgs() << *replacement << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *replacement << "\n";
} } while (false)
;
784
785 assert(replaced->getNumResults() == replacement->getNumResults() &&(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 786
, __extension__ __PRETTY_FUNCTION__))
786 "Unexpected replaced and replacement results")(static_cast <bool> (replaced->getNumResults() == replacement
->getNumResults() && "Unexpected replaced and replacement results"
) ? void (0) : __assert_fail ("replaced->getNumResults() == replacement->getNumResults() && \"Unexpected replaced and replacement results\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 786
, __extension__ __PRETTY_FUNCTION__))
;
787 assert(opVectorReplacement.count(replaced) == 0 && "already registered")(static_cast <bool> (opVectorReplacement.count(replaced
) == 0 && "already registered") ? void (0) : __assert_fail
("opVectorReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 787
, __extension__ __PRETTY_FUNCTION__))
;
788 opVectorReplacement[replaced] = replacement;
789
790 for (auto resultTuple :
791 llvm::zip(replaced->getResults(), replacement->getResults()))
792 registerValueVectorReplacementImpl(std::get<0>(resultTuple),
793 std::get<1>(resultTuple));
794}
795
796/// Registers the vector replacement of a scalar value. The replacement
797/// operation should have a single result, which replaces the scalar value.
798///
799/// This utility is used to register the vector replacement of block arguments
800/// and operation results which are not directly vectorized (i.e., their
801/// scalar version still exists after vectorization), like uniforms.
802///
803/// Example:
804/// * 'replaced': block argument or operation outside of the vectorized loop.
805/// * 'replacement': %0 = vector.broadcast %1 : f32 to vector<128xf32>
806void VectorizationState::registerValueVectorReplacement(
807 Value replaced, Operation *replacement) {
808 assert(replacement->getNumResults() == 1 &&(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 809
, __extension__ __PRETTY_FUNCTION__))
809 "Expected single-result replacement")(static_cast <bool> (replacement->getNumResults() ==
1 && "Expected single-result replacement") ? void (0
) : __assert_fail ("replacement->getNumResults() == 1 && \"Expected single-result replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 809
, __extension__ __PRETTY_FUNCTION__))
;
810 if (Operation *defOp = replaced.getDefiningOp())
811 registerOpVectorReplacement(defOp, replacement);
812 else
813 registerValueVectorReplacementImpl(replaced, replacement->getResult(0));
814}
815
816/// Registers the vector replacement of a block argument (e.g., iter_args).
817///
818/// Example:
819/// * 'replaced': 'iter_arg' block argument.
820/// * 'replacement': vectorized 'iter_arg' block argument.
821void VectorizationState::registerBlockArgVectorReplacement(
822 BlockArgument replaced, BlockArgument replacement) {
823 registerValueVectorReplacementImpl(replaced, replacement);
824}
825
826void VectorizationState::registerValueVectorReplacementImpl(Value replaced,
827 Value replacement) {
828 assert(!valueVectorReplacement.contains(replaced) &&(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 829
, __extension__ __PRETTY_FUNCTION__))
829 "Vector replacement already registered")(static_cast <bool> (!valueVectorReplacement.contains(replaced
) && "Vector replacement already registered") ? void (
0) : __assert_fail ("!valueVectorReplacement.contains(replaced) && \"Vector replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 829
, __extension__ __PRETTY_FUNCTION__))
;
830 assert(replacement.getType().isa<VectorType>() &&(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831
, __extension__ __PRETTY_FUNCTION__))
831 "Expected vector type in vector replacement")(static_cast <bool> (replacement.getType().isa<VectorType
>() && "Expected vector type in vector replacement"
) ? void (0) : __assert_fail ("replacement.getType().isa<VectorType>() && \"Expected vector type in vector replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 831
, __extension__ __PRETTY_FUNCTION__))
;
832 valueVectorReplacement.map(replaced, replacement);
833}
834
835/// Registers the scalar replacement of a scalar value. 'replacement' must be
836/// scalar. Both values must be block arguments. Operation results should be
837/// replaced using the 'registerOp*' utilitites.
838///
839/// This utility is used to register the replacement of block arguments
840/// that are within the loop to be vectorized and will continue being scalar
841/// within the vector loop.
842///
843/// Example:
844/// * 'replaced': induction variable of a loop to be vectorized.
845/// * 'replacement': new induction variable in the new vector loop.
846void VectorizationState::registerValueScalarReplacement(
847 BlockArgument replaced, BlockArgument replacement) {
848 registerValueScalarReplacementImpl(replaced, replacement);
849}
850
851/// Registers the scalar replacement of a scalar result returned from a
852/// reduction loop. 'replacement' must be scalar.
853///
854/// This utility is used to register the replacement for scalar results of
855/// vectorized reduction loops with iter_args.
856///
857/// Example 2:
858/// * 'replaced': %0 = affine.for %i = 0 to 512 iter_args(%x = ...) -> (f32)
859/// * 'replacement': %1 = vector.reduction "add" %0 : vector<4xf32> into f32
860void VectorizationState::registerLoopResultScalarReplacement(
861 Value replaced, Value replacement) {
862 assert(isa<AffineForOp>(replaced.getDefiningOp()))(static_cast <bool> (isa<AffineForOp>(replaced.getDefiningOp
())) ? void (0) : __assert_fail ("isa<AffineForOp>(replaced.getDefiningOp())"
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 862
, __extension__ __PRETTY_FUNCTION__))
;
863 assert(loopResultScalarReplacement.count(replaced) == 0 &&(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 864
, __extension__ __PRETTY_FUNCTION__))
864 "already registered")(static_cast <bool> (loopResultScalarReplacement.count(
replaced) == 0 && "already registered") ? void (0) : __assert_fail
("loopResultScalarReplacement.count(replaced) == 0 && \"already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 864
, __extension__ __PRETTY_FUNCTION__))
;
865 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ will replace a result of the loop "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
866 "with scalar: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
867 << replacement)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ will replace a result of the loop "
"with scalar: " << replacement; } } while (false)
;
868 loopResultScalarReplacement[replaced] = replacement;
869}
870
871void VectorizationState::registerValueScalarReplacementImpl(Value replaced,
872 Value replacement) {
873 assert(!valueScalarReplacement.contains(replaced) &&(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 874
, __extension__ __PRETTY_FUNCTION__))
874 "Scalar value replacement already registered")(static_cast <bool> (!valueScalarReplacement.contains(replaced
) && "Scalar value replacement already registered") ?
void (0) : __assert_fail ("!valueScalarReplacement.contains(replaced) && \"Scalar value replacement already registered\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 874
, __extension__ __PRETTY_FUNCTION__))
;
875 assert(!replacement.getType().isa<VectorType>() &&(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876
, __extension__ __PRETTY_FUNCTION__))
876 "Expected scalar type in scalar replacement")(static_cast <bool> (!replacement.getType().isa<VectorType
>() && "Expected scalar type in scalar replacement"
) ? void (0) : __assert_fail ("!replacement.getType().isa<VectorType>() && \"Expected scalar type in scalar replacement\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 876
, __extension__ __PRETTY_FUNCTION__))
;
877 valueScalarReplacement.map(replaced, replacement);
878}
879
880/// Returns in 'replacedVals' the scalar replacement for values in 'inputVals'.
881void VectorizationState::getScalarValueReplacementsFor(
882 ValueRange inputVals, SmallVectorImpl<Value> &replacedVals) {
883 for (Value inputVal : inputVals)
884 replacedVals.push_back(valueScalarReplacement.lookupOrDefault(inputVal));
885}
886
887/// Erases a loop nest, including all its nested operations.
888static void eraseLoopNest(AffineForOp forOp) {
889 LLVM_DEBUG(dbgs() << "[early-vect]+++++ erasing:\n" << forOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ erasing:\n"
<< forOp << "\n"; } } while (false)
;
890 forOp.erase();
891}
892
893/// Erases the scalar loop nest after its successful vectorization.
894void VectorizationState::finishVectorizationPattern(AffineForOp rootLoop) {
895 LLVM_DEBUG(dbgs() << "\n[early-vect] Finalizing vectorization\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] Finalizing vectorization\n"
; } } while (false)
;
896 eraseLoopNest(rootLoop);
897}
898
899// Apply 'map' with 'mapOperands' returning resulting values in 'results'.
900static void computeMemoryOpIndices(Operation *op, AffineMap map,
901 ValueRange mapOperands,
902 VectorizationState &state,
903 SmallVectorImpl<Value> &results) {
904 for (auto resultExpr : map.getResults()) {
905 auto singleResMap =
906 AffineMap::get(map.getNumDims(), map.getNumSymbols(), resultExpr);
907 auto afOp = state.builder.create<AffineApplyOp>(op->getLoc(), singleResMap,
908 mapOperands);
909 results.push_back(afOp);
910 }
911}
912
913/// Returns a FilterFunctionType that can be used in NestedPattern to match a
914/// loop whose underlying load/store accesses are either invariant or all
915// varying along the `fastestVaryingMemRefDimension`.
916static FilterFunctionType
917isVectorizableLoopPtrFactory(const DenseSet<Operation *> &parallelLoops,
918 int fastestVaryingMemRefDimension) {
919 return [&parallelLoops, fastestVaryingMemRefDimension](Operation &forOp) {
920 auto loop = cast<AffineForOp>(forOp);
921 auto parallelIt = parallelLoops.find(loop);
922 if (parallelIt == parallelLoops.end())
923 return false;
924 int memRefDim = -1;
925 auto vectorizableBody =
926 isVectorizableLoopBody(loop, &memRefDim, vectorTransferPattern());
927 if (!vectorizableBody)
928 return false;
929 return memRefDim == -1 || fastestVaryingMemRefDimension == -1 ||
930 memRefDim == fastestVaryingMemRefDimension;
931 };
932}
933
934/// Returns the vector type resulting from applying the provided vectorization
935/// strategy on the scalar type.
936static VectorType getVectorType(Type scalarTy,
937 const VectorizationStrategy *strategy) {
938 assert(!scalarTy.isa<VectorType>() && "Expected scalar type")(static_cast <bool> (!scalarTy.isa<VectorType>() &&
"Expected scalar type") ? void (0) : __assert_fail ("!scalarTy.isa<VectorType>() && \"Expected scalar type\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 938
, __extension__ __PRETTY_FUNCTION__))
;
939 return VectorType::get(strategy->vectorSizes, scalarTy);
940}
941
942/// Tries to transform a scalar constant into a vector constant. Returns the
943/// vector constant if the scalar type is valid vector element type. Returns
944/// nullptr, otherwise.
945static arith::ConstantOp vectorizeConstant(arith::ConstantOp constOp,
946 VectorizationState &state) {
947 Type scalarTy = constOp.getType();
948 if (!VectorType::isValidElementType(scalarTy))
949 return nullptr;
950
951 auto vecTy = getVectorType(scalarTy, state.strategy);
952 auto vecAttr = DenseElementsAttr::get(vecTy, constOp.getValue());
953
954 OpBuilder::InsertionGuard guard(state.builder);
955 Operation *parentOp = state.builder.getInsertionBlock()->getParentOp();
956 // Find the innermost vectorized ancestor loop to insert the vector constant.
957 while (parentOp && !state.vecLoopToVecDim.count(parentOp))
958 parentOp = parentOp->getParentOp();
959 assert(parentOp && state.vecLoopToVecDim.count(parentOp) &&(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 960
, __extension__ __PRETTY_FUNCTION__))
960 isa<AffineForOp>(parentOp) && "Expected a vectorized for op")(static_cast <bool> (parentOp && state.vecLoopToVecDim
.count(parentOp) && isa<AffineForOp>(parentOp) &&
"Expected a vectorized for op") ? void (0) : __assert_fail (
"parentOp && state.vecLoopToVecDim.count(parentOp) && isa<AffineForOp>(parentOp) && \"Expected a vectorized for op\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 960
, __extension__ __PRETTY_FUNCTION__))
;
961 auto vecForOp = cast<AffineForOp>(parentOp);
962 state.builder.setInsertionPointToStart(vecForOp.getBody());
963 auto newConstOp =
964 state.builder.create<arith::ConstantOp>(constOp.getLoc(), vecAttr);
965
966 // Register vector replacement for future uses in the scope.
967 state.registerOpVectorReplacement(constOp, newConstOp);
968 return newConstOp;
969}
970
971/// Creates a constant vector filled with the neutral elements of the given
972/// reduction. The scalar type of vector elements will be taken from
973/// `oldOperand`.
974static arith::ConstantOp createInitialVector(arith::AtomicRMWKind reductionKind,
975 Value oldOperand,
976 VectorizationState &state) {
977 Type scalarTy = oldOperand.getType();
978 if (!VectorType::isValidElementType(scalarTy))
979 return nullptr;
980
981 Attribute valueAttr = getIdentityValueAttr(
982 reductionKind, scalarTy, state.builder, oldOperand.getLoc());
983 auto vecTy = getVectorType(scalarTy, state.strategy);
984 auto vecAttr = DenseElementsAttr::get(vecTy, valueAttr);
985 auto newConstOp =
986 state.builder.create<arith::ConstantOp>(oldOperand.getLoc(), vecAttr);
987
988 return newConstOp;
989}
990
991/// Creates a mask used to filter out garbage elements in the last iteration
992/// of unaligned loops. If a mask is not required then `nullptr` is returned.
993/// The mask will be a vector of booleans representing meaningful vector
994/// elements in the current iteration. It is filled with ones for each iteration
995/// except for the last one, where it has the form `11...100...0` with the
996/// number of ones equal to the number of meaningful elements (i.e. the number
997/// of iterations that would be left in the original loop).
998static Value createMask(AffineForOp vecForOp, VectorizationState &state) {
999 assert(state.strategy->vectorSizes.size() == 1 &&(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1000
, __extension__ __PRETTY_FUNCTION__))
1000 "Creating a mask non-1-D vectors is not supported.")(static_cast <bool> (state.strategy->vectorSizes.size
() == 1 && "Creating a mask non-1-D vectors is not supported."
) ? void (0) : __assert_fail ("state.strategy->vectorSizes.size() == 1 && \"Creating a mask non-1-D vectors is not supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1000
, __extension__ __PRETTY_FUNCTION__))
;
1001 assert(vecForOp.getStep() == state.strategy->vectorSizes[0] &&(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1003
, __extension__ __PRETTY_FUNCTION__))
1002 "Creating a mask for loops with non-unit original step size is not "(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1003
, __extension__ __PRETTY_FUNCTION__))
1003 "supported.")(static_cast <bool> (vecForOp.getStep() == state.strategy
->vectorSizes[0] && "Creating a mask for loops with non-unit original step size is not "
"supported.") ? void (0) : __assert_fail ("vecForOp.getStep() == state.strategy->vectorSizes[0] && \"Creating a mask for loops with non-unit original step size is not \" \"supported.\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1003
, __extension__ __PRETTY_FUNCTION__))
;
1004
1005 // Check if we have already created the mask.
1006 if (Value mask = state.vecLoopToMask.lookup(vecForOp))
1007 return mask;
1008
1009 // If the loop has constant bounds and the original number of iterations is
1010 // divisable by the vector size then we don't need a mask.
1011 if (vecForOp.hasConstantBounds()) {
1012 int64_t originalTripCount =
1013 vecForOp.getConstantUpperBound() - vecForOp.getConstantLowerBound();
1014 if (originalTripCount % vecForOp.getStep() == 0)
1015 return nullptr;
1016 }
1017
1018 OpBuilder::InsertionGuard guard(state.builder);
1019 state.builder.setInsertionPointToStart(vecForOp.getBody());
1020
1021 // We generate the mask using the `vector.create_mask` operation which accepts
1022 // the number of meaningful elements (i.e. the length of the prefix of 1s).
1023 // To compute the number of meaningful elements we subtract the current value
1024 // of the iteration variable from the upper bound of the loop. Example:
1025 //
1026 // // 500 is the upper bound of the loop
1027 // #map = affine_map<(d0) -> (500 - d0)>
1028 // %elems_left = affine.apply #map(%iv)
1029 // %mask = vector.create_mask %elems_left : vector<128xi1>
1030
1031 Location loc = vecForOp.getLoc();
1032
1033 // First we get the upper bound of the loop using `affine.apply` or
1034 // `affine.min`.
1035 AffineMap ubMap = vecForOp.getUpperBoundMap();
1036 Value ub;
1037 if (ubMap.getNumResults() == 1)
1038 ub = state.builder.create<AffineApplyOp>(loc, vecForOp.getUpperBoundMap(),
1039 vecForOp.getUpperBoundOperands());
1040 else
1041 ub = state.builder.create<AffineMinOp>(loc, vecForOp.getUpperBoundMap(),
1042 vecForOp.getUpperBoundOperands());
1043 // Then we compute the number of (original) iterations left in the loop.
1044 AffineExpr subExpr =
1045 state.builder.getAffineDimExpr(0) - state.builder.getAffineDimExpr(1);
1046 Value itersLeft =
1047 makeComposedAffineApply(state.builder, loc, AffineMap::get(2, 0, subExpr),
1048 {ub, vecForOp.getInductionVar()});
1049 // If the affine maps were successfully composed then `ub` is unneeded.
1050 if (ub.use_empty())
1051 ub.getDefiningOp()->erase();
1052 // Finally we create the mask.
1053 Type maskTy = VectorType::get(state.strategy->vectorSizes,
1054 state.builder.getIntegerType(1));
1055 Value mask =
1056 state.builder.create<vector::CreateMaskOp>(loc, maskTy, itersLeft);
1057
1058 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a mask:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1059 << itersLeft << "\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
1060 << mask << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a mask:\n"
<< itersLeft << "\n" << mask << "\n"
; } } while (false)
;
1061
1062 state.vecLoopToMask[vecForOp] = mask;
1063 return mask;
1064}
1065
1066/// Returns true if the provided value is vector uniform given the vectorization
1067/// strategy.
1068// TODO: For now, only values that are induction variables of loops not in
1069// `loopToVectorDim` or invariants to all the loops in the vectorization
1070// strategy are considered vector uniforms.
1071static bool isUniformDefinition(Value value,
1072 const VectorizationStrategy *strategy) {
1073 AffineForOp forOp = getForInductionVarOwner(value);
1074 if (forOp && strategy->loopToVectorDim.count(forOp) == 0)
1075 return true;
1076
1077 for (auto loopToDim : strategy->loopToVectorDim) {
1078 auto loop = cast<AffineForOp>(loopToDim.first);
1079 if (!loop.isDefinedOutsideOfLoop(value))
1080 return false;
1081 }
1082 return true;
1083}
1084
1085/// Generates a broadcast op for the provided uniform value using the
1086/// vectorization strategy in 'state'.
1087static Operation *vectorizeUniform(Value uniformVal,
1088 VectorizationState &state) {
1089 OpBuilder::InsertionGuard guard(state.builder);
1090 Value uniformScalarRepl =
1091 state.valueScalarReplacement.lookupOrDefault(uniformVal);
1092 state.builder.setInsertionPointAfterValue(uniformScalarRepl);
1093
1094 auto vectorTy = getVectorType(uniformVal.getType(), state.strategy);
1095 auto bcastOp = state.builder.create<BroadcastOp>(uniformVal.getLoc(),
1096 vectorTy, uniformScalarRepl);
1097 state.registerValueVectorReplacement(uniformVal, bcastOp);
1098 return bcastOp;
1099}
1100
1101/// Tries to vectorize a given `operand` by applying the following logic:
1102/// 1. if the defining operation has been already vectorized, `operand` is
1103/// already in the proper vector form;
1104/// 2. if the `operand` is a constant, returns the vectorized form of the
1105/// constant;
1106/// 3. if the `operand` is uniform, returns a vector broadcast of the `op`;
1107/// 4. otherwise, the vectorization of `operand` is not supported.
1108/// Newly created vector operations are registered in `state` as replacement
1109/// for their scalar counterparts.
1110/// In particular this logic captures some of the use cases where definitions
1111/// that are not scoped under the current pattern are needed to vectorize.
1112/// One such example is top level function constants that need to be splatted.
1113///
1114/// Returns an operand that has been vectorized to match `state`'s strategy if
1115/// vectorization is possible with the above logic. Returns nullptr otherwise.
1116///
1117/// TODO: handle more complex cases.
1118static Value vectorizeOperand(Value operand, VectorizationState &state) {
1119 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorize operand: " << operand)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorize operand: "
<< operand; } } while (false)
;
15
Assuming 'DebugFlag' is false
16
Loop condition is false. Exiting loop
1120 // If this value is already vectorized, we are done.
1121 if (Value vecRepl = state.valueVectorReplacement.lookupOrNull(operand)) {
1122 LLVM_DEBUG(dbgs() << " -> already vectorized: " << vecRepl)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << " -> already vectorized: "
<< vecRepl; } } while (false)
;
1123 return vecRepl;
1124 }
1125
1126 // An vector operand that is not in the replacement map should never reach
1127 // this point. Reaching this point could mean that the code was already
1128 // vectorized and we shouldn't try to vectorize already vectorized code.
1129 assert(!operand.getType().isa<VectorType>() &&(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1130
, __extension__ __PRETTY_FUNCTION__))
17
Taking false branch
18
'?' condition is true
1130 "Vector op not found in replacement map")(static_cast <bool> (!operand.getType().isa<VectorType
>() && "Vector op not found in replacement map") ?
void (0) : __assert_fail ("!operand.getType().isa<VectorType>() && \"Vector op not found in replacement map\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1130
, __extension__ __PRETTY_FUNCTION__))
;
1131
1132 // Vectorize constant.
1133 if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
19
Taking true branch
1134 auto vecConstant = vectorizeConstant(constOp, state);
1135 LLVM_DEBUG(dbgs() << "-> constant: " << vecConstant)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> constant: " <<
vecConstant; } } while (false)
;
20
Assuming 'DebugFlag' is true
21
Assuming the condition is true
22
Taking true branch
23
Null pointer value stored to 'op.state'
24
Calling 'operator<<'
1136 return vecConstant.getResult();
1137 }
1138
1139 // Vectorize uniform values.
1140 if (isUniformDefinition(operand, state.strategy)) {
1141 Operation *vecUniform = vectorizeUniform(operand, state);
1142 LLVM_DEBUG(dbgs() << "-> uniform: " << *vecUniform)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> uniform: " << *
vecUniform; } } while (false)
;
1143 return vecUniform->getResult(0);
1144 }
1145
1146 // Check for unsupported block argument scenarios. A supported block argument
1147 // should have been vectorized already.
1148 if (!operand.getDefiningOp())
1149 LLVM_DEBUG(dbgs() << "-> unsupported block argument\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> unsupported block argument\n"
; } } while (false)
;
1150 else
1151 // Generic unsupported case.
1152 LLVM_DEBUG(dbgs() << "-> non-vectorizable\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "-> non-vectorizable\n";
} } while (false)
;
1153
1154 return nullptr;
1155}
1156
1157/// Vectorizes an affine load with the vectorization strategy in 'state' by
1158/// generating a 'vector.transfer_read' op with the proper permutation map
1159/// inferred from the indices of the load. The new 'vector.transfer_read' is
1160/// registered as replacement of the scalar load. Returns the newly created
1161/// 'vector.transfer_read' if vectorization was successful. Returns nullptr,
1162/// otherwise.
1163static Operation *vectorizeAffineLoad(AffineLoadOp loadOp,
1164 VectorizationState &state) {
1165 MemRefType memRefType = loadOp.getMemRefType();
1166 Type elementType = memRefType.getElementType();
1167 auto vectorType = VectorType::get(state.strategy->vectorSizes, elementType);
1168
1169 // Replace map operands with operands from the vector loop nest.
1170 SmallVector<Value, 8> mapOperands;
1171 state.getScalarValueReplacementsFor(loadOp.getMapOperands(), mapOperands);
1172
1173 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1174 SmallVector<Value, 8> indices;
1175 indices.reserve(memRefType.getRank());
1176 if (loadOp.getAffineMap() !=
1177 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1178 computeMemoryOpIndices(loadOp, loadOp.getAffineMap(), mapOperands, state,
1179 indices);
1180 else
1181 indices.append(mapOperands.begin(), mapOperands.end());
1182
1183 // Compute permutation map using the information of new vector loops.
1184 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1185 indices, state.vecLoopToVecDim);
1186 if (!permutationMap) {
1187 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ can't compute permutationMap\n"
; } } while (false)
;
1188 return nullptr;
1189 }
1190 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1191 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1192
1193 auto transfer = state.builder.create<vector::TransferReadOp>(
1194 loadOp.getLoc(), vectorType, loadOp.getMemRef(), indices, permutationMap);
1195
1196 // Register replacement for future uses in the scope.
1197 state.registerOpVectorReplacement(loadOp, transfer);
1198 return transfer;
1199}
1200
1201/// Vectorizes an affine store with the vectorization strategy in 'state' by
1202/// generating a 'vector.transfer_write' op with the proper permutation map
1203/// inferred from the indices of the store. The new 'vector.transfer_store' is
1204/// registered as replacement of the scalar load. Returns the newly created
1205/// 'vector.transfer_write' if vectorization was successful. Returns nullptr,
1206/// otherwise.
1207static Operation *vectorizeAffineStore(AffineStoreOp storeOp,
1208 VectorizationState &state) {
1209 MemRefType memRefType = storeOp.getMemRefType();
1210 Value vectorValue = vectorizeOperand(storeOp.getValueToStore(), state);
1211 if (!vectorValue)
1212 return nullptr;
1213
1214 // Replace map operands with operands from the vector loop nest.
1215 SmallVector<Value, 8> mapOperands;
1216 state.getScalarValueReplacementsFor(storeOp.getMapOperands(), mapOperands);
1217
1218 // Compute indices for the transfer op. AffineApplyOp's may be generated.
1219 SmallVector<Value, 8> indices;
1220 indices.reserve(memRefType.getRank());
1221 if (storeOp.getAffineMap() !=
1222 state.builder.getMultiDimIdentityMap(memRefType.getRank()))
1223 computeMemoryOpIndices(storeOp, storeOp.getAffineMap(), mapOperands, state,
1224 indices);
1225 else
1226 indices.append(mapOperands.begin(), mapOperands.end());
1227
1228 // Compute permutation map using the information of new vector loops.
1229 auto permutationMap = makePermutationMap(state.builder.getInsertionBlock(),
1230 indices, state.vecLoopToVecDim);
1231 if (!permutationMap)
1232 return nullptr;
1233 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ permutationMap: ")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ permutationMap: "
; } } while (false)
;
1234 LLVM_DEBUG(permutationMap.print(dbgs()))do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { permutationMap.print(dbgs()); } } while (false
)
;
1235
1236 auto transfer = state.builder.create<vector::TransferWriteOp>(
1237 storeOp.getLoc(), vectorValue, storeOp.getMemRef(), indices,
1238 permutationMap);
1239 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorized store: " << transfer)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorized store: "
<< transfer; } } while (false)
;
1240
1241 // Register replacement for future uses in the scope.
1242 state.registerOpVectorReplacement(storeOp, transfer);
1243 return transfer;
1244}
1245
1246/// Returns true if `value` is a constant equal to the neutral element of the
1247/// given vectorizable reduction.
1248static bool isNeutralElementConst(arith::AtomicRMWKind reductionKind,
1249 Value value, VectorizationState &state) {
1250 Type scalarTy = value.getType();
1251 if (!VectorType::isValidElementType(scalarTy))
1252 return false;
1253 Attribute valueAttr = getIdentityValueAttr(reductionKind, scalarTy,
1254 state.builder, value.getLoc());
1255 if (auto constOp = dyn_cast_or_null<arith::ConstantOp>(value.getDefiningOp()))
1256 return constOp.getValue() == valueAttr;
1257 return false;
1258}
1259
1260/// Vectorizes a loop with the vectorization strategy in 'state'. A new loop is
1261/// created and registered as replacement for the scalar loop. The builder's
1262/// insertion point is set to the new loop's body so that subsequent vectorized
1263/// operations are inserted into the new loop. If the loop is a vector
1264/// dimension, the step of the newly created loop will reflect the vectorization
1265/// factor used to vectorized that dimension.
1266static Operation *vectorizeAffineForOp(AffineForOp forOp,
1267 VectorizationState &state) {
1268 const VectorizationStrategy &strategy = *state.strategy;
1269 auto loopToVecDimIt = strategy.loopToVectorDim.find(forOp);
1270 bool isLoopVecDim = loopToVecDimIt != strategy.loopToVectorDim.end();
1271
1272 // TODO: Vectorization of reduction loops is not supported for non-unit steps.
1273 if (isLoopVecDim
11.1
'isLoopVecDim' is false
11.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0 && forOp.getStep() != 1) {
1274 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1275 dbgs()do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1276 << "\n[early-vect]+++++ unsupported step size for reduction loop: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
1277 << forOp.getStep() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ unsupported step size for reduction loop: "
<< forOp.getStep() << "\n"; } } while (false)
;
1278 return nullptr;
1279 }
1280
1281 // If we are vectorizing a vector dimension, compute a new step for the new
1282 // vectorized loop using the vectorization factor for the vector dimension.
1283 // Otherwise, propagate the step of the scalar loop.
1284 unsigned newStep;
1285 if (isLoopVecDim
11.2
'isLoopVecDim' is false
11.2
'isLoopVecDim' is false
) {
12
Taking false branch
1286 unsigned vectorDim = loopToVecDimIt->second;
1287 assert(vectorDim < strategy.vectorSizes.size() && "vector dim overflow")(static_cast <bool> (vectorDim < strategy.vectorSizes
.size() && "vector dim overflow") ? void (0) : __assert_fail
("vectorDim < strategy.vectorSizes.size() && \"vector dim overflow\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1287
, __extension__ __PRETTY_FUNCTION__))
;
1288 int64_t forOpVecFactor = strategy.vectorSizes[vectorDim];
1289 newStep = forOp.getStep() * forOpVecFactor;
1290 } else {
1291 newStep = forOp.getStep();
1292 }
1293
1294 // Get information about reduction kinds.
1295 ArrayRef<LoopReduction> reductions;
1296 if (isLoopVecDim
12.1
'isLoopVecDim' is false
12.1
'isLoopVecDim' is false
&& forOp.getNumIterOperands() > 0) {
1297 auto it = strategy.reductionLoops.find(forOp);
1298 assert(it != strategy.reductionLoops.end() &&(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1299
, __extension__ __PRETTY_FUNCTION__))
1299 "Reduction descriptors not found when vectorizing a reduction loop")(static_cast <bool> (it != strategy.reductionLoops.end(
) && "Reduction descriptors not found when vectorizing a reduction loop"
) ? void (0) : __assert_fail ("it != strategy.reductionLoops.end() && \"Reduction descriptors not found when vectorizing a reduction loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1299
, __extension__ __PRETTY_FUNCTION__))
;
1300 reductions = it->second;
1301 assert(reductions.size() == forOp.getNumIterOperands() &&(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1302
, __extension__ __PRETTY_FUNCTION__))
1302 "The size of reductions array must match the number of iter_args")(static_cast <bool> (reductions.size() == forOp.getNumIterOperands
() && "The size of reductions array must match the number of iter_args"
) ? void (0) : __assert_fail ("reductions.size() == forOp.getNumIterOperands() && \"The size of reductions array must match the number of iter_args\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1302
, __extension__ __PRETTY_FUNCTION__))
;
1303 }
1304
1305 // Vectorize 'iter_args'.
1306 SmallVector<Value, 8> vecIterOperands;
1307 if (!isLoopVecDim
12.2
'isLoopVecDim' is false
12.2
'isLoopVecDim' is false
) {
13
Taking true branch
1308 for (auto operand : forOp.getIterOperands())
1309 vecIterOperands.push_back(vectorizeOperand(operand, state));
14
Calling 'vectorizeOperand'
1310 } else {
1311 // For reduction loops we need to pass a vector of neutral elements as an
1312 // initial value of the accumulator. We will add the original initial value
1313 // later.
1314 for (auto redAndOperand : llvm::zip(reductions, forOp.getIterOperands())) {
1315 vecIterOperands.push_back(createInitialVector(
1316 std::get<0>(redAndOperand).kind, std::get<1>(redAndOperand), state));
1317 }
1318 }
1319
1320 auto vecForOp = state.builder.create<AffineForOp>(
1321 forOp.getLoc(), forOp.getLowerBoundOperands(), forOp.getLowerBoundMap(),
1322 forOp.getUpperBoundOperands(), forOp.getUpperBoundMap(), newStep,
1323 vecIterOperands,
1324 /*bodyBuilder=*/[](OpBuilder &, Location, Value, ValueRange) {
1325 // Make sure we don't create a default terminator in the loop body as
1326 // the proper terminator will be added during vectorization.
1327 });
1328
1329 // Register loop-related replacements:
1330 // 1) The new vectorized loop is registered as vector replacement of the
1331 // scalar loop.
1332 // 2) The new iv of the vectorized loop is registered as scalar replacement
1333 // since a scalar copy of the iv will prevail in the vectorized loop.
1334 // TODO: A vector replacement will also be added in the future when
1335 // vectorization of linear ops is supported.
1336 // 3) The new 'iter_args' region arguments are registered as vector
1337 // replacements since they have been vectorized.
1338 // 4) If the loop performs a reduction along the vector dimension, a
1339 // `vector.reduction` or similar op is inserted for each resulting value
1340 // of the loop and its scalar value replaces the corresponding scalar
1341 // result of the loop.
1342 state.registerOpVectorReplacement(forOp, vecForOp);
1343 state.registerValueScalarReplacement(forOp.getInductionVar(),
1344 vecForOp.getInductionVar());
1345 for (auto iterTuple :
1346 llvm ::zip(forOp.getRegionIterArgs(), vecForOp.getRegionIterArgs()))
1347 state.registerBlockArgVectorReplacement(std::get<0>(iterTuple),
1348 std::get<1>(iterTuple));
1349
1350 if (isLoopVecDim) {
1351 for (unsigned i = 0; i < vecForOp.getNumIterOperands(); ++i) {
1352 // First, we reduce the vector returned from the loop into a scalar.
1353 Value reducedRes =
1354 getVectorReductionOp(reductions[i].kind, state.builder,
1355 vecForOp.getLoc(), vecForOp.getResult(i));
1356 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ creating a vector reduction: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
1357 << reducedRes)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ creating a vector reduction: "
<< reducedRes; } } while (false)
;
1358 // Then we combine it with the original (scalar) initial value unless it
1359 // is equal to the neutral element of the reduction.
1360 Value origInit = forOp.getOperand(forOp.getNumControlOperands() + i);
1361 Value finalRes = reducedRes;
1362 if (!isNeutralElementConst(reductions[i].kind, origInit, state))
1363 finalRes =
1364 arith::getReductionOp(reductions[i].kind, state.builder,
1365 reducedRes.getLoc(), reducedRes, origInit);
1366 state.registerLoopResultScalarReplacement(forOp.getResult(i), finalRes);
1367 }
1368 }
1369
1370 if (isLoopVecDim)
1371 state.vecLoopToVecDim[vecForOp] = loopToVecDimIt->second;
1372
1373 // Change insertion point so that upcoming vectorized instructions are
1374 // inserted into the vectorized loop's body.
1375 state.builder.setInsertionPointToStart(vecForOp.getBody());
1376
1377 // If this is a reduction loop then we may need to create a mask to filter out
1378 // garbage in the last iteration.
1379 if (isLoopVecDim && forOp.getNumIterOperands() > 0)
1380 createMask(vecForOp, state);
1381
1382 return vecForOp;
1383}
1384
1385/// Vectorizes arbitrary operation by plain widening. We apply generic type
1386/// widening of all its results and retrieve the vector counterparts for all its
1387/// operands.
1388static Operation *widenOp(Operation *op, VectorizationState &state) {
1389 SmallVector<Type, 8> vectorTypes;
1390 for (Value result : op->getResults())
1391 vectorTypes.push_back(
1392 VectorType::get(state.strategy->vectorSizes, result.getType()));
1393
1394 SmallVector<Value, 8> vectorOperands;
1395 for (Value operand : op->getOperands()) {
1396 Value vecOperand = vectorizeOperand(operand, state);
1397 if (!vecOperand) {
1398 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ an operand failed vectorize\n"
; } } while (false)
;
1399 return nullptr;
1400 }
1401 vectorOperands.push_back(vecOperand);
1402 }
1403
1404 // Create a clone of the op with the proper operands and return types.
1405 // TODO: The following assumes there is always an op with a fixed
1406 // name that works both in scalar mode and vector mode.
1407 // TODO: Is it worth considering an Operation.clone operation which
1408 // changes the type so we can promote an Operation with less boilerplate?
1409 OperationState vecOpState(op->getLoc(), op->getName(), vectorOperands,
1410 vectorTypes, op->getAttrs(), /*successors=*/{},
1411 /*regions=*/{});
1412 Operation *vecOp = state.builder.createOperation(vecOpState);
1413 state.registerOpVectorReplacement(op, vecOp);
1414 return vecOp;
1415}
1416
1417/// Vectorizes a yield operation by widening its types. The builder's insertion
1418/// point is set after the vectorized parent op to continue vectorizing the
1419/// operations after the parent op. When vectorizing a reduction loop a mask may
1420/// be used to prevent adding garbage values to the accumulator.
1421static Operation *vectorizeAffineYieldOp(AffineYieldOp yieldOp,
1422 VectorizationState &state) {
1423 Operation *newYieldOp = widenOp(yieldOp, state);
1424 Operation *newParentOp = state.builder.getInsertionBlock()->getParentOp();
1425
1426 // If there is a mask for this loop then we must prevent garbage values from
1427 // being added to the accumulator by inserting `select` operations, for
1428 // example:
1429 //
1430 // %res = arith.addf %acc, %val : vector<128xf32>
1431 // %res_masked = select %mask, %res, %acc : vector<128xi1>, vector<128xf32>
1432 // affine.yield %res_masked : vector<128xf32>
1433 //
1434 if (Value mask = state.vecLoopToMask.lookup(newParentOp)) {
1435 state.builder.setInsertionPoint(newYieldOp);
1436 for (unsigned i = 0; i < newYieldOp->getNumOperands(); ++i) {
1437 Value result = newYieldOp->getOperand(i);
1438 Value iterArg = cast<AffineForOp>(newParentOp).getRegionIterArgs()[i];
1439 Value maskedResult = state.builder.create<SelectOp>(result.getLoc(), mask,
1440 result, iterArg);
1441 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
1442 dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
1443 << maskedResult)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ masking a yielded vector value: "
<< maskedResult; } } while (false)
;
1444 newYieldOp->setOperand(i, maskedResult);
1445 }
1446 }
1447
1448 state.builder.setInsertionPointAfter(newParentOp);
1449 return newYieldOp;
1450}
1451
1452/// Encodes Operation-specific behavior for vectorization. In general we
1453/// assume that all operands of an op must be vectorized but this is not
1454/// always true. In the future, it would be nice to have a trait that
1455/// describes how a particular operation vectorizes. For now we implement the
1456/// case distinction here. Returns a vectorized form of an operation or
1457/// nullptr if vectorization fails.
1458// TODO: consider adding a trait to Op to describe how it gets vectorized.
1459// Maybe some Ops are not vectorizable or require some tricky logic, we cannot
1460// do one-off logic here; ideally it would be TableGen'd.
1461static Operation *vectorizeOneOperation(Operation *op,
1462 VectorizationState &state) {
1463 // Sanity checks.
1464 assert(!isa<vector::TransferReadOp>(op) &&(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1465
, __extension__ __PRETTY_FUNCTION__))
4
Assuming 'op' is not a 'TransferReadOp'
5
'?' condition is true
1465 "vector.transfer_read cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferReadOp>
(op) && "vector.transfer_read cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferReadOp>(op) && \"vector.transfer_read cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1465
, __extension__ __PRETTY_FUNCTION__))
;
1466 assert(!isa<vector::TransferWriteOp>(op) &&(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1467
, __extension__ __PRETTY_FUNCTION__))
6
Assuming 'op' is not a 'TransferWriteOp'
7
'?' condition is true
1467 "vector.transfer_write cannot be further vectorized")(static_cast <bool> (!isa<vector::TransferWriteOp>
(op) && "vector.transfer_write cannot be further vectorized"
) ? void (0) : __assert_fail ("!isa<vector::TransferWriteOp>(op) && \"vector.transfer_write cannot be further vectorized\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1467
, __extension__ __PRETTY_FUNCTION__))
;
1468
1469 if (auto loadOp = dyn_cast<AffineLoadOp>(op))
8
Taking false branch
1470 return vectorizeAffineLoad(loadOp, state);
1471 if (auto storeOp = dyn_cast<AffineStoreOp>(op))
9
Taking false branch
1472 return vectorizeAffineStore(storeOp, state);
1473 if (auto forOp = dyn_cast<AffineForOp>(op))
10
Taking true branch
1474 return vectorizeAffineForOp(forOp, state);
11
Calling 'vectorizeAffineForOp'
1475 if (auto yieldOp = dyn_cast<AffineYieldOp>(op))
1476 return vectorizeAffineYieldOp(yieldOp, state);
1477 if (auto constant = dyn_cast<arith::ConstantOp>(op))
1478 return vectorizeConstant(constant, state);
1479
1480 // Other ops with regions are not supported.
1481 if (op->getNumRegions() != 0)
1482 return nullptr;
1483
1484 return widenOp(op, state);
1485}
1486
1487/// Recursive implementation to convert all the nested loops in 'match' to a 2D
1488/// vector container that preserves the relative nesting level of each loop with
1489/// respect to the others in 'match'. 'currentLevel' is the nesting level that
1490/// will be assigned to the loop in the current 'match'.
1491static void
1492getMatchedAffineLoopsRec(NestedMatch match, unsigned currentLevel,
1493 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1494 // Add a new empty level to the output if it doesn't exist already.
1495 assert(currentLevel <= loops.size() && "Unexpected currentLevel")(static_cast <bool> (currentLevel <= loops.size() &&
"Unexpected currentLevel") ? void (0) : __assert_fail ("currentLevel <= loops.size() && \"Unexpected currentLevel\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1495
, __extension__ __PRETTY_FUNCTION__))
;
1496 if (currentLevel == loops.size())
1497 loops.emplace_back();
1498
1499 // Add current match and recursively visit its children.
1500 loops[currentLevel].push_back(cast<AffineForOp>(match.getMatchedOperation()));
1501 for (auto childMatch : match.getMatchedChildren()) {
1502 getMatchedAffineLoopsRec(childMatch, currentLevel + 1, loops);
1503 }
1504}
1505
1506/// Converts all the nested loops in 'match' to a 2D vector container that
1507/// preserves the relative nesting level of each loop with respect to the others
1508/// in 'match'. This means that every loop in 'loops[i]' will have a parent loop
1509/// in 'loops[i-1]'. A loop in 'loops[i]' may or may not have a child loop in
1510/// 'loops[i+1]'.
1511static void
1512getMatchedAffineLoops(NestedMatch match,
1513 std::vector<SmallVector<AffineForOp, 2>> &loops) {
1514 getMatchedAffineLoopsRec(match, /*currLoopDepth=*/0, loops);
1515}
1516
1517/// Internal implementation to vectorize affine loops from a single loop nest
1518/// using an n-D vectorization strategy.
1519static LogicalResult
1520vectorizeLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1521 const VectorizationStrategy &strategy) {
1522 assert(loops[0].size() == 1 && "Expected single root loop")(static_cast <bool> (loops[0].size() == 1 && "Expected single root loop"
) ? void (0) : __assert_fail ("loops[0].size() == 1 && \"Expected single root loop\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1522
, __extension__ __PRETTY_FUNCTION__))
;
1523 AffineForOp rootLoop = loops[0][0];
1524 VectorizationState state(rootLoop.getContext());
1525 state.builder.setInsertionPointAfter(rootLoop);
1526 state.strategy = &strategy;
1527
1528 // Since patterns are recursive, they can very well intersect.
1529 // Since we do not want a fully greedy strategy in general, we decouple
1530 // pattern matching, from profitability analysis, from application.
1531 // As a consequence we must check that each root pattern is still
1532 // vectorizable. If a pattern is not vectorizable anymore, we just skip it.
1533 // TODO: implement a non-greedy profitability analysis that keeps only
1534 // non-intersecting patterns.
1535 if (!isVectorizableLoopBody(rootLoop, vectorTransferPattern())) {
1536 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ loop is not vectorizable")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ loop is not vectorizable"
; } } while (false)
;
1537 return failure();
1538 }
1539
1540 //////////////////////////////////////////////////////////////////////////////
1541 // Vectorize the scalar loop nest following a topological order. A new vector
1542 // loop nest with the vectorized operations is created along the process. If
1543 // vectorization succeeds, the scalar loop nest is erased. If vectorization
1544 // fails, the vector loop nest is erased and the scalar loop nest is not
1545 // modified.
1546 //////////////////////////////////////////////////////////////////////////////
1547
1548 auto opVecResult = rootLoop.walk<WalkOrder::PreOrder>([&](Operation *op) {
1549 LLVM_DEBUG(dbgs() << "[early-vect]+++++ Vectorizing: " << *op)do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ Vectorizing: "
<< *op; } } while (false)
;
1
Assuming 'DebugFlag' is false
2
Loop condition is false. Exiting loop
1550 Operation *vectorOp = vectorizeOneOperation(op, state);
3
Calling 'vectorizeOneOperation'
1551 if (!vectorOp) {
1552 LLVM_DEBUG(do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1553 dbgs() << "[early-vect]+++++ failed vectorizing the operation: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
1554 << *op << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorizing the operation: "
<< *op << "\n"; } } while (false)
;
1555 return WalkResult::interrupt();
1556 }
1557
1558 return WalkResult::advance();
1559 });
1560
1561 if (opVecResult.wasInterrupted()) {
1562 LLVM_DEBUG(dbgs() << "[early-vect]+++++ failed vectorization for: "do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
1563 << rootLoop << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "[early-vect]+++++ failed vectorization for: "
<< rootLoop << "\n"; } } while (false)
;
1564 // Erase vector loop nest if it was created.
1565 auto vecRootLoopIt = state.opVectorReplacement.find(rootLoop);
1566 if (vecRootLoopIt != state.opVectorReplacement.end())
1567 eraseLoopNest(cast<AffineForOp>(vecRootLoopIt->second));
1568
1569 return failure();
1570 }
1571
1572 // Replace results of reduction loops with the scalar values computed using
1573 // `vector.reduce` or similar ops.
1574 for (auto resPair : state.loopResultScalarReplacement)
1575 resPair.first.replaceAllUsesWith(resPair.second);
1576
1577 assert(state.opVectorReplacement.count(rootLoop) == 1 &&(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1578
, __extension__ __PRETTY_FUNCTION__))
1578 "Expected vector replacement for loop nest")(static_cast <bool> (state.opVectorReplacement.count(rootLoop
) == 1 && "Expected vector replacement for loop nest"
) ? void (0) : __assert_fail ("state.opVectorReplacement.count(rootLoop) == 1 && \"Expected vector replacement for loop nest\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1578
, __extension__ __PRETTY_FUNCTION__))
;
1579 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ success vectorizing pattern")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ success vectorizing pattern"
; } } while (false)
;
1580 LLVM_DEBUG(dbgs() << "\n[early-vect]+++++ vectorization result:\n"do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
1581 << *state.opVectorReplacement[rootLoop])do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect]+++++ vectorization result:\n"
<< *state.opVectorReplacement[rootLoop]; } } while (false
)
;
1582
1583 // Finish this vectorization pattern.
1584 state.finishVectorizationPattern(rootLoop);
1585 return success();
1586}
1587
1588/// Extracts the matched loops and vectorizes them following a topological
1589/// order. A new vector loop nest will be created if vectorization succeeds. The
1590/// original loop nest won't be modified in any case.
1591static LogicalResult vectorizeRootMatch(NestedMatch m,
1592 const VectorizationStrategy &strategy) {
1593 std::vector<SmallVector<AffineForOp, 2>> loopsToVectorize;
1594 getMatchedAffineLoops(m, loopsToVectorize);
1595 return vectorizeLoopNest(loopsToVectorize, strategy);
1596}
1597
1598/// Traverses all the loop matches and classifies them into intersection
1599/// buckets. Two matches intersect if any of them encloses the other one. A
1600/// match intersects with a bucket if the match intersects with the root
1601/// (outermost) loop in that bucket.
1602static void computeIntersectionBuckets(
1603 ArrayRef<NestedMatch> matches,
1604 std::vector<SmallVector<NestedMatch, 8>> &intersectionBuckets) {
1605 assert(intersectionBuckets.empty() && "Expected empty output")(static_cast <bool> (intersectionBuckets.empty() &&
"Expected empty output") ? void (0) : __assert_fail ("intersectionBuckets.empty() && \"Expected empty output\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1605
, __extension__ __PRETTY_FUNCTION__))
;
1606 // Keeps track of the root (outermost) loop of each bucket.
1607 SmallVector<AffineForOp, 8> bucketRoots;
1608
1609 for (const NestedMatch &match : matches) {
1610 AffineForOp matchRoot = cast<AffineForOp>(match.getMatchedOperation());
1611 bool intersects = false;
1612 for (int i = 0, end = intersectionBuckets.size(); i < end; ++i) {
1613 AffineForOp bucketRoot = bucketRoots[i];
1614 // Add match to the bucket if the bucket root encloses the match root.
1615 if (bucketRoot->isAncestor(matchRoot)) {
1616 intersectionBuckets[i].push_back(match);
1617 intersects = true;
1618 break;
1619 }
1620 // Add match to the bucket if the match root encloses the bucket root. The
1621 // match root becomes the new bucket root.
1622 if (matchRoot->isAncestor(bucketRoot)) {
1623 bucketRoots[i] = matchRoot;
1624 intersectionBuckets[i].push_back(match);
1625 intersects = true;
1626 break;
1627 }
1628 }
1629
1630 // Match doesn't intersect with any existing bucket. Create a new bucket for
1631 // it.
1632 if (!intersects) {
1633 bucketRoots.push_back(matchRoot);
1634 intersectionBuckets.emplace_back();
1635 intersectionBuckets.back().push_back(match);
1636 }
1637 }
1638}
1639
1640/// Internal implementation to vectorize affine loops in 'loops' using the n-D
1641/// vectorization factors in 'vectorSizes'. By default, each vectorization
1642/// factor is applied inner-to-outer to the loops of each loop nest.
1643/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1644/// vectorization order. `reductionLoops` can be provided to specify loops which
1645/// can be vectorized along the reduction dimension.
1646static void vectorizeLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1647 ArrayRef<int64_t> vectorSizes,
1648 ArrayRef<int64_t> fastestVaryingPattern,
1649 const ReductionLoopMap &reductionLoops) {
1650 assert((reductionLoops.empty() || vectorSizes.size() == 1) &&(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1651
, __extension__ __PRETTY_FUNCTION__))
1651 "Vectorizing reductions is supported only for 1-D vectors")(static_cast <bool> ((reductionLoops.empty() || vectorSizes
.size() == 1) && "Vectorizing reductions is supported only for 1-D vectors"
) ? void (0) : __assert_fail ("(reductionLoops.empty() || vectorSizes.size() == 1) && \"Vectorizing reductions is supported only for 1-D vectors\""
, "mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp", 1651
, __extension__ __PRETTY_FUNCTION__))
;
1652
1653 // Compute 1-D, 2-D or 3-D loop pattern to be matched on the target loops.
1654 Optional<NestedPattern> pattern =
1655 makePattern(loops, vectorSizes.size(), fastestVaryingPattern);
1656 if (!pattern.hasValue()) {
1657 LLVM_DEBUG(dbgs() << "\n[early-vect] pattern couldn't be computed\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] pattern couldn't be computed\n"
; } } while (false)
;
1658 return;
1659 }
1660
1661 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1662 LLVM_DEBUG(dbgs() << "\n******************************************")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n******************************************"
; } } while (false)
;
1663 LLVM_DEBUG(dbgs() << "\n[early-vect] new pattern on parent op\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n[early-vect] new pattern on parent op\n"
; } } while (false)
;
1664 LLVM_DEBUG(dbgs() << *parentOp << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << *parentOp << "\n"; } }
while (false)
;
1665
1666 unsigned patternDepth = pattern->getDepth();
1667
1668 // Compute all the pattern matches and classify them into buckets of
1669 // intersecting matches.
1670 SmallVector<NestedMatch, 32> allMatches;
1671 pattern->match(parentOp, &allMatches);
1672 std::vector<SmallVector<NestedMatch, 8>> intersectionBuckets;
1673 computeIntersectionBuckets(allMatches, intersectionBuckets);
1674
1675 // Iterate over all buckets and vectorize the matches eagerly. We can only
1676 // vectorize one match from each bucket since all the matches within a bucket
1677 // intersect.
1678 for (auto &intersectingMatches : intersectionBuckets) {
1679 for (NestedMatch &match : intersectingMatches) {
1680 VectorizationStrategy strategy;
1681 // TODO: depending on profitability, elect to reduce the vector size.
1682 strategy.vectorSizes.assign(vectorSizes.begin(), vectorSizes.end());
1683 strategy.reductionLoops = reductionLoops;
1684 if (failed(analyzeProfitability(match.getMatchedChildren(), 1,
1685 patternDepth, &strategy))) {
1686 continue;
1687 }
1688 vectorizeLoopIfProfitable(match.getMatchedOperation(), 0, patternDepth,
1689 &strategy);
1690 // Vectorize match. Skip the rest of intersecting matches in the bucket if
1691 // vectorization succeeded.
1692 // TODO: if pattern does not apply, report it; alter the cost/benefit.
1693 // TODO: some diagnostics if failure to vectorize occurs.
1694 if (succeeded(vectorizeRootMatch(match, strategy)))
1695 break;
1696 }
1697 }
1698
1699 LLVM_DEBUG(dbgs() << "\n")do { if (::llvm::DebugFlag && ::llvm::isCurrentDebugType
("early-vect")) { dbgs() << "\n"; } } while (false)
;
1700}
1701
1702std::unique_ptr<OperationPass<FuncOp>>
1703createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) {
1704 return std::make_unique<Vectorize>(virtualVectorSize);
1705}
1706std::unique_ptr<OperationPass<FuncOp>> createSuperVectorizePass() {
1707 return std::make_unique<Vectorize>();
1708}
1709
1710/// Applies vectorization to the current function by searching over a bunch of
1711/// predetermined patterns.
1712void Vectorize::runOnOperation() {
1713 FuncOp f = getOperation();
1714 if (!fastestVaryingPattern.empty() &&
1715 fastestVaryingPattern.size() != vectorSizes.size()) {
1716 f.emitRemark("Fastest varying pattern specified with different size than "
1717 "the vector size.");
1718 return signalPassFailure();
1719 }
1720
1721 if (vectorizeReductions && vectorSizes.size() != 1) {
1722 f.emitError("Vectorizing reductions is supported only for 1-D vectors.");
1723 return signalPassFailure();
1724 }
1725
1726 DenseSet<Operation *> parallelLoops;
1727 ReductionLoopMap reductionLoops;
1728
1729 // If 'vectorize-reduction=true' is provided, we also populate the
1730 // `reductionLoops` map.
1731 if (vectorizeReductions) {
1732 f.walk([&parallelLoops, &reductionLoops](AffineForOp loop) {
1733 SmallVector<LoopReduction, 2> reductions;
1734 if (isLoopParallel(loop, &reductions)) {
1735 parallelLoops.insert(loop);
1736 // If it's not a reduction loop, adding it to the map is not necessary.
1737 if (!reductions.empty())
1738 reductionLoops[loop] = reductions;
1739 }
1740 });
1741 } else {
1742 f.walk([&parallelLoops](AffineForOp loop) {
1743 if (isLoopParallel(loop))
1744 parallelLoops.insert(loop);
1745 });
1746 }
1747
1748 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1749 NestedPatternContext mlContext;
1750 vectorizeLoops(f, parallelLoops, vectorSizes, fastestVaryingPattern,
1751 reductionLoops);
1752}
1753
1754/// Verify that affine loops in 'loops' meet the nesting criteria expected by
1755/// SuperVectorizer:
1756/// * There must be at least one loop.
1757/// * There must be a single root loop (nesting level 0).
1758/// * Each loop at a given nesting level must be nested in a loop from a
1759/// previous nesting level.
1760static LogicalResult
1761verifyLoopNesting(const std::vector<SmallVector<AffineForOp, 2>> &loops) {
1762 // Expected at least one loop.
1763 if (loops.empty())
1764 return failure();
1765
1766 // Expected only one root loop.
1767 if (loops[0].size() != 1)
1768 return failure();
1769
1770 // Traverse loops outer-to-inner to check some invariants.
1771 for (int i = 1, end = loops.size(); i < end; ++i) {
1772 for (AffineForOp loop : loops[i]) {
1773 // Check that each loop at this level is nested in one of the loops from
1774 // the previous level.
1775 if (none_of(loops[i - 1], [&](AffineForOp maybeParent) {
1776 return maybeParent->isProperAncestor(loop);
1777 }))
1778 return failure();
1779
1780 // Check that each loop at this level is not nested in another loop from
1781 // this level.
1782 for (AffineForOp sibling : loops[i]) {
1783 if (sibling->isProperAncestor(loop))
1784 return failure();
1785 }
1786 }
1787 }
1788
1789 return success();
1790}
1791
1792namespace mlir {
1793
1794/// External utility to vectorize affine loops in 'loops' using the n-D
1795/// vectorization factors in 'vectorSizes'. By default, each vectorization
1796/// factor is applied inner-to-outer to the loops of each loop nest.
1797/// 'fastestVaryingPattern' can be optionally used to provide a different loop
1798/// vectorization order.
1799/// If `reductionLoops` is not empty, the given reduction loops may be
1800/// vectorized along the reduction dimension.
1801/// TODO: Vectorizing reductions is supported only for 1-D vectorization.
1802void vectorizeAffineLoops(Operation *parentOp, DenseSet<Operation *> &loops,
1803 ArrayRef<int64_t> vectorSizes,
1804 ArrayRef<int64_t> fastestVaryingPattern,
1805 const ReductionLoopMap &reductionLoops) {
1806 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1807 NestedPatternContext mlContext;
1808 vectorizeLoops(parentOp, loops, vectorSizes, fastestVaryingPattern,
1809 reductionLoops);
1810}
1811
1812/// External utility to vectorize affine loops from a single loop nest using an
1813/// n-D vectorization strategy (see doc in VectorizationStrategy definition).
1814/// Loops are provided in a 2D vector container. The first dimension represents
1815/// the nesting level relative to the loops to be vectorized. The second
1816/// dimension contains the loops. This means that:
1817/// a) every loop in 'loops[i]' must have a parent loop in 'loops[i-1]',
1818/// b) a loop in 'loops[i]' may or may not have a child loop in 'loops[i+1]'.
1819///
1820/// For example, for the following loop nest:
1821///
1822/// func @vec2d(%in0: memref<64x128x512xf32>, %in1: memref<64x128x128xf32>,
1823/// %out0: memref<64x128x512xf32>,
1824/// %out1: memref<64x128x128xf32>) {
1825/// affine.for %i0 = 0 to 64 {
1826/// affine.for %i1 = 0 to 128 {
1827/// affine.for %i2 = 0 to 512 {
1828/// %ld = affine.load %in0[%i0, %i1, %i2] : memref<64x128x512xf32>
1829/// affine.store %ld, %out0[%i0, %i1, %i2] : memref<64x128x512xf32>
1830/// }
1831/// affine.for %i3 = 0 to 128 {
1832/// %ld = affine.load %in1[%i0, %i1, %i3] : memref<64x128x128xf32>
1833/// affine.store %ld, %out1[%i0, %i1, %i3] : memref<64x128x128xf32>
1834/// }
1835/// }
1836/// }
1837/// return
1838/// }
1839///
1840/// loops = {{%i0}, {%i2, %i3}}, to vectorize the outermost and the two
1841/// innermost loops;
1842/// loops = {{%i1}, {%i2, %i3}}, to vectorize the middle and the two innermost
1843/// loops;
1844/// loops = {{%i2}}, to vectorize only the first innermost loop;
1845/// loops = {{%i3}}, to vectorize only the second innermost loop;
1846/// loops = {{%i1}}, to vectorize only the middle loop.
1847LogicalResult
1848vectorizeAffineLoopNest(std::vector<SmallVector<AffineForOp, 2>> &loops,
1849 const VectorizationStrategy &strategy) {
1850 // Thread-safe RAII local context, BumpPtrAllocator freed on exit.
1851 NestedPatternContext mlContext;
1852 if (failed(verifyLoopNesting(loops)))
1853 return failure();
1854 return vectorizeLoopNest(loops, strategy);
1855}
1856
1857std::unique_ptr<OperationPass<FuncOp>>
1858createSuperVectorizePass(ArrayRef<int64_t> virtualVectorSize) {
1859 return std::make_unique<Vectorize>(virtualVectorSize);
1860}
1861std::unique_ptr<OperationPass<FuncOp>> createSuperVectorizePass() {
1862 return std::make_unique<Vectorize>();
1863}
1864
1865} // namespace mlir

/build/llvm-toolchain-snapshot-14~++20220119111520+da61cb019eb2/mlir/include/mlir/IR/OpDefinition.h

1//===- OpDefinition.h - Classes for defining concrete Op types --*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements helper classes for implementing the "Op" types. This
10// includes the Op type, which is the base class for Op class definitions,
11// as well as number of traits in the OpTrait namespace that provide a
12// declarative way to specify properties of Ops.
13//
14// The purpose of these types are to allow light-weight implementation of
15// concrete ops (like DimOp) with very little boilerplate.
16//
17//===----------------------------------------------------------------------===//
18
19#ifndef MLIR_IR_OPDEFINITION_H
20#define MLIR_IR_OPDEFINITION_H
21
22#include "mlir/IR/Dialect.h"
23#include "mlir/IR/Operation.h"
24#include "llvm/Support/PointerLikeTypeTraits.h"
25
26#include <type_traits>
27
28namespace mlir {
29class Builder;
30class OpBuilder;
31
32/// This class represents success/failure for operation parsing. It is
33/// essentially a simple wrapper class around LogicalResult that allows for
34/// explicit conversion to bool. This allows for the parser to chain together
35/// parse rules without the clutter of "failed/succeeded".
36class ParseResult : public LogicalResult {
37public:
38 ParseResult(LogicalResult result = success()) : LogicalResult(result) {}
39
40 // Allow diagnostics emitted during parsing to be converted to failure.
41 ParseResult(const InFlightDiagnostic &) : LogicalResult(failure()) {}
42 ParseResult(const Diagnostic &) : LogicalResult(failure()) {}
43
44 /// Failure is true in a boolean context.
45 explicit operator bool() const { return failed(); }
46};
47/// This class implements `Optional` functionality for ParseResult. We don't
48/// directly use Optional here, because it provides an implicit conversion
49/// to 'bool' which we want to avoid. This class is used to implement tri-state
50/// 'parseOptional' functions that may have a failure mode when parsing that
51/// shouldn't be attributed to "not present".
52class OptionalParseResult {
53public:
54 OptionalParseResult() = default;
55 OptionalParseResult(LogicalResult result) : impl(result) {}
56 OptionalParseResult(ParseResult result) : impl(result) {}
57 OptionalParseResult(const InFlightDiagnostic &)
58 : OptionalParseResult(failure()) {}
59 OptionalParseResult(llvm::NoneType) : impl(llvm::None) {}
60
61 /// Returns true if we contain a valid ParseResult value.
62 bool hasValue() const { return impl.hasValue(); }
63
64 /// Access the internal ParseResult value.
65 ParseResult getValue() const { return impl.getValue(); }
66 ParseResult operator*() const { return getValue(); }
67
68private:
69 Optional<ParseResult> impl;
70};
71
72// These functions are out-of-line utilities, which avoids them being template
73// instantiated/duplicated.
74namespace impl {
75/// Insert an operation, generated by `buildTerminatorOp`, at the end of the
76/// region's only block if it does not have a terminator already. If the region
77/// is empty, insert a new block first. `buildTerminatorOp` should return the
78/// terminator operation to insert.
79void ensureRegionTerminator(
80 Region &region, OpBuilder &builder, Location loc,
81 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
82void ensureRegionTerminator(
83 Region &region, Builder &builder, Location loc,
84 function_ref<Operation *(OpBuilder &, Location)> buildTerminatorOp);
85
86} // namespace impl
87
88/// This is the concrete base class that holds the operation pointer and has
89/// non-generic methods that only depend on State (to avoid having them
90/// instantiated on template types that don't affect them.
91///
92/// This also has the fallback implementations of customization hooks for when
93/// they aren't customized.
94class OpState {
95public:
96 /// Ops are pointer-like, so we allow conversion to bool.
97 explicit operator bool() { return getOperation() != nullptr; }
98
99 /// This implicitly converts to Operation*.
100 operator Operation *() const { return state; }
101
102 /// Shortcut of `->` to access a member of Operation.
103 Operation *operator->() const { return state; }
104
105 /// Return the operation that this refers to.
106 Operation *getOperation() { return state; }
107
108 /// Return the context this operation belongs to.
109 MLIRContext *getContext() { return getOperation()->getContext(); }
110
111 /// Print the operation to the given stream.
112 void print(raw_ostream &os, OpPrintingFlags flags = llvm::None) {
113 state->print(os, flags);
26
Called C++ object pointer is null
114 }
115 void print(raw_ostream &os, AsmState &asmState,
116 OpPrintingFlags flags = llvm::None) {
117 state->print(os, asmState, flags);
118 }
119
120 /// Dump this operation.
121 void dump() { state->dump(); }
122
123 /// The source location the operation was defined or derived from.
124 Location getLoc() { return state->getLoc(); }
125
126 /// Return true if there are no users of any results of this operation.
127 bool use_empty() { return state->use_empty(); }
128
129 /// Remove this operation from its parent block and delete it.
130 void erase() { state->erase(); }
131
132 /// Emit an error with the op name prefixed, like "'dim' op " which is
133 /// convenient for verifiers.
134 InFlightDiagnostic emitOpError(const Twine &message = {});
135
136 /// Emit an error about fatal conditions with this operation, reporting up to
137 /// any diagnostic handlers that may be listening.
138 InFlightDiagnostic emitError(const Twine &message = {});
139
140 /// Emit a warning about this operation, reporting up to any diagnostic
141 /// handlers that may be listening.
142 InFlightDiagnostic emitWarning(const Twine &message = {});
143
144 /// Emit a remark about this operation, reporting up to any diagnostic
145 /// handlers that may be listening.
146 InFlightDiagnostic emitRemark(const Twine &message = {});
147
148 /// Walk the operation by calling the callback for each nested operation
149 /// (including this one), block or region, depending on the callback provided.
150 /// Regions, blocks and operations at the same nesting level are visited in
151 /// lexicographical order. The walk order for enclosing regions, blocks and
152 /// operations with respect to their nested ones is specified by 'Order'
153 /// (post-order by default). A callback on a block or operation is allowed to
154 /// erase that block or operation if either:
155 /// * the walk is in post-order, or
156 /// * the walk is in pre-order and the walk is skipped after the erasure.
157 /// See Operation::walk for more details.
158 template <WalkOrder Order = WalkOrder::PostOrder, typename FnT,
159 typename RetT = detail::walkResultType<FnT>>
160 RetT walk(FnT &&callback) {
161 return state->walk<Order>(std::forward<FnT>(callback));
162 }
163
164 // These are default implementations of customization hooks.
165public:
166 /// This hook returns any canonicalization pattern rewrites that the operation
167 /// supports, for use by the canonicalization pass.
168 static void getCanonicalizationPatterns(RewritePatternSet &results,
169 MLIRContext *context) {}
170
171protected:
172 /// If the concrete type didn't implement a custom verifier hook, just fall
173 /// back to this one which accepts everything.
174 LogicalResult verify() { return success(); }
175
176 /// Parse the custom form of an operation. Unless overridden, this method will
177 /// first try to get an operation parser from the op's dialect. Otherwise the
178 /// custom assembly form of an op is always rejected. Op implementations
179 /// should implement this to return failure. On success, they should fill in
180 /// result with the fields to use.
181 static ParseResult parse(OpAsmParser &parser, OperationState &result);
182
183 /// Print the operation. Unless overridden, this method will first try to get
184 /// an operation printer from the dialect. Otherwise, it prints the operation
185 /// in generic form.
186 static void print(Operation *op, OpAsmPrinter &p, StringRef defaultDialect);
187
188 /// Print an operation name, eliding the dialect prefix if necessary.
189 static void printOpName(Operation *op, OpAsmPrinter &p,
190 StringRef defaultDialect);
191
192 /// Mutability management is handled by the OpWrapper/OpConstWrapper classes,
193 /// so we can cast it away here.
194 explicit OpState(Operation *state) : state(state) {}
195
196private:
197 Operation *state;
198
199 /// Allow access to internal hook implementation methods.
200 friend RegisteredOperationName;
201};
202
203// Allow comparing operators.
204inline bool operator==(OpState lhs, OpState rhs) {
205 return lhs.getOperation() == rhs.getOperation();
206}
207inline bool operator!=(OpState lhs, OpState rhs) {
208 return lhs.getOperation() != rhs.getOperation();
209}
210
211raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr);
212
213/// This class represents a single result from folding an operation.
214class OpFoldResult : public PointerUnion<Attribute, Value> {
215 using PointerUnion<Attribute, Value>::PointerUnion;
216
217public:
218 void dump() { llvm::errs() << *this << "\n"; }
219};
220
221/// Allow printing to a stream.
222inline raw_ostream &operator<<(raw_ostream &os, OpFoldResult ofr) {
223 if (Value value = ofr.dyn_cast<Value>())
224 value.print(os);
225 else
226 ofr.dyn_cast<Attribute>().print(os);
227 return os;
228}
229
230/// Allow printing to a stream.
231inline raw_ostream &operator<<(raw_ostream &os, OpState op) {
232 op.print(os, OpPrintingFlags().useLocalScope());
25
Calling 'OpState::print'
233 return os;
234}
235
236//===----------------------------------------------------------------------===//
237// Operation Trait Types
238//===----------------------------------------------------------------------===//
239
240namespace OpTrait {
241
242// These functions are out-of-line implementations of the methods in the
243// corresponding trait classes. This avoids them being template
244// instantiated/duplicated.
245namespace impl {
246OpFoldResult foldIdempotent(Operation *op);
247OpFoldResult foldInvolution(Operation *op);
248LogicalResult verifyZeroOperands(Operation *op);
249LogicalResult verifyOneOperand(Operation *op);
250LogicalResult verifyNOperands(Operation *op, unsigned numOperands);
251LogicalResult verifyIsIdempotent(Operation *op);
252LogicalResult verifyIsInvolution(Operation *op);
253LogicalResult verifyAtLeastNOperands(Operation *op, unsigned numOperands);
254LogicalResult verifyOperandsAreFloatLike(Operation *op);
255LogicalResult verifyOperandsAreSignlessIntegerLike(Operation *op);
256LogicalResult verifySameTypeOperands(Operation *op);
257LogicalResult verifyZeroRegion(Operation *op);
258LogicalResult verifyOneRegion(Operation *op);
259LogicalResult verifyNRegions(Operation *op, unsigned numRegions);
260LogicalResult verifyAtLeastNRegions(Operation *op, unsigned numRegions);
261LogicalResult verifyZeroResult(Operation *op);
262LogicalResult verifyOneResult(Operation *op);
263LogicalResult verifyNResults(Operation *op, unsigned numOperands);
264LogicalResult verifyAtLeastNResults(Operation *op, unsigned numOperands);
265LogicalResult verifySameOperandsShape(Operation *op);
266LogicalResult verifySameOperandsAndResultShape(Operation *op);
267LogicalResult verifySameOperandsElementType(Operation *op);
268LogicalResult verifySameOperandsAndResultElementType(Operation *op);
269LogicalResult verifySameOperandsAndResultType(Operation *op);
270LogicalResult verifyResultsAreBoolLike(Operation *op);
271LogicalResult verifyResultsAreFloatLike(Operation *op);
272LogicalResult verifyResultsAreSignlessIntegerLike(Operation *op);
273LogicalResult verifyIsTerminator(Operation *op);
274LogicalResult verifyZeroSuccessor(Operation *op);
275LogicalResult verifyOneSuccessor(Operation *op);
276LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
277LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
278LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
279 StringRef valueGroupName,
280 size_t expectedCount);
281LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
282LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
283LogicalResult verifyNoRegionArguments(Operation *op);
284LogicalResult verifyElementwise(Operation *op);
285LogicalResult verifyIsIsolatedFromAbove(Operation *op);
286} // namespace impl
287
288/// Helper class for implementing traits. Clients are not expected to interact
289/// with this directly, so its members are all protected.
290template <typename ConcreteType, template <typename> class TraitType>
291class TraitBase {
292protected:
293 /// Return the ultimate Operation being worked on.
294 Operation *getOperation() {
295 // We have to cast up to the trait type, then to the concrete type, then to
296 // the BaseState class in explicit hops because the concrete type will
297 // multiply derive from the (content free) TraitBase class, and we need to
298 // be able to disambiguate the path for the C++ compiler.
299 auto *trait = static_cast<TraitType<ConcreteType> *>(this);
300 auto *concrete = static_cast<ConcreteType *>(trait);
301 auto *base = static_cast<OpState *>(concrete);
302 return base->getOperation();
303 }
304};
305
306//===----------------------------------------------------------------------===//
307// Operand Traits
308
309namespace detail {
310/// Utility trait base that provides accessors for derived traits that have
311/// multiple operands.
312template <typename ConcreteType, template <typename> class TraitType>
313struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
314 using operand_iterator = Operation::operand_iterator;
315 using operand_range = Operation::operand_range;
316 using operand_type_iterator = Operation::operand_type_iterator;
317 using operand_type_range = Operation::operand_type_range;
318
319 /// Return the number of operands.
320 unsigned getNumOperands() { return this->getOperation()->getNumOperands(); }
321
322 /// Return the operand at index 'i'.
323 Value getOperand(unsigned i) { return this->getOperation()->getOperand(i); }
324
325 /// Set the operand at index 'i' to 'value'.
326 void setOperand(unsigned i, Value value) {
327 this->getOperation()->setOperand(i, value);
328 }
329
330 /// Operand iterator access.
331 operand_iterator operand_begin() {
332 return this->getOperation()->operand_begin();
333 }
334 operand_iterator operand_end() { return this->getOperation()->operand_end(); }
335 operand_range getOperands() { return this->getOperation()->getOperands(); }
336
337 /// Operand type access.
338 operand_type_iterator operand_type_begin() {
339 return this->getOperation()->operand_type_begin();
340 }
341 operand_type_iterator operand_type_end() {
342 return this->getOperation()->operand_type_end();
343 }
344 operand_type_range getOperandTypes() {
345 return this->getOperation()->getOperandTypes();
346 }
347};
348} // namespace detail
349
350/// This class provides the API for ops that are known to have no
351/// SSA operand.
352template <typename ConcreteType>
353class ZeroOperands : public TraitBase<ConcreteType, ZeroOperands> {
354public:
355 static LogicalResult verifyTrait(Operation *op) {
356 return impl::verifyZeroOperands(op);
357 }
358
359private:
360 // Disable these.
361 void getOperand() {}
362 void setOperand() {}
363};
364
365/// This class provides the API for ops that are known to have exactly one
366/// SSA operand.
367template <typename ConcreteType>
368class OneOperand : public TraitBase<ConcreteType, OneOperand> {
369public:
370 Value getOperand() { return this->getOperation()->getOperand(0); }
371
372 void setOperand(Value value) { this->getOperation()->setOperand(0, value); }
373
374 static LogicalResult verifyTrait(Operation *op) {
375 return impl::verifyOneOperand(op);
376 }
377};
378
379/// This class provides the API for ops that are known to have a specified
380/// number of operands. This is used as a trait like this:
381///
382/// class FooOp : public Op<FooOp, OpTrait::NOperands<2>::Impl> {
383///
384template <unsigned N>
385class NOperands {
386public:
387 static_assert(N > 1, "use ZeroOperands/OneOperand for N < 2");
388
389 template <typename ConcreteType>
390 class Impl
391 : public detail::MultiOperandTraitBase<ConcreteType, NOperands<N>::Impl> {
392 public:
393 static LogicalResult verifyTrait(Operation *op) {
394 return impl::verifyNOperands(op, N);
395 }
396 };
397};
398
399/// This class provides the API for ops that are known to have a at least a
400/// specified number of operands. This is used as a trait like this:
401///
402/// class FooOp : public Op<FooOp, OpTrait::AtLeastNOperands<2>::Impl> {
403///
404template <unsigned N>
405class AtLeastNOperands {
406public:
407 template <typename ConcreteType>
408 class Impl : public detail::MultiOperandTraitBase<ConcreteType,
409 AtLeastNOperands<N>::Impl> {
410 public:
411 static LogicalResult verifyTrait(Operation *op) {
412 return impl::verifyAtLeastNOperands(op, N);
413 }
414 };
415};
416
417/// This class provides the API for ops which have an unknown number of
418/// SSA operands.
419template <typename ConcreteType>
420class VariadicOperands
421 : public detail::MultiOperandTraitBase<ConcreteType, VariadicOperands> {};
422
423//===----------------------------------------------------------------------===//
424// Region Traits
425
426/// This class provides verification for ops that are known to have zero
427/// regions.
428template <typename ConcreteType>
429class ZeroRegion : public TraitBase<ConcreteType, ZeroRegion> {
430public:
431 static LogicalResult verifyTrait(Operation *op) {
432 return impl::verifyZeroRegion(op);
433 }
434};
435
436namespace detail {
437/// Utility trait base that provides accessors for derived traits that have
438/// multiple regions.
439template <typename ConcreteType, template <typename> class TraitType>
440struct MultiRegionTraitBase : public TraitBase<ConcreteType, TraitType> {
441 using region_iterator = MutableArrayRef<Region>;
442 using region_range = RegionRange;
443
444 /// Return the number of regions.
445 unsigned getNumRegions() { return this->getOperation()->getNumRegions(); }
446
447 /// Return the region at `index`.
448 Region &getRegion(unsigned i) { return this->getOperation()->getRegion(i); }
449
450 /// Region iterator access.
451 region_iterator region_begin() {
452 return this->getOperation()->region_begin();
453 }
454 region_iterator region_end() { return this->getOperation()->region_end(); }
455 region_range getRegions() { return this->getOperation()->getRegions(); }
456};
457} // namespace detail
458
459/// This class provides APIs for ops that are known to have a single region.
460template <typename ConcreteType>
461class OneRegion : public TraitBase<ConcreteType, OneRegion> {
462public:
463 Region &getRegion() { return this->getOperation()->getRegion(0); }
464
465 /// Returns a range of operations within the region of this operation.
466 auto getOps() { return getRegion().getOps(); }
467 template <typename OpT>
468 auto getOps() {
469 return getRegion().template getOps<OpT>();
470 }
471
472 static LogicalResult verifyTrait(Operation *op) {
473 return impl::verifyOneRegion(op);
474 }
475};
476
477/// This class provides the API for ops that are known to have a specified
478/// number of regions.
479template <unsigned N>
480class NRegions {
481public:
482 static_assert(N > 1, "use ZeroRegion/OneRegion for N < 2");
483
484 template <typename ConcreteType>
485 class Impl
486 : public detail::MultiRegionTraitBase<ConcreteType, NRegions<N>::Impl> {
487 public:
488 static LogicalResult verifyTrait(Operation *op) {
489 return impl::verifyNRegions(op, N);
490 }
491 };
492};
493
494/// This class provides APIs for ops that are known to have at least a specified
495/// number of regions.
496template <unsigned N>
497class AtLeastNRegions {
498public:
499 template <typename ConcreteType>
500 class Impl : public detail::MultiRegionTraitBase<ConcreteType,
501 AtLeastNRegions<N>::Impl> {
502 public:
503 static LogicalResult verifyTrait(Operation *op) {
504 return impl::verifyAtLeastNRegions(op, N);
505 }
506 };
507};
508
509/// This class provides the API for ops which have an unknown number of
510/// regions.
511template <typename ConcreteType>
512class VariadicRegions
513 : public detail::MultiRegionTraitBase<ConcreteType, VariadicRegions> {};
514
515//===----------------------------------------------------------------------===//
516// Result Traits
517
518/// This class provides return value APIs for ops that are known to have
519/// zero results.
520template <typename ConcreteType>
521class ZeroResult : public TraitBase<ConcreteType, ZeroResult> {
522public:
523 static LogicalResult verifyTrait(Operation *op) {
524 return impl::verifyZeroResult(op);
525 }
526};
527
528namespace detail {
529/// Utility trait base that provides accessors for derived traits that have
530/// multiple results.
531template <typename ConcreteType, template <typename> class TraitType>
532struct MultiResultTraitBase : public TraitBase<ConcreteType, TraitType> {
533 using result_iterator = Operation::result_iterator;
534 using result_range = Operation::result_range;
535 using result_type_iterator = Operation::result_type_iterator;
536 using result_type_range = Operation::result_type_range;
537
538 /// Return the number of results.
539 unsigned getNumResults() { return this->getOperation()->getNumResults(); }
540
541 /// Return the result at index 'i'.
542 Value getResult(unsigned i) { return this->getOperation()->getResult(i); }
543
544 /// Replace all uses of results of this operation with the provided 'values'.
545 /// 'values' may correspond to an existing operation, or a range of 'Value'.
546 template <typename ValuesT>
547 void replaceAllUsesWith(ValuesT &&values) {
548 this->getOperation()->replaceAllUsesWith(std::forward<ValuesT>(values));
549 }
550
551 /// Return the type of the `i`-th result.
552 Type getType(unsigned i) { return getResult(i).getType(); }
553
554 /// Result iterator access.
555 result_iterator result_begin() {
556 return this->getOperation()->result_begin();
557 }
558 result_iterator result_end() { return this->getOperation()->result_end(); }
559 result_range getResults() { return this->getOperation()->getResults(); }
560
561 /// Result type access.
562 result_type_iterator result_type_begin() {
563 return this->getOperation()->result_type_begin();
564 }
565 result_type_iterator result_type_end() {
566 return this->getOperation()->result_type_end();
567 }
568 result_type_range getResultTypes() {
569 return this->getOperation()->getResultTypes();
570 }
571};
572} // namespace detail
573
574/// This class provides return value APIs for ops that are known to have a
575/// single result. ResultType is the concrete type returned by getType().
576template <typename ConcreteType>
577class OneResult : public TraitBase<ConcreteType, OneResult> {
578public:
579 Value getResult() { return this->getOperation()->getResult(0); }
580
581 /// If the operation returns a single value, then the Op can be implicitly
582 /// converted to an Value. This yields the value of the only result.
583 operator Value() { return getResult(); }
584
585 /// Replace all uses of 'this' value with the new value, updating anything
586 /// in the IR that uses 'this' to use the other value instead. When this
587 /// returns there are zero uses of 'this'.
588 void replaceAllUsesWith(Value newValue) {
589 getResult().replaceAllUsesWith(newValue);
590 }
591
592 /// Replace all uses of 'this' value with the result of 'op'.
593 void replaceAllUsesWith(Operation *op) {
594 this->getOperation()->replaceAllUsesWith(op);
595 }
596
597 static LogicalResult verifyTrait(Operation *op) {
598 return impl::verifyOneResult(op);
599 }
600};
601
602/// This trait is used for return value APIs for ops that are known to have a
603/// specific type other than `Type`. This allows the "getType()" member to be
604/// more specific for an op. This should be used in conjunction with OneResult,
605/// and occur in the trait list before OneResult.
606template <typename ResultType>
607class OneTypedResult {
608public:
609 /// This class provides return value APIs for ops that are known to have a
610 /// single result. ResultType is the concrete type returned by getType().
611 template <typename ConcreteType>
612 class Impl
613 : public TraitBase<ConcreteType, OneTypedResult<ResultType>::Impl> {
614 public:
615 ResultType getType() {
616 auto resultTy = this->getOperation()->getResult(0).getType();
617 return resultTy.template cast<ResultType>();
618 }
619 };
620};
621
622/// This class provides the API for ops that are known to have a specified
623/// number of results. This is used as a trait like this:
624///
625/// class FooOp : public Op<FooOp, OpTrait::NResults<2>::Impl> {
626///
627template <unsigned N>
628class NResults {
629public:
630 static_assert(N > 1, "use ZeroResult/OneResult for N < 2");
631
632 template <typename ConcreteType>
633 class Impl
634 : public detail::MultiResultTraitBase<ConcreteType, NResults<N>::Impl> {
635 public:
636 static LogicalResult verifyTrait(Operation *op) {
637 return impl::verifyNResults(op, N);
638 }
639 };
640};
641
642/// This class provides the API for ops that are known to have at least a
643/// specified number of results. This is used as a trait like this:
644///
645/// class FooOp : public Op<FooOp, OpTrait::AtLeastNResults<2>::Impl> {
646///
647template <unsigned N>
648class AtLeastNResults {
649public:
650 template <typename ConcreteType>
651 class Impl : public detail::MultiResultTraitBase<ConcreteType,
652 AtLeastNResults<N>::Impl> {
653 public:
654 static LogicalResult verifyTrait(Operation *op) {
655 return impl::verifyAtLeastNResults(op, N);
656 }
657 };
658};
659
660/// This class provides the API for ops which have an unknown number of
661/// results.
662template <typename ConcreteType>
663class VariadicResults
664 : public detail::MultiResultTraitBase<ConcreteType, VariadicResults> {};
665
666//===----------------------------------------------------------------------===//
667// Terminator Traits
668
669/// This class indicates that the regions associated with this op don't have
670/// terminators.
671template <typename ConcreteType>
672class NoTerminator : public TraitBase<ConcreteType, NoTerminator> {};
673
674/// This class provides the API for ops that are known to be terminators.
675template <typename ConcreteType>
676class IsTerminator : public TraitBase<ConcreteType, IsTerminator> {
677public:
678 static LogicalResult verifyTrait(Operation *op) {
679 return impl::verifyIsTerminator(op);
680 }
681};
682
683/// This class provides verification for ops that are known to have zero
684/// successors.
685template <typename ConcreteType>
686class ZeroSuccessor : public TraitBase<ConcreteType, ZeroSuccessor> {
687public:
688 static LogicalResult verifyTrait(Operation *op) {
689 return impl::verifyZeroSuccessor(op);
690 }
691};
692
693namespace detail {
694/// Utility trait base that provides accessors for derived traits that have
695/// multiple successors.
696template <typename ConcreteType, template <typename> class TraitType>
697struct MultiSuccessorTraitBase : public TraitBase<ConcreteType, TraitType> {
698 using succ_iterator = Operation::succ_iterator;
699 using succ_range = SuccessorRange;
700
701 /// Return the number of successors.
702 unsigned getNumSuccessors() {
703 return this->getOperation()->getNumSuccessors();
704 }
705
706 /// Return the successor at `index`.
707 Block *getSuccessor(unsigned i) {
708 return this->getOperation()->getSuccessor(i);
709 }
710
711 /// Set the successor at `index`.
712 void setSuccessor(Block *block, unsigned i) {
713 return this->getOperation()->setSuccessor(block, i);
714 }
715
716 /// Successor iterator access.
717 succ_iterator succ_begin() { return this->getOperation()->succ_begin(); }
718 succ_iterator succ_end() { return this->getOperation()->succ_end(); }
719 succ_range getSuccessors() { return this->getOperation()->getSuccessors(); }
720};
721} // namespace detail
722
723/// This class provides APIs for ops that are known to have a single successor.
724template <typename ConcreteType>
725class OneSuccessor : public TraitBase<ConcreteType, OneSuccessor> {
726public:
727 Block *getSuccessor() { return this->getOperation()->getSuccessor(0); }
728 void setSuccessor(Block *succ) {
729 this->getOperation()->setSuccessor(succ, 0);
730 }
731
732 static LogicalResult verifyTrait(Operation *op) {
733 return impl::verifyOneSuccessor(op);
734 }
735};
736
737/// This class provides the API for ops that are known to have a specified
738/// number of successors.
739template <unsigned N>
740class NSuccessors {
741public:
742 static_assert(N > 1, "use ZeroSuccessor/OneSuccessor for N < 2");
743
744 template <typename ConcreteType>
745 class Impl : public detail::MultiSuccessorTraitBase<ConcreteType,
746 NSuccessors<N>::Impl> {
747 public:
748 static LogicalResult verifyTrait(Operation *op) {
749 return impl::verifyNSuccessors(op, N);
750 }
751 };
752};
753
754/// This class provides APIs for ops that are known to have at least a specified
755/// number of successors.
756template <unsigned N>
757class AtLeastNSuccessors {
758public:
759 template <typename ConcreteType>
760 class Impl
761 : public detail::MultiSuccessorTraitBase<ConcreteType,
762 AtLeastNSuccessors<N>::Impl> {
763 public:
764 static LogicalResult verifyTrait(Operation *op) {
765 return impl::verifyAtLeastNSuccessors(op, N);
766 }
767 };
768};
769
770/// This class provides the API for ops which have an unknown number of
771/// successors.
772template <typename ConcreteType>
773class VariadicSuccessors
774 : public detail::MultiSuccessorTraitBase<ConcreteType, VariadicSuccessors> {
775};
776
777//===----------------------------------------------------------------------===//
778// SingleBlock
779
780/// This class provides APIs and verifiers for ops with regions having a single
781/// block.
782template <typename ConcreteType>
783struct SingleBlock : public TraitBase<ConcreteType, SingleBlock> {
784public:
785 static LogicalResult verifyTrait(Operation *op) {
786 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
787 Region &region = op->getRegion(i);
788
789 // Empty regions are fine.
790 if (region.empty())
791 continue;
792
793 // Non-empty regions must contain a single basic block.
794 if (!llvm::hasSingleElement(region))
795 return op->emitOpError("expects region #")
796 << i << " to have 0 or 1 blocks";
797
798 if (!ConcreteType::template hasTrait<NoTerminator>()) {
799 Block &block = region.front();
800 if (block.empty())
801 return op->emitOpError() << "expects a non-empty block";
802 }
803 }
804 return success();
805 }
806
807 Block *getBody(unsigned idx = 0) {
808 Region &region = this->getOperation()->getRegion(idx);
809 assert(!region.empty() && "unexpected empty region")(static_cast <bool> (!region.empty() && "unexpected empty region"
) ? void (0) : __assert_fail ("!region.empty() && \"unexpected empty region\""
, "mlir/include/mlir/IR/OpDefinition.h", 809, __extension__ __PRETTY_FUNCTION__
))
;
810 return &region.front();
811 }
812 Region &getBodyRegion(unsigned idx = 0) {
813 return this->getOperation()->getRegion(idx);
814 }
815
816 //===------------------------------------------------------------------===//
817 // Single Region Utilities
818 //===------------------------------------------------------------------===//
819
820 /// The following are a set of methods only enabled when the parent
821 /// operation has a single region. Each of these methods take an additional
822 /// template parameter that represents the concrete operation so that we
823 /// can use SFINAE to disable the methods for non-single region operations.
824 template <typename OpT, typename T = void>
825 using enable_if_single_region =
826 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
827
828 template <typename OpT = ConcreteType>
829 enable_if_single_region<OpT, Block::iterator> begin() {
830 return getBody()->begin();
831 }
832 template <typename OpT = ConcreteType>
833 enable_if_single_region<OpT, Block::iterator> end() {
834 return getBody()->end();
835 }
836 template <typename OpT = ConcreteType>
837 enable_if_single_region<OpT, Operation &> front() {
838 return *begin();
839 }
840
841 /// Insert the operation into the back of the body.
842 template <typename OpT = ConcreteType>
843 enable_if_single_region<OpT> push_back(Operation *op) {
844 insert(Block::iterator(getBody()->end()), op);
845 }
846
847 /// Insert the operation at the given insertion point.
848 template <typename OpT = ConcreteType>
849 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
850 insert(Block::iterator(insertPt), op);
851 }
852 template <typename OpT = ConcreteType>
853 enable_if_single_region<OpT> insert(Block::iterator insertPt, Operation *op) {
854 getBody()->getOperations().insert(insertPt, op);
855 }
856};
857
858//===----------------------------------------------------------------------===//
859// SingleBlockImplicitTerminator
860
861/// This class provides APIs and verifiers for ops with regions having a single
862/// block that must terminate with `TerminatorOpType`.
863template <typename TerminatorOpType>
864struct SingleBlockImplicitTerminator {
865 template <typename ConcreteType>
866 class Impl : public SingleBlock<ConcreteType> {
867 private:
868 using Base = SingleBlock<ConcreteType>;
869 /// Builds a terminator operation without relying on OpBuilder APIs to avoid
870 /// cyclic header inclusion.
871 static Operation *buildTerminator(OpBuilder &builder, Location loc) {
872 OperationState state(loc, TerminatorOpType::getOperationName());
873 TerminatorOpType::build(builder, state);
874 return Operation::create(state);
875 }
876
877 public:
878 /// The type of the operation used as the implicit terminator type.
879 using ImplicitTerminatorOpT = TerminatorOpType;
880
881 static LogicalResult verifyTrait(Operation *op) {
882 if (failed(Base::verifyTrait(op)))
883 return failure();
884 for (unsigned i = 0, e = op->getNumRegions(); i < e; ++i) {
885 Region &region = op->getRegion(i);
886 // Empty regions are fine.
887 if (region.empty())
888 continue;
889 Operation &terminator = region.front().back();
890 if (isa<TerminatorOpType>(terminator))
891 continue;
892
893 return op->emitOpError("expects regions to end with '" +
894 TerminatorOpType::getOperationName() +
895 "', found '" +
896 terminator.getName().getStringRef() + "'")
897 .attachNote()
898 << "in custom textual format, the absence of terminator implies "
899 "'"
900 << TerminatorOpType::getOperationName() << '\'';
901 }
902
903 return success();
904 }
905
906 /// Ensure that the given region has the terminator required by this trait.
907 /// If OpBuilder is provided, use it to build the terminator and notify the
908 /// OpBuilder listeners accordingly. If only a Builder is provided, locally
909 /// construct an OpBuilder with no listeners; this should only be used if no
910 /// OpBuilder is available at the call site, e.g., in the parser.
911 static void ensureTerminator(Region &region, Builder &builder,
912 Location loc) {
913 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
914 buildTerminator);
915 }
916 static void ensureTerminator(Region &region, OpBuilder &builder,
917 Location loc) {
918 ::mlir::impl::ensureRegionTerminator(region, builder, loc,
919 buildTerminator);
920 }
921
922 //===------------------------------------------------------------------===//
923 // Single Region Utilities
924 //===------------------------------------------------------------------===//
925 using Base::getBody;
926
927 template <typename OpT, typename T = void>
928 using enable_if_single_region =
929 typename std::enable_if_t<OpT::template hasTrait<OneRegion>(), T>;
930
931 /// Insert the operation into the back of the body, before the terminator.
932 template <typename OpT = ConcreteType>
933 enable_if_single_region<OpT> push_back(Operation *op) {
934 insert(Block::iterator(getBody()->getTerminator()), op);
935 }
936
937 /// Insert the operation at the given insertion point. Note: The operation
938 /// is never inserted after the terminator, even if the insertion point is
939 /// end().
940 template <typename OpT = ConcreteType>
941 enable_if_single_region<OpT> insert(Operation *insertPt, Operation *op) {
942 insert(Block::iterator(insertPt), op);
943 }
944 template <typename OpT = ConcreteType>
945 enable_if_single_region<OpT> insert(Block::iterator insertPt,
946 Operation *op) {
947 auto *body = getBody();
948 if (insertPt == body->end())
949 insertPt = Block::iterator(body->getTerminator());
950 body->getOperations().insert(insertPt, op);
951 }
952 };
953};
954
955/// Check is an op defines the `ImplicitTerminatorOpT` member. This is intended
956/// to be used with `llvm::is_detected`.
957template <class T>
958using has_implicit_terminator_t = typename T::ImplicitTerminatorOpT;
959
960/// Support to check if an operation has the SingleBlockImplicitTerminator
961/// trait. We can't just use `hasTrait` because this class is templated on a
962/// specific terminator op.
963template <class Op, bool hasTerminator =
964 llvm::is_detected<has_implicit_terminator_t, Op>::value>
965struct hasSingleBlockImplicitTerminator {
966 static constexpr bool value = std::is_base_of<
967 typename OpTrait::SingleBlockImplicitTerminator<
968 typename Op::ImplicitTerminatorOpT>::template Impl<Op>,
969 Op>::value;
970};
971template <class Op>
972struct hasSingleBlockImplicitTerminator<Op, false> {
973 static constexpr bool value = false;
974};
975
976//===----------------------------------------------------------------------===//
977// Misc Traits
978
979/// This class provides verification for ops that are known to have the same
980/// operand shape: all operands are scalars, vectors/tensors of the same
981/// shape.
982template <typename ConcreteType>
983class SameOperandsShape : public TraitBase<ConcreteType, SameOperandsShape> {
984public:
985 static LogicalResult verifyTrait(Operation *op) {
986 return impl::verifySameOperandsShape(op);
987 }
988};
989
990/// This class provides verification for ops that are known to have the same
991/// operand and result shape: both are scalars, vectors/tensors of the same
992/// shape.
993template <typename ConcreteType>
994class SameOperandsAndResultShape
995 : public TraitBase<ConcreteType, SameOperandsAndResultShape> {
996public:
997 static LogicalResult verifyTrait(Operation *op) {
998 return impl::verifySameOperandsAndResultShape(op);
999 }
1000};
1001
1002/// This class provides verification for ops that are known to have the same
1003/// operand element type (or the type itself if it is scalar).
1004///
1005template <typename ConcreteType>
1006class SameOperandsElementType
1007 : public TraitBase<ConcreteType, SameOperandsElementType> {
1008public:
1009 static LogicalResult verifyTrait(Operation *op) {
1010 return impl::verifySameOperandsElementType(op);
1011 }
1012};
1013
1014/// This class provides verification for ops that are known to have the same
1015/// operand and result element type (or the type itself if it is scalar).
1016///
1017template <typename ConcreteType>
1018class SameOperandsAndResultElementType
1019 : public TraitBase<ConcreteType, SameOperandsAndResultElementType> {
1020public:
1021 static LogicalResult verifyTrait(Operation *op) {
1022 return impl::verifySameOperandsAndResultElementType(op);
1023 }
1024};
1025
1026/// This class provides verification for ops that are known to have the same
1027/// operand and result type.
1028///
1029/// Note: this trait subsumes the SameOperandsAndResultShape and
1030/// SameOperandsAndResultElementType traits.
1031template <typename ConcreteType>
1032class SameOperandsAndResultType
1033 : public TraitBase<ConcreteType, SameOperandsAndResultType> {
1034public:
1035 static LogicalResult verifyTrait(Operation *op) {
1036 return impl::verifySameOperandsAndResultType(op);
1037 }
1038};
1039
1040/// This class verifies that any results of the specified op have a boolean
1041/// type, a vector thereof, or a tensor thereof.
1042template <typename ConcreteType>
1043class ResultsAreBoolLike : public TraitBase<ConcreteType, ResultsAreBoolLike> {
1044public:
1045 static LogicalResult verifyTrait(Operation *op) {
1046 return impl::verifyResultsAreBoolLike(op);
1047 }
1048};
1049
1050/// This class verifies that any results of the specified op have a floating
1051/// point type, a vector thereof, or a tensor thereof.
1052template <typename ConcreteType>
1053class ResultsAreFloatLike
1054 : public TraitBase<ConcreteType, ResultsAreFloatLike> {
1055public:
1056 static LogicalResult verifyTrait(Operation *op) {
1057 return impl::verifyResultsAreFloatLike(op);
1058 }
1059};
1060
1061/// This class verifies that any results of the specified op have a signless
1062/// integer or index type, a vector thereof, or a tensor thereof.
1063template <typename ConcreteType>
1064class ResultsAreSignlessIntegerLike
1065 : public TraitBase<ConcreteType, ResultsAreSignlessIntegerLike> {
1066public:
1067 static LogicalResult verifyTrait(Operation *op) {
1068 return impl::verifyResultsAreSignlessIntegerLike(op);
1069 }
1070};
1071
1072/// This class adds property that the operation is commutative.
1073template <typename ConcreteType>
1074class IsCommutative : public TraitBase<ConcreteType, IsCommutative> {};
1075
1076/// This class adds property that the operation is an involution.
1077/// This means a unary to unary operation "f" that satisfies f(f(x)) = x
1078template <typename ConcreteType>
1079class IsInvolution : public TraitBase<ConcreteType, IsInvolution> {
1080public:
1081 static LogicalResult verifyTrait(Operation *op) {
1082 static_assert(ConcreteType::template hasTrait<OneResult>(),
1083 "expected operation to produce one result");
1084 static_assert(ConcreteType::template hasTrait<OneOperand>(),
1085 "expected operation to take one operand");
1086 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1087 "expected operation to preserve type");
1088 // Involution requires the operation to be side effect free as well
1089 // but currently this check is under a FIXME and is not actually done.
1090 return impl::verifyIsInvolution(op);
1091 }
1092
1093 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1094 return impl::foldInvolution(op);
1095 }
1096};
1097
1098/// This class adds property that the operation is idempotent.
1099/// This means a unary to unary operation "f" that satisfies f(f(x)) = f(x),
1100/// or a binary operation "g" that satisfies g(x, x) = x.
1101template <typename ConcreteType>
1102class IsIdempotent : public TraitBase<ConcreteType, IsIdempotent> {
1103public:
1104 static LogicalResult verifyTrait(Operation *op) {
1105 static_assert(ConcreteType::template hasTrait<OneResult>(),
1106 "expected operation to produce one result");
1107 static_assert(ConcreteType::template hasTrait<OneOperand>() ||
1108 ConcreteType::template hasTrait<NOperands<2>::Impl>(),
1109 "expected operation to take one or two operands");
1110 static_assert(ConcreteType::template hasTrait<SameOperandsAndResultType>(),
1111 "expected operation to preserve type");
1112 // Idempotent requires the operation to be side effect free as well
1113 // but currently this check is under a FIXME and is not actually done.
1114 return impl::verifyIsIdempotent(op);
1115 }
1116
1117 static OpFoldResult foldTrait(Operation *op, ArrayRef<Attribute> operands) {
1118 return impl::foldIdempotent(op);
1119 }
1120};
1121
1122/// This class verifies that all operands of the specified op have a float type,
1123/// a vector thereof, or a tensor thereof.
1124template <typename ConcreteType>
1125class OperandsAreFloatLike
1126 : public TraitBase<ConcreteType, OperandsAreFloatLike> {
1127public:
1128 static LogicalResult verifyTrait(Operation *op) {
1129 return impl::verifyOperandsAreFloatLike(op);
1130 }
1131};
1132
1133/// This class verifies that all operands of the specified op have a signless
1134/// integer or index type, a vector thereof, or a tensor thereof.
1135template <typename ConcreteType>
1136class OperandsAreSignlessIntegerLike
1137 : public TraitBase<ConcreteType, OperandsAreSignlessIntegerLike> {
1138public:
1139 static LogicalResult verifyTrait(Operation *op) {
1140 return impl::verifyOperandsAreSignlessIntegerLike(op);
1141 }
1142};
1143
1144/// This class verifies that all operands of the specified op have the same
1145/// type.
1146template <typename ConcreteType>
1147class SameTypeOperands : public TraitBase<ConcreteType, SameTypeOperands> {
1148public:
1149 static LogicalResult verifyTrait(Operation *op) {
1150 return impl::verifySameTypeOperands(op);
1151 }
1152};
1153
1154/// This class provides the API for a sub-set of ops that are known to be
1155/// constant-like. These are non-side effecting operations with one result and
1156/// zero operands that can always be folded to a specific attribute value.
1157template <typename ConcreteType>
1158class ConstantLike : public TraitBase<ConcreteType, ConstantLike> {
1159public:
1160 static LogicalResult verifyTrait(Operation *op) {
1161 static_assert(ConcreteType::template hasTrait<OneResult>(),
1162 "expected operation to produce one result");
1163 static_assert(ConcreteType::template hasTrait<ZeroOperands>(),
1164 "expected operation to take zero operands");
1165 // TODO: We should verify that the operation can always be folded, but this
1166 // requires that the attributes of the op already be verified. We should add
1167 // support for verifying traits "after" the operation to enable this use
1168 // case.
1169 return success();
1170 }
1171};
1172
1173/// This class provides the API for ops that are known to be isolated from
1174/// above.
1175template <typename ConcreteType>
1176class IsIsolatedFromAbove
1177 : public TraitBase<ConcreteType, IsIsolatedFromAbove> {
1178public:
1179 static LogicalResult verifyTrait(Operation *op) {
1180 return impl::verifyIsIsolatedFromAbove(op);
1181 }
1182};
1183
1184/// A trait of region holding operations that defines a new scope for polyhedral
1185/// optimization purposes. Any SSA values of 'index' type that either dominate
1186/// such an operation or are used at the top-level of such an operation
1187/// automatically become valid symbols for the polyhedral scope defined by that
1188/// operation. For more details, see `Traits.md#AffineScope`.
1189template <typename ConcreteType>
1190class AffineScope : public TraitBase<ConcreteType, AffineScope> {
1191public:
1192 static LogicalResult verifyTrait(Operation *op) {
1193 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1194 "expected operation to have one or more regions");
1195 return success();
1196 }
1197};
1198
1199/// A trait of region holding operations that define a new scope for automatic
1200/// allocations, i.e., allocations that are freed when control is transferred
1201/// back from the operation's region. Any operations performing such allocations
1202/// (for eg. memref.alloca) will have their allocations automatically freed at
1203/// their closest enclosing operation with this trait.
1204template <typename ConcreteType>
1205class AutomaticAllocationScope
1206 : public TraitBase<ConcreteType, AutomaticAllocationScope> {
1207public:
1208 static LogicalResult verifyTrait(Operation *op) {
1209 static_assert(!ConcreteType::template hasTrait<ZeroRegion>(),
1210 "expected operation to have one or more regions");
1211 return success();
1212 }
1213};
1214
1215/// This class provides a verifier for ops that are expecting their parent
1216/// to be one of the given parent ops
1217template <typename... ParentOpTypes>
1218struct HasParent {
1219 template <typename ConcreteType>
1220 class Impl : public TraitBase<ConcreteType, Impl> {
1221 public:
1222 static LogicalResult verifyTrait(Operation *op) {
1223 if (llvm::isa<ParentOpTypes...>(op->getParentOp()))
1224 return success();
1225
1226 return op->emitOpError()
1227 << "expects parent op "
1228 << (sizeof...(ParentOpTypes) != 1 ? "to be one of '" : "'")
1229 << llvm::makeArrayRef({ParentOpTypes::getOperationName()...})
1230 << "'";
1231 }
1232 };
1233};
1234
1235/// A trait for operations that have an attribute specifying operand segments.
1236///
1237/// Certain operations can have multiple variadic operands and their size
1238/// relationship is not always known statically. For such cases, we need
1239/// a per-op-instance specification to divide the operands into logical groups
1240/// or segments. This can be modeled by attributes. The attribute will be named
1241/// as `operand_segment_sizes`.
1242///
1243/// This trait verifies the attribute for specifying operand segments has
1244/// the correct type (1D vector) and values (non-negative), etc.
1245template <typename ConcreteType>
1246class AttrSizedOperandSegments
1247 : public TraitBase<ConcreteType, AttrSizedOperandSegments> {
1248public:
1249 static StringRef getOperandSegmentSizeAttr() {
1250 return "operand_segment_sizes";
1251 }
1252
1253 static LogicalResult verifyTrait(Operation *op) {
1254 return ::mlir::OpTrait::impl::verifyOperandSizeAttr(
1255 op, getOperandSegmentSizeAttr());
1256 }
1257};
1258
1259/// Similar to AttrSizedOperandSegments but used for results.
1260template <typename ConcreteType>
1261class AttrSizedResultSegments
1262 : public TraitBase<ConcreteType, AttrSizedResultSegments> {
1263public:
1264 static StringRef getResultSegmentSizeAttr() { return "result_segment_sizes"; }
1265
1266 static LogicalResult verifyTrait(Operation *op) {
1267 return ::mlir::OpTrait::impl::verifyResultSizeAttr(
1268 op, getResultSegmentSizeAttr());
1269 }
1270};
1271
1272/// This trait provides a verifier for ops that are expecting their regions to
1273/// not have any arguments
1274template <typename ConcrentType>
1275struct NoRegionArguments : public TraitBase<ConcrentType, NoRegionArguments> {
1276 static LogicalResult verifyTrait(Operation *op) {
1277 return ::mlir::OpTrait::impl::verifyNoRegionArguments(op);
1278 }
1279};
1280
1281// This trait is used to flag operations that consume or produce
1282// values of `MemRef` type where those references can be 'normalized'.
1283// TODO: Right now, the operands of an operation are either all normalizable,
1284// or not. In the future, we may want to allow some of the operands to be
1285// normalizable.
1286template <typename ConcrentType>
1287struct MemRefsNormalizable
1288 : public TraitBase<ConcrentType, MemRefsNormalizable> {};
1289
1290/// This trait tags element-wise ops on vectors or tensors.
1291///
1292/// NOTE: Not all ops that are "elementwise" in some abstract sense satisfy this
1293/// trait. In particular, broadcasting behavior is not allowed.
1294///
1295/// An `Elementwise` op must satisfy the following properties:
1296///
1297/// 1. If any result is a vector/tensor then at least one operand must also be a
1298/// vector/tensor.
1299/// 2. If any operand is a vector/tensor then there must be at least one result
1300/// and all results must be vectors/tensors.
1301/// 3. All operand and result vector/tensor types must be of the same shape. The
1302/// shape may be dynamic in which case the op's behaviour is undefined for
1303/// non-matching shapes.
1304/// 4. The operation must be elementwise on its vector/tensor operands and
1305/// results. When applied to single-element vectors/tensors, the result must
1306/// be the same per elememnt.
1307///
1308/// TODO: Avoid hardcoding vector/tensor, and generalize this trait to a new
1309/// interface `ElementwiseTypeInterface` that describes the container types for
1310/// which the operation is elementwise.
1311///
1312/// Rationale:
1313/// - 1. and 2. guarantee a well-defined iteration space and exclude the cases
1314/// of 0 non-scalar operands or 0 non-scalar results, which complicate a
1315/// generic definition of the iteration space.
1316/// - 3. guarantees that folding can be done across scalars/vectors/tensors with
1317/// the same pattern, as otherwise lots of special handling for type
1318/// mismatches would be needed.
1319/// - 4. guarantees that no error handling is needed. Higher-level dialects
1320/// should reify any needed guards or error handling code before lowering to
1321/// an `Elementwise` op.
1322template <typename ConcreteType>
1323struct Elementwise : public TraitBase<ConcreteType, Elementwise> {
1324 static LogicalResult verifyTrait(Operation *op) {
1325 return ::mlir::OpTrait::impl::verifyElementwise(op);
1326 }
1327};
1328
1329/// This trait tags `Elementwise` operatons that can be systematically
1330/// scalarized. All vector/tensor operands and results are then replaced by
1331/// scalars of the respective element type. Semantically, this is the operation
1332/// on a single element of the vector/tensor.
1333///
1334/// Rationale:
1335/// Allow to define the vector/tensor semantics of elementwise operations based
1336/// on the same op's behavior on scalars. This provides a constructive procedure
1337/// for IR transformations to, e.g., create scalar loop bodies from tensor ops.
1338///
1339/// Example:
1340/// ```
1341/// %tensor_select = "std.select"(%pred_tensor, %true_val, %false_val)
1342/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1343/// -> tensor<?xf32>
1344/// ```
1345/// can be scalarized to
1346///
1347/// ```
1348/// %scalar_select = "std.select"(%pred, %true_val_scalar, %false_val_scalar)
1349/// : (i1, f32, f32) -> f32
1350/// ```
1351template <typename ConcreteType>
1352struct Scalarizable : public TraitBase<ConcreteType, Scalarizable> {
1353 static LogicalResult verifyTrait(Operation *op) {
1354 static_assert(
1355 ConcreteType::template hasTrait<Elementwise>(),
1356 "`Scalarizable` trait is only applicable to `Elementwise` ops.");
1357 return success();
1358 }
1359};
1360
1361/// This trait tags `Elementwise` operatons that can be systematically
1362/// vectorized. All scalar operands and results are then replaced by vectors
1363/// with the respective element type. Semantically, this is the operation on
1364/// multiple elements simultaneously. See also `Tensorizable`.
1365///
1366/// Rationale:
1367/// Provide the reverse to `Scalarizable` which, when chained together, allows
1368/// reasoning about the relationship between the tensor and vector case.
1369/// Additionally, it permits reasoning about promoting scalars to vectors via
1370/// broadcasting in cases like `%select_scalar_pred` below.
1371template <typename ConcreteType>
1372struct Vectorizable : public TraitBase<ConcreteType, Vectorizable> {
1373 static LogicalResult verifyTrait(Operation *op) {
1374 static_assert(
1375 ConcreteType::template hasTrait<Elementwise>(),
1376 "`Vectorizable` trait is only applicable to `Elementwise` ops.");
1377 return success();
1378 }
1379};
1380
1381/// This trait tags `Elementwise` operatons that can be systematically
1382/// tensorized. All scalar operands and results are then replaced by tensors
1383/// with the respective element type. Semantically, this is the operation on
1384/// multiple elements simultaneously. See also `Vectorizable`.
1385///
1386/// Rationale:
1387/// Provide the reverse to `Scalarizable` which, when chained together, allows
1388/// reasoning about the relationship between the tensor and vector case.
1389/// Additionally, it permits reasoning about promoting scalars to tensors via
1390/// broadcasting in cases like `%select_scalar_pred` below.
1391///
1392/// Examples:
1393/// ```
1394/// %scalar = "arith.addf"(%a, %b) : (f32, f32) -> f32
1395/// ```
1396/// can be tensorized to
1397/// ```
1398/// %tensor = "arith.addf"(%a, %b) : (tensor<?xf32>, tensor<?xf32>)
1399/// -> tensor<?xf32>
1400/// ```
1401///
1402/// ```
1403/// %scalar_pred = "std.select"(%pred, %true_val, %false_val)
1404/// : (i1, tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
1405/// ```
1406/// can be tensorized to
1407/// ```
1408/// %tensor_pred = "std.select"(%pred, %true_val, %false_val)
1409/// : (tensor<?xi1>, tensor<?xf32>, tensor<?xf32>)
1410/// -> tensor<?xf32>
1411/// ```
1412template <typename ConcreteType>
1413struct Tensorizable : public TraitBase<ConcreteType, Tensorizable> {
1414 static LogicalResult verifyTrait(Operation *op) {
1415 static_assert(
1416 ConcreteType::template hasTrait<Elementwise>(),
1417 "`Tensorizable` trait is only applicable to `Elementwise` ops.");
1418 return success();
1419 }
1420};
1421
1422/// Together, `Elementwise`, `Scalarizable`, `Vectorizable`, and `Tensorizable`
1423/// provide an easy way for scalar operations to conveniently generalize their
1424/// behavior to vectors/tensors, and systematize conversion between these forms.
1425bool hasElementwiseMappableTraits(Operation *op);
1426
1427} // namespace OpTrait
1428
1429//===----------------------------------------------------------------------===//
1430// Internal Trait Utilities
1431//===----------------------------------------------------------------------===//
1432
1433namespace op_definition_impl {
1434//===----------------------------------------------------------------------===//
1435// Trait Existence
1436
1437/// Returns true if this given Trait ID matches the IDs of any of the provided
1438/// trait types `Traits`.
1439template <template <typename T> class... Traits>
1440static bool hasTrait(TypeID traitID) {
1441 TypeID traitIDs[] = {TypeID::get<Traits>()...};
1442 for (unsigned i = 0, e = sizeof...(Traits); i != e; ++i)
1443 if (traitIDs[i] == traitID)
1444 return true;
1445 return false;
1446}
1447
1448//===----------------------------------------------------------------------===//
1449// Trait Folding
1450
1451/// Trait to check if T provides a 'foldTrait' method for single result
1452/// operations.
1453template <typename T, typename... Args>
1454using has_single_result_fold_trait = decltype(T::foldTrait(
1455 std::declval<Operation *>(), std::declval<ArrayRef<Attribute>>()));
1456template <typename T>
1457using detect_has_single_result_fold_trait =
1458 llvm::is_detected<has_single_result_fold_trait, T>;
1459/// Trait to check if T provides a general 'foldTrait' method.
1460template <typename T, typename... Args>
1461using has_fold_trait =
1462 decltype(T::foldTrait(std::declval<Operation *>(),
1463 std::declval<ArrayRef<Attribute>>(),
1464 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1465template <typename T>
1466using detect_has_fold_trait = llvm::is_detected<has_fold_trait, T>;
1467/// Trait to check if T provides any `foldTrait` method.
1468/// NOTE: This should use std::disjunction when C++17 is available.
1469template <typename T>
1470using detect_has_any_fold_trait =
1471 std::conditional_t<bool(detect_has_fold_trait<T>::value),
1472 detect_has_fold_trait<T>,
1473 detect_has_single_result_fold_trait<T>>;
1474
1475/// Returns the result of folding a trait that implements a `foldTrait` function
1476/// that is specialized for operations that have a single result.
1477template <typename Trait>
1478static std::enable_if_t<detect_has_single_result_fold_trait<Trait>::value,
1479 LogicalResult>
1480foldTrait(Operation *op, ArrayRef<Attribute> operands,
1481 SmallVectorImpl<OpFoldResult> &results) {
1482 assert(op->hasTrait<OpTrait::OneResult>() &&(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
1483 "expected trait on non single-result operation to implement the "(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
1484 "general `foldTrait` method")(static_cast <bool> (op->hasTrait<OpTrait::OneResult
>() && "expected trait on non single-result operation to implement the "
"general `foldTrait` method") ? void (0) : __assert_fail ("op->hasTrait<OpTrait::OneResult>() && \"expected trait on non single-result operation to implement the \" \"general `foldTrait` method\""
, "mlir/include/mlir/IR/OpDefinition.h", 1484, __extension__ __PRETTY_FUNCTION__
))
;
1485 // If a previous trait has already been folded and replaced this operation, we
1486 // fail to fold this trait.
1487 if (!results.empty())
1488 return failure();
1489
1490 if (OpFoldResult result = Trait::foldTrait(op, operands)) {
1491 if (result.template dyn_cast<Value>() != op->getResult(0))
1492 results.push_back(result);
1493 return success();
1494 }
1495 return failure();
1496}
1497/// Returns the result of folding a trait that implements a generalized
1498/// `foldTrait` function that is supports any operation type.
1499template <typename Trait>
1500static std::enable_if_t<detect_has_fold_trait<Trait>::value, LogicalResult>
1501foldTrait(Operation *op, ArrayRef<Attribute> operands,
1502 SmallVectorImpl<OpFoldResult> &results) {
1503 // If a previous trait has already been folded and replaced this operation, we
1504 // fail to fold this trait.
1505 return results.empty() ? Trait::foldTrait(op, operands, results) : failure();
1506}
1507
1508/// The internal implementation of `foldTraits` below that returns the result of
1509/// folding a set of trait types `Ts` that implement a `foldTrait` method.
1510template <typename... Ts>
1511static LogicalResult foldTraitsImpl(Operation *op, ArrayRef<Attribute> operands,
1512 SmallVectorImpl<OpFoldResult> &results,
1513 std::tuple<Ts...> *) {
1514 bool anyFolded = false;
1515 (void)std::initializer_list<int>{
1516 (anyFolded |= succeeded(foldTrait<Ts>(op, operands, results)), 0)...};
1517 return success(anyFolded);
1518}
1519
1520/// Given a tuple type containing a set of traits that contain a `foldTrait`
1521/// method, return the result of folding the given operation.
1522template <typename TraitTupleT>
1523static std::enable_if_t<std::tuple_size<TraitTupleT>::value != 0, LogicalResult>
1524foldTraits(Operation *op, ArrayRef<Attribute> operands,
1525 SmallVectorImpl<OpFoldResult> &results) {
1526 return foldTraitsImpl(op, operands, results, (TraitTupleT *)nullptr);
1527}
1528/// A variant of the method above that is specialized when there are no traits
1529/// that contain a `foldTrait` method.
1530template <typename TraitTupleT>
1531static std::enable_if_t<std::tuple_size<TraitTupleT>::value == 0, LogicalResult>
1532foldTraits(Operation *op, ArrayRef<Attribute> operands,
1533 SmallVectorImpl<OpFoldResult> &results) {
1534 return failure();
1535}
1536
1537//===----------------------------------------------------------------------===//
1538// Trait Verification
1539
1540/// Trait to check if T provides a `verifyTrait` method.
1541template <typename T, typename... Args>
1542using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
1543template <typename T>
1544using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
1545
1546/// The internal implementation of `verifyTraits` below that returns the result
1547/// of verifying the current operation with all of the provided trait types
1548/// `Ts`.
1549template <typename... Ts>
1550static LogicalResult verifyTraitsImpl(Operation *op, std::tuple<Ts...> *) {
1551 LogicalResult result = success();
1552 (void)std::initializer_list<int>{
1553 (result = succeeded(result) ? Ts::verifyTrait(op) : failure(), 0)...};
1554 return result;
1555}
1556
1557/// Given a tuple type containing a set of traits that contain a
1558/// `verifyTrait` method, return the result of verifying the given operation.
1559template <typename TraitTupleT>
1560static LogicalResult verifyTraits(Operation *op) {
1561 return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
1562}
1563} // namespace op_definition_impl
1564
1565//===----------------------------------------------------------------------===//
1566// Operation Definition classes
1567//===----------------------------------------------------------------------===//
1568
1569/// This provides public APIs that all operations should have. The template
1570/// argument 'ConcreteType' should be the concrete type by CRTP and the others
1571/// are base classes by the policy pattern.
1572template <typename ConcreteType, template <typename T> class... Traits>
1573class Op : public OpState, public Traits<ConcreteType>... {
1574public:
1575 /// Inherit getOperation from `OpState`.
1576 using OpState::getOperation;
1577
1578 /// Return if this operation contains the provided trait.
1579 template <template <typename T> class Trait>
1580 static constexpr bool hasTrait() {
1581 return llvm::is_one_of<Trait<ConcreteType>, Traits<ConcreteType>...>::value;
1582 }
1583
1584 /// Create a deep copy of this operation.
1585 ConcreteType clone() { return cast<ConcreteType>(getOperation()->clone()); }
1586
1587 /// Create a partial copy of this operation without traversing into attached
1588 /// regions. The new operation will have the same number of regions as the
1589 /// original one, but they will be left empty.
1590 ConcreteType cloneWithoutRegions() {
1591 return cast<ConcreteType>(getOperation()->cloneWithoutRegions());
1592 }
1593
1594 /// Return true if this "op class" can match against the specified operation.
1595 static bool classof(Operation *op) {
1596 if (auto info = op->getRegisteredInfo())
1597 return TypeID::get<ConcreteType>() == info->getTypeID();
1598#ifndef NDEBUG
1599 if (op->getName().getStringRef() == ConcreteType::getOperationName())
1600 llvm::report_fatal_error(
1601 "classof on '" + ConcreteType::getOperationName() +
1602 "' failed due to the operation not being registered");
1603#endif
1604 return false;
1605 }
1606 /// Provide `classof` support for other OpBase derived classes, such as
1607 /// Interfaces.
1608 template <typename T>
1609 static std::enable_if_t<std::is_base_of<OpState, T>::value, bool>
1610 classof(const T *op) {
1611 return classof(const_cast<T *>(op)->getOperation());
1612 }
1613
1614 /// Expose the type we are instantiated on to template machinery that may want
1615 /// to introspect traits on this operation.
1616 using ConcreteOpType = ConcreteType;
1617
1618 /// This is a public constructor. Any op can be initialized to null.
1619 explicit Op() : OpState(nullptr) {}
1620 Op(std::nullptr_t) : OpState(nullptr) {}
1621
1622 /// This is a public constructor to enable access via the llvm::cast family of
1623 /// methods. This should not be used directly.
1624 explicit Op(Operation *state) : OpState(state) {}
1625
1626 /// Methods for supporting PointerLikeTypeTraits.
1627 const void *getAsOpaquePointer() const {
1628 return static_cast<const void *>((Operation *)*this);
1629 }
1630 static ConcreteOpType getFromOpaquePointer(const void *pointer) {
1631 return ConcreteOpType(
1632 reinterpret_cast<Operation *>(const_cast<void *>(pointer)));
1633 }
1634
1635 /// Attach the given models as implementations of the corresponding interfaces
1636 /// for the concrete operation.
1637 template <typename... Models>
1638 static void attachInterface(MLIRContext &context) {
1639 Optional<RegisteredOperationName> info = RegisteredOperationName::lookup(
1640 ConcreteType::getOperationName(), &context);
1641 if (!info)
1642 llvm::report_fatal_error(
1643 "Attempting to attach an interface to an unregistered operation " +
1644 ConcreteType::getOperationName() + ".");
1645 info->attachInterface<Models...>();
1646 }
1647
1648private:
1649 /// Trait to check if T provides a 'fold' method for a single result op.
1650 template <typename T, typename... Args>
1651 using has_single_result_fold =
1652 decltype(std::declval<T>().fold(std::declval<ArrayRef<Attribute>>()));
1653 template <typename T>
1654 using detect_has_single_result_fold =
1655 llvm::is_detected<has_single_result_fold, T>;
1656 /// Trait to check if T provides a general 'fold' method.
1657 template <typename T, typename... Args>
1658 using has_fold = decltype(std::declval<T>().fold(
1659 std::declval<ArrayRef<Attribute>>(),
1660 std::declval<SmallVectorImpl<OpFoldResult> &>()));
1661 template <typename T>
1662 using detect_has_fold = llvm::is_detected<has_fold, T>;
1663 /// Trait to check if T provides a 'print' method.
1664 template <typename T, typename... Args>
1665 using has_print =
1666 decltype(std::declval<T>().print(std::declval<OpAsmPrinter &>()));
1667 template <typename T>
1668 using detect_has_print = llvm::is_detected<has_print, T>;
1669 /// A tuple type containing the traits that have a `foldTrait` function.
1670 using FoldableTraitsTupleT = typename detail::FilterTypes<
1671 op_definition_impl::detect_has_any_fold_trait,
1672 Traits<ConcreteType>...>::type;
1673 /// A tuple type containing the traits that have a verify function.
1674 using VerifiableTraitsTupleT =
1675 typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
1676 Traits<ConcreteType>...>::type;
1677
1678 /// Returns an interface map containing the interfaces registered to this
1679 /// operation.
1680 static detail::InterfaceMap getInterfaceMap() {
1681 return detail::InterfaceMap::template get<Traits<ConcreteType>...>();
1682 }
1683
1684 /// Return the internal implementations of each of the OperationName
1685 /// hooks.
1686 /// Implementation of `FoldHookFn` OperationName hook.
1687 static OperationName::FoldHookFn getFoldHookFn() {
1688 return getFoldHookFnImpl<ConcreteType>();
1689 }
1690 /// The internal implementation of `getFoldHookFn` above that is invoked if
1691 /// the operation is single result and defines a `fold` method.
1692 template <typename ConcreteOpT>
1693 static std::enable_if_t<llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1694 Traits<ConcreteOpT>...>::value &&
1695 detect_has_single_result_fold<ConcreteOpT>::value,
1696 OperationName::FoldHookFn>
1697 getFoldHookFnImpl() {
1698 return [](Operation *op, ArrayRef<Attribute> operands,
1699 SmallVectorImpl<OpFoldResult> &results) {
1700 return foldSingleResultHook<ConcreteOpT>(op, operands, results);
1701 };
1702 }
1703 /// The internal implementation of `getFoldHookFn` above that is invoked if
1704 /// the operation is not single result and defines a `fold` method.
1705 template <typename ConcreteOpT>
1706 static std::enable_if_t<!llvm::is_one_of<OpTrait::OneResult<ConcreteOpT>,
1707 Traits<ConcreteOpT>...>::value &&
1708 detect_has_fold<ConcreteOpT>::value,
1709 OperationName::FoldHookFn>
1710 getFoldHookFnImpl() {
1711 return [](Operation *op, ArrayRef<Attribute> operands,
1712 SmallVectorImpl<OpFoldResult> &results) {
1713 return foldHook<ConcreteOpT>(op, operands, results);
1714 };
1715 }
1716 /// The internal implementation of `getFoldHookFn` above that is invoked if
1717 /// the operation does not define a `fold` method.
1718 template <typename ConcreteOpT>
1719 static std::enable_if_t<!detect_has_single_result_fold<ConcreteOpT>::value &&
1720 !detect_has_fold<ConcreteOpT>::value,
1721 OperationName::FoldHookFn>
1722 getFoldHookFnImpl() {
1723 return [](Operation *op, ArrayRef<Attribute> operands,
1724 SmallVectorImpl<OpFoldResult> &results) {
1725 // In this case, we only need to fold the traits of the operation.
1726 return op_definition_impl::foldTraits<FoldableTraitsTupleT>(op, operands,
1727 results);
1728 };
1729 }
1730 /// Return the result of folding a single result operation that defines a
1731 /// `fold` method.
1732 template <typename ConcreteOpT>
1733 static LogicalResult
1734 foldSingleResultHook(Operation *op, ArrayRef<Attribute> operands,
1735 SmallVectorImpl<OpFoldResult> &results) {
1736 OpFoldResult result = cast<ConcreteOpT>(op).fold(operands);
1737
1738 // If the fold failed or was in-place, try to fold the traits of the
1739 // operation.
1740 if (!result || result.template dyn_cast<Value>() == op->getResult(0)) {
1741 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1742 op, operands, results)))
1743 return success();
1744 return success(static_cast<bool>(result));
1745 }
1746 results.push_back(result);
1747 return success();
1748 }
1749 /// Return the result of folding an operation that defines a `fold` method.
1750 template <typename ConcreteOpT>
1751 static LogicalResult foldHook(Operation *op, ArrayRef<Attribute> operands,
1752 SmallVectorImpl<OpFoldResult> &results) {
1753 LogicalResult result = cast<ConcreteOpT>(op).fold(operands, results);
1754
1755 // If the fold failed or was in-place, try to fold the traits of the
1756 // operation.
1757 if (failed(result) || results.empty()) {
1758 if (succeeded(op_definition_impl::foldTraits<FoldableTraitsTupleT>(
1759 op, operands, results)))
1760 return success();
1761 }
1762 return result;
1763 }
1764
1765 /// Implementation of `GetCanonicalizationPatternsFn` OperationName hook.
1766 static OperationName::GetCanonicalizationPatternsFn
1767 getGetCanonicalizationPatternsFn() {
1768 return &ConcreteType::getCanonicalizationPatterns;
1769 }
1770 /// Implementation of `GetHasTraitFn`
1771 static OperationName::HasTraitFn getHasTraitFn() {
1772 return
1773 [](TypeID id) { return op_definition_impl::hasTrait<Traits...>(id); };
1774 }
1775 /// Implementation of `ParseAssemblyFn` OperationName hook.
1776 static OperationName::ParseAssemblyFn getParseAssemblyFn() {
1777 return &ConcreteType::parse;
1778 }
1779 /// Implementation of `PrintAssemblyFn` OperationName hook.
1780 static OperationName::PrintAssemblyFn getPrintAssemblyFn() {
1781 return getPrintAssemblyFnImpl<ConcreteType>();
1782 }
1783 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1784 /// the concrete operation does not define a `print` method.
1785 template <typename ConcreteOpT>
1786 static std::enable_if_t<!detect_has_print<ConcreteOpT>::value,
1787 OperationName::PrintAssemblyFn>
1788 getPrintAssemblyFnImpl() {
1789 return [](Operation *op, OpAsmPrinter &printer, StringRef defaultDialect) {
1790 return OpState::print(op, printer, defaultDialect);
1791 };
1792 }
1793 /// The internal implementation of `getPrintAssemblyFn` that is invoked when
1794 /// the concrete operation defines a `print` method.
1795 template <typename ConcreteOpT>
1796 static std::enable_if_t<detect_has_print<ConcreteOpT>::value,
1797 OperationName::PrintAssemblyFn>
1798 getPrintAssemblyFnImpl() {
1799 return &printAssembly;
1800 }
1801 static void printAssembly(Operation *op, OpAsmPrinter &p,
1802 StringRef defaultDialect) {
1803 OpState::printOpName(op, p, defaultDialect);
1804 return cast<ConcreteType>(op).print(p);
1805 }
1806 /// Implementation of `VerifyInvariantsFn` OperationName hook.
1807 static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
1808 return &verifyInvariants;
1809 }
1810
1811 static constexpr bool hasNoDataMembers() {
1812 // Checking that the derived class does not define any member by comparing
1813 // its size to an ad-hoc EmptyOp.
1814 class EmptyOp : public Op<EmptyOp, Traits...> {};
1815 return sizeof(ConcreteType) == sizeof(EmptyOp);
1816 }
1817
1818 static LogicalResult verifyInvariants(Operation *op) {
1819 static_assert(hasNoDataMembers(),
1820 "Op class shouldn't define new data members");
1821 return failure(
1822 failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
1823 failed(cast<ConcreteType>(op).verify()));
1824 }
1825
1826 /// Allow access to internal implementation methods.
1827 friend RegisteredOperationName;
1828};
1829
1830/// This class represents the base of an operation interface. See the definition
1831/// of `detail::Interface` for requirements on the `Traits` type.
1832template <typename ConcreteType, typename Traits>
1833class OpInterface
1834 : public detail::Interface<ConcreteType, Operation *, Traits,
1835 Op<ConcreteType>, OpTrait::TraitBase> {
1836public:
1837 using Base = OpInterface<ConcreteType, Traits>;
1838 using InterfaceBase = detail::Interface<ConcreteType, Operation *, Traits,
1839 Op<ConcreteType>, OpTrait::TraitBase>;
1840
1841 /// Inherit the base class constructor.
1842 using InterfaceBase::InterfaceBase;
1843
1844protected:
1845 /// Returns the impl interface instance for the given operation.
1846 static typename InterfaceBase::Concept *getInterfaceFor(Operation *op) {
1847 OperationName name = op->getName();
1848
1849 // Access the raw interface from the operation info.
1850 if (Optional<RegisteredOperationName> rInfo = name.getRegisteredInfo()) {
1851 if (auto *opIface = rInfo->getInterface<ConcreteType>())
1852 return opIface;
1853 // Fallback to the dialect to provide it with a chance to implement this
1854 // interface for this operation.
1855 return rInfo->getDialect().getRegisteredInterfaceForOp<ConcreteType>(
1856 op->getName());
1857 }
1858 // Fallback to the dialect to provide it with a chance to implement this
1859 // interface for this operation.
1860 if (Dialect *dialect = name.getDialect())
1861 return dialect->getRegisteredInterfaceForOp<ConcreteType>(name);
1862 return nullptr;
1863 }
1864
1865 /// Allow access to `getInterfaceFor`.
1866 friend InterfaceBase;
1867};
1868
1869//===----------------------------------------------------------------------===//
1870// Common Operation Folders/Parsers/Printers
1871//===----------------------------------------------------------------------===//
1872
1873// These functions are out-of-line implementations of the methods in UnaryOp and
1874// BinaryOp, which avoids them being template instantiated/duplicated.
1875namespace impl {
1876ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
1877 OperationState &result);
1878
1879void buildBinaryOp(OpBuilder &builder, OperationState &result, Value lhs,
1880 Value rhs);
1881ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
1882 OperationState &result);
1883
1884// Prints the given binary `op` in custom assembly form if both the two operands
1885// and the result have the same time. Otherwise, prints the generic assembly
1886// form.
1887void printOneResultOp(Operation *op, OpAsmPrinter &p);
1888} // namespace impl
1889
1890// These functions are out-of-line implementations of the methods in
1891// CastOpInterface, which avoids them being template instantiated/duplicated.
1892namespace impl {
1893/// Attempt to fold the given cast operation.
1894LogicalResult foldCastInterfaceOp(Operation *op,
1895 ArrayRef<Attribute> attrOperands,
1896 SmallVectorImpl<OpFoldResult> &foldResults);
1897/// Attempt to verify the given cast operation.
1898LogicalResult verifyCastInterfaceOp(
1899 Operation *op, function_ref<bool(TypeRange, TypeRange)> areCastCompatible);
1900
1901// TODO: Remove the parse/print/build here (new ODS functionality obsoletes the
1902// need for them, but some older ODS code in `std` still depends on them).
1903void buildCastOp(OpBuilder &builder, OperationState &result, Value source,
1904 Type destType);
1905ParseResult parseCastOp(OpAsmParser &parser, OperationState &result);
1906void printCastOp(Operation *op, OpAsmPrinter &p);
1907// TODO: These methods are deprecated in favor of CastOpInterface. Remove them
1908// when all uses have been updated. Also, consider adding functionality to
1909// CastOpInterface to be able to perform the ChainedTensorCast canonicalization
1910// generically.
1911Value foldCastOp(Operation *op);
1912LogicalResult verifyCastOp(Operation *op,
1913 function_ref<bool(Type, Type)> areCastCompatible);
1914} // namespace impl
1915} // namespace mlir
1916
1917namespace llvm {
1918
1919template <typename T>
1920struct DenseMapInfo<
1921 T, std::enable_if_t<std::is_base_of<mlir::OpState, T>::value>> {
1922 static inline T getEmptyKey() {
1923 auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
1924 return T::getFromOpaquePointer(pointer);
1925 }
1926 static inline T getTombstoneKey() {
1927 auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
1928 return T::getFromOpaquePointer(pointer);
1929 }
1930 static unsigned getHashValue(T val) {
1931 return hash_value(val.getAsOpaquePointer());
1932 }
1933 static bool isEqual(T lhs, T rhs) { return lhs == rhs; }
1934};
1935
1936} // namespace llvm
1937
1938#endif